diff --git a/cmd/thv/app/config.go b/cmd/thv/app/config.go index eefb73ec9..8f1dd2651 100644 --- a/cmd/thv/app/config.go +++ b/cmd/thv/app/config.go @@ -130,7 +130,8 @@ func setCACertCmdFunc(_ *cobra.Command, args []string) error { } func getCACertCmdFunc(_ *cobra.Command, _ []string) error { - cfg := config.GetConfig() + configProvider := config.NewDefaultProvider() + cfg := configProvider.GetConfig() if cfg.CACertificatePath == "" { fmt.Println("No CA certificate is currently configured.") @@ -148,7 +149,8 @@ func getCACertCmdFunc(_ *cobra.Command, _ []string) error { } func unsetCACertCmdFunc(_ *cobra.Command, _ []string) error { - cfg := config.GetConfig() + configProvider := config.NewDefaultProvider() + cfg := configProvider.GetConfig() if cfg.CACertificatePath == "" { fmt.Println("No CA certificate is currently configured.") @@ -171,9 +173,11 @@ func setRegistryCmdFunc(_ *cobra.Command, args []string) error { input := args[0] registryType, cleanPath := config.DetectRegistryType(input) + provider := config.NewDefaultProvider() + switch registryType { case config.RegistryTypeURL: - err := config.SetRegistryURL(cleanPath, allowPrivateRegistryIp) + err := provider.SetRegistryURL(cleanPath, allowPrivateRegistryIp) if err != nil { return err } @@ -188,14 +192,15 @@ func setRegistryCmdFunc(_ *cobra.Command, args []string) error { } return nil case config.RegistryTypeFile: - return config.SetRegistryFile(cleanPath) + return provider.SetRegistryFile(cleanPath) default: return fmt.Errorf("unsupported registry type") } } func getRegistryCmdFunc(_ *cobra.Command, _ []string) error { - url, localPath, _, registryType := config.GetRegistryConfig() + provider := config.NewDefaultProvider() + url, localPath, _, registryType := provider.GetRegistryConfig() switch registryType { case config.RegistryTypeURL: @@ -213,14 +218,15 @@ func getRegistryCmdFunc(_ *cobra.Command, _ []string) error { } func unsetRegistryCmdFunc(_ *cobra.Command, _ []string) error { - url, localPath, _, registryType := config.GetRegistryConfig() + provider := config.NewDefaultProvider() + url, localPath, _, registryType := provider.GetRegistryConfig() if registryType == "default" { fmt.Println("No custom registry is currently configured.") return nil } - err := config.UnsetRegistry() + err := provider.UnsetRegistry() if err != nil { return fmt.Errorf("failed to update configuration: %w", err) } diff --git a/cmd/thv/app/otel.go b/cmd/thv/app/otel.go index acf7eed23..f53e1b312 100644 --- a/cmd/thv/app/otel.go +++ b/cmd/thv/app/otel.go @@ -136,7 +136,8 @@ func setOtelEndpointCmdFunc(_ *cobra.Command, args []string) error { } func getOtelEndpointCmdFunc(_ *cobra.Command, _ []string) error { - cfg := config.GetConfig() + configProvider := config.NewDefaultProvider() + cfg := configProvider.GetConfig() if cfg.OTEL.Endpoint == "" { fmt.Println("No OpenTelemetry endpoint is currently configured.") @@ -148,7 +149,8 @@ func getOtelEndpointCmdFunc(_ *cobra.Command, _ []string) error { } func unsetOtelEndpointCmdFunc(_ *cobra.Command, _ []string) error { - cfg := config.GetConfig() + configProvider := config.NewDefaultProvider() + cfg := configProvider.GetConfig() if cfg.OTEL.Endpoint == "" { fmt.Println("No OpenTelemetry endpoint is currently configured.") @@ -191,7 +193,8 @@ func setOtelSamplingRateCmdFunc(_ *cobra.Command, args []string) error { } func getOtelSamplingRateCmdFunc(_ *cobra.Command, _ []string) error { - cfg := config.GetConfig() + configProvider := config.NewDefaultProvider() + cfg := configProvider.GetConfig() if cfg.OTEL.SamplingRate == 0.0 { fmt.Println("No OpenTelemetry sampling rate is currently configured.") @@ -203,7 +206,8 @@ func getOtelSamplingRateCmdFunc(_ *cobra.Command, _ []string) error { } func unsetOtelSamplingRateCmdFunc(_ *cobra.Command, _ []string) error { - cfg := config.GetConfig() + configProvider := config.NewDefaultProvider() + cfg := configProvider.GetConfig() if cfg.OTEL.SamplingRate == 0.0 { fmt.Println("No OpenTelemetry sampling rate is currently configured.") @@ -243,7 +247,8 @@ func setOtelEnvVarsCmdFunc(_ *cobra.Command, args []string) error { } func getOtelEnvVarsCmdFunc(_ *cobra.Command, _ []string) error { - cfg := config.GetConfig() + configProvider := config.NewDefaultProvider() + cfg := configProvider.GetConfig() if len(cfg.OTEL.EnvVars) == 0 { fmt.Println("No OpenTelemetry environment variables are currently configured.") @@ -255,7 +260,8 @@ func getOtelEnvVarsCmdFunc(_ *cobra.Command, _ []string) error { } func unsetOtelEnvVarsCmdFunc(_ *cobra.Command, _ []string) error { - cfg := config.GetConfig() + configProvider := config.NewDefaultProvider() + cfg := configProvider.GetConfig() if len(cfg.OTEL.EnvVars) == 0 { fmt.Println("No OpenTelemetry environment variables are currently configured.") diff --git a/cmd/thv/app/run_flags.go b/cmd/thv/app/run_flags.go index 6b128a87b..e7aa834e0 100644 --- a/cmd/thv/app/run_flags.go +++ b/cmd/thv/app/run_flags.go @@ -286,7 +286,8 @@ func setupOIDCConfiguration(cmd *cobra.Command, runFlags *RunFlags) (*auth.Token // setupTelemetryConfiguration sets up telemetry configuration with config fallbacks func setupTelemetryConfiguration(cmd *cobra.Command, runFlags *RunFlags) *telemetry.Config { - config := cfg.GetConfig() + configProvider := cfg.NewDefaultProvider() + config := configProvider.GetConfig() finalOtelEndpoint, finalOtelSamplingRate, finalOtelEnvironmentVariables := getTelemetryFromFlags(cmd, config, runFlags.OtelEndpoint, runFlags.OtelSamplingRate, runFlags.OtelEnvironmentVariables) @@ -306,7 +307,8 @@ func setupRuntimeAndValidation(ctx context.Context) (runtime.Deployer, runner.En if process.IsDetached() || runtime.IsKubernetesRuntime() { envVarValidator = &runner.DetachedEnvVarValidator{} } else { - envVarValidator = &runner.CLIEnvVarValidator{} + cfgProvider := cfg.NewDefaultProvider() + envVarValidator = runner.NewCLIEnvVarValidator(cfgProvider) } return rt, envVarValidator, nil diff --git a/cmd/thv/app/secret.go b/cmd/thv/app/secret.go index bb9e94b8a..05b0ae462 100644 --- a/cmd/thv/app/secret.go +++ b/cmd/thv/app/secret.go @@ -167,7 +167,9 @@ Note that some providers (like 1Password) are read-only and do not support setti // Check if the provider supports writing secrets if !manager.Capabilities().CanWrite { - providerType, _ := config.GetConfig().Secrets.GetProviderType() + configProvider := config.NewDefaultProvider() + cfg := configProvider.GetConfig() + providerType, _ := cfg.Secrets.GetProviderType() fmt.Fprintf(os.Stderr, "Error: The %s secrets provider does not support setting secrets (read-only)\n", providerType) return } @@ -250,7 +252,9 @@ If your provider is read-only or doesn't support deletion, this command returns // Check if the provider supports deleting secrets if !manager.Capabilities().CanDelete { - providerType, _ := config.GetConfig().Secrets.GetProviderType() + configProvider := config.NewDefaultProvider() + cfg := configProvider.GetConfig() + providerType, _ := cfg.Secrets.GetProviderType() fmt.Fprintf(os.Stderr, "Error: The %s secrets provider does not support deleting secrets\n", providerType) return } @@ -284,7 +288,9 @@ If descriptions exist for the secrets, the command displays them alongside the n // Check if the provider supports listing secrets if !manager.Capabilities().CanList { - providerType, _ := config.GetConfig().Secrets.GetProviderType() + configProvider := config.NewDefaultProvider() + cfg := configProvider.GetConfig() + providerType, _ := cfg.Secrets.GetProviderType() fmt.Fprintf(os.Stderr, "Error: The %s secrets provider does not support listing secrets\n", providerType) return } @@ -344,7 +350,8 @@ This command only works with the 'encrypted' secrets provider.`, } func getSecretsManager() (secrets.Provider, error) { - cfg := config.GetConfig() + configProvider := config.NewDefaultProvider() + cfg := configProvider.GetConfig() // Check if secrets setup has been completed if !cfg.Secrets.SetupCompleted { diff --git a/pkg/api/v1/groups_test.go b/pkg/api/v1/groups_test.go index 717f6cdf5..29ee5f895 100644 --- a/pkg/api/v1/groups_test.go +++ b/pkg/api/v1/groups_test.go @@ -277,17 +277,21 @@ func TestGroupsRouter_Integration(t *testing.T) { logger.Initialize() // Test with real managers (integration test) + // Use a test config provider to avoid modifying the real config file + configProvider, cleanup := CreateTestConfigProvider(t, nil) + t.Cleanup(cleanup) + groupManager, err := groups.NewManager() if err != nil { t.Skip("Skipping integration test: failed to create group manager") } - workloadManager, err := workloads.NewManager(context.Background()) + workloadManager, err := workloads.NewManagerWithProvider(context.Background(), configProvider) if err != nil { t.Skip("Skipping integration test: failed to create workload manager") } - clientManager, err := client.NewManager(context.Background()) + clientManager, err := client.NewManagerWithProvider(context.Background(), configProvider) if err != nil { t.Skip("Skipping integration test: failed to create client manager") } diff --git a/pkg/api/v1/healtcheck_test.go b/pkg/api/v1/healtcheck_test.go index 11af48616..4b08cedbb 100644 --- a/pkg/api/v1/healtcheck_test.go +++ b/pkg/api/v1/healtcheck_test.go @@ -15,20 +15,19 @@ import ( func TestGetHealthcheck(t *testing.T) { t.Parallel() - // Create a new gomock controller - ctrl := gomock.NewController(t) - t.Cleanup(func() { - ctrl.Finish() - }) - - // Create a mock runtime - mockRuntime := mocks.NewMockRuntime(ctrl) - - // Create healthcheck routes with the mock runtime - routes := &healthcheckRoutes{containerRuntime: mockRuntime} - t.Run("returns 204 when runtime is running", func(t *testing.T) { t.Parallel() + // Create a new gomock controller for this subtest + ctrl := gomock.NewController(t) + t.Cleanup(func() { + ctrl.Finish() + }) + + // Create a mock runtime + mockRuntime := mocks.NewMockRuntime(ctrl) + + // Create healthcheck routes with the mock runtime + routes := &healthcheckRoutes{containerRuntime: mockRuntime} // Setup mock to return nil (no error) when IsRunning is called mockRuntime.EXPECT(). @@ -49,6 +48,17 @@ func TestGetHealthcheck(t *testing.T) { t.Run("returns 503 when runtime is not running", func(t *testing.T) { t.Parallel() + // Create a new gomock controller for this subtest + ctrl := gomock.NewController(t) + t.Cleanup(func() { + ctrl.Finish() + }) + + // Create a mock runtime + mockRuntime := mocks.NewMockRuntime(ctrl) + + // Create healthcheck routes with the mock runtime + routes := &healthcheckRoutes{containerRuntime: mockRuntime} // Create an error to return expectedError := errors.New("container runtime is not available") diff --git a/pkg/api/v1/registry.go b/pkg/api/v1/registry.go index 4f3425318..85a1f2dee 100644 --- a/pkg/api/v1/registry.go +++ b/pkg/api/v1/registry.go @@ -30,8 +30,8 @@ const ( ) // getRegistryInfo returns the registry type and the source -func getRegistryInfo() (RegistryType, string) { - return getRegistryInfoWithProvider(config.NewDefaultProvider()) +func (rr *RegistryRoutes) getRegistryInfo() (RegistryType, string) { + return getRegistryInfoWithProvider(rr.configProvider) } // getRegistryInfoWithProvider returns the registry type and the source using the provided config provider @@ -50,9 +50,9 @@ func getRegistryInfoWithProvider(configProvider config.Provider) (RegistryType, return RegistryTypeDefault, "" } -// getCurrentProvider returns the current registry provider -func getCurrentProvider(w http.ResponseWriter) (registry.Provider, bool) { - provider, err := registry.GetDefaultProvider() +// getCurrentProvider returns the current registry provider using the injected config +func (rr *RegistryRoutes) getCurrentProvider(w http.ResponseWriter) (registry.Provider, bool) { + provider, err := registry.GetDefaultProviderWithConfig(rr.configProvider) if err != nil { http.Error(w, "Failed to get registry provider", http.StatusInternalServerError) logger.Errorf("Failed to get registry provider: %v", err) @@ -62,11 +62,28 @@ func getCurrentProvider(w http.ResponseWriter) (registry.Provider, bool) { } // RegistryRoutes defines the routes for the registry API. -type RegistryRoutes struct{} +type RegistryRoutes struct { + configProvider config.Provider +} + +// NewRegistryRoutes creates a new RegistryRoutes with the default config provider +func NewRegistryRoutes() *RegistryRoutes { + return &RegistryRoutes{ + configProvider: config.NewDefaultProvider(), + } +} + +// NewRegistryRoutesWithProvider creates a new RegistryRoutes with a custom config provider +// This is useful for testing +func NewRegistryRoutesWithProvider(provider config.Provider) *RegistryRoutes { + return &RegistryRoutes{ + configProvider: provider, + } +} // RegistryRouter creates a new router for the registry API. func RegistryRouter() http.Handler { - routes := RegistryRoutes{} + routes := NewRegistryRoutes() r := chi.NewRouter() r.Get("/", routes.listRegistries) @@ -91,8 +108,8 @@ func RegistryRouter() http.Handler { // @Produce json // @Success 200 {object} registryListResponse // @Router /api/v1beta/registry [get] -func (*RegistryRoutes) listRegistries(w http.ResponseWriter, _ *http.Request) { - provider, ok := getCurrentProvider(w) +func (rr *RegistryRoutes) listRegistries(w http.ResponseWriter, _ *http.Request) { + provider, ok := rr.getCurrentProvider(w) if !ok { return } @@ -103,7 +120,7 @@ func (*RegistryRoutes) listRegistries(w http.ResponseWriter, _ *http.Request) { return } - registryType, source := getRegistryInfo() + registryType, source := rr.getRegistryInfo() registries := []registryInfo{ { @@ -149,7 +166,7 @@ func (*RegistryRoutes) addRegistry(w http.ResponseWriter, _ *http.Request) { // @Success 200 {object} getRegistryResponse // @Failure 404 {string} string "Not Found" // @Router /api/v1beta/registry/{name} [get] -func (*RegistryRoutes) getRegistry(w http.ResponseWriter, r *http.Request) { +func (rr *RegistryRoutes) getRegistry(w http.ResponseWriter, r *http.Request) { name := chi.URLParam(r, "name") // Only "default" registry is supported currently @@ -158,7 +175,7 @@ func (*RegistryRoutes) getRegistry(w http.ResponseWriter, r *http.Request) { return } - provider, ok := getCurrentProvider(w) + provider, ok := rr.getCurrentProvider(w) if !ok { return } @@ -169,7 +186,7 @@ func (*RegistryRoutes) getRegistry(w http.ResponseWriter, r *http.Request) { return } - registryType, source := getRegistryInfo() + registryType, source := rr.getRegistryInfo() response := getRegistryResponse{ Name: defaultRegistryName, @@ -202,7 +219,7 @@ func (*RegistryRoutes) getRegistry(w http.ResponseWriter, r *http.Request) { // @Failure 400 {string} string "Bad Request" // @Failure 404 {string} string "Not Found" // @Router /api/v1beta/registry/{name} [put] -func (*RegistryRoutes) updateRegistry(w http.ResponseWriter, r *http.Request) { +func (rr *RegistryRoutes) updateRegistry(w http.ResponseWriter, r *http.Request) { name := chi.URLParam(r, "name") // Only "default" registry can be updated currently @@ -228,7 +245,9 @@ func (*RegistryRoutes) updateRegistry(w http.ResponseWriter, r *http.Request) { // Handle reset to default (no URL or LocalPath specified) if req.URL == nil && req.LocalPath == nil { - if err := config.UnsetRegistry(); err != nil { + // Use the config provider to unset the registry + provider := rr.configProvider + if err := provider.UnsetRegistry(); err != nil { logger.Errorf("Failed to unset registry: %v", err) http.Error(w, "Failed to reset registry configuration", http.StatusInternalServerError) return @@ -242,7 +261,8 @@ func (*RegistryRoutes) updateRegistry(w http.ResponseWriter, r *http.Request) { allowPrivateIP = *req.AllowPrivateIP } - if err := config.SetRegistryURL(*req.URL, allowPrivateIP); err != nil { + // Use the config provider to update the registry URL + if err := rr.configProvider.SetRegistryURL(*req.URL, allowPrivateIP); err != nil { logger.Errorf("Failed to set registry URL: %v", err) http.Error(w, fmt.Sprintf("Failed to set registry URL: %v", err), http.StatusBadRequest) return @@ -251,7 +271,10 @@ func (*RegistryRoutes) updateRegistry(w http.ResponseWriter, r *http.Request) { message = fmt.Sprintf("Successfully set registry URL: %s", *req.URL) } else if req.LocalPath != nil { // Handle local path update - if err := config.SetRegistryFile(*req.LocalPath); err != nil { + // Use the config provider to update the registry file + provider := rr.configProvider + + if err := provider.SetRegistryFile(*req.LocalPath); err != nil { logger.Errorf("Failed to set registry file: %v", err) http.Error(w, fmt.Sprintf("Failed to set registry file: %v", err), http.StatusBadRequest) return @@ -309,7 +332,7 @@ func (*RegistryRoutes) removeRegistry(w http.ResponseWriter, r *http.Request) { // @Success 200 {object} listServersResponse // @Failure 404 {string} string "Not Found" // @Router /api/v1beta/registry/{name}/servers [get] -func (*RegistryRoutes) listServers(w http.ResponseWriter, r *http.Request) { +func (rr *RegistryRoutes) listServers(w http.ResponseWriter, r *http.Request) { registryName := chi.URLParam(r, "name") // Only "default" registry is supported currently @@ -318,7 +341,7 @@ func (*RegistryRoutes) listServers(w http.ResponseWriter, r *http.Request) { return } - provider, ok := getCurrentProvider(w) + provider, ok := rr.getCurrentProvider(w) if !ok { return } @@ -366,7 +389,7 @@ func (*RegistryRoutes) listServers(w http.ResponseWriter, r *http.Request) { // @Success 200 {object} getServerResponse // @Failure 404 {string} string "Not Found" // @Router /api/v1beta/registry/{name}/servers/{serverName} [get] -func (*RegistryRoutes) getServer(w http.ResponseWriter, r *http.Request) { +func (rr *RegistryRoutes) getServer(w http.ResponseWriter, r *http.Request) { registryName := chi.URLParam(r, "name") serverName := chi.URLParam(r, "serverName") @@ -376,7 +399,7 @@ func (*RegistryRoutes) getServer(w http.ResponseWriter, r *http.Request) { return } - provider, ok := getCurrentProvider(w) + provider, ok := rr.getCurrentProvider(w) if !ok { return } diff --git a/pkg/api/v1/registry_test.go b/pkg/api/v1/registry_test.go index ff6ae62db..3b0b9ce86 100644 --- a/pkg/api/v1/registry_test.go +++ b/pkg/api/v1/registry_test.go @@ -51,8 +51,10 @@ func TestRegistryRouter(t *testing.T) { logger.Initialize() - router := RegistryRouter() - assert.NotNil(t, router) + // Create a test config provider to avoid using the singleton + provider, _ := CreateTestConfigProvider(t, nil) + routes := NewRegistryRoutesWithProvider(provider) + assert.NotNil(t, routes) } //nolint:paralleltest // Cannot use t.Parallel() with t.Setenv() in Go 1.24+ @@ -114,8 +116,6 @@ func TestRegistryAPI_PutEndpoint(t *testing.T) { logger.Initialize() - routes := &RegistryRoutes{} - tests := []struct { name string requestBody string @@ -152,6 +152,20 @@ func TestRegistryAPI_PutEndpoint(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() + // Create a temporary config for this test + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "toolhive", "config.yaml") + + // Ensure the directory exists + err := os.MkdirAll(filepath.Dir(configPath), 0755) + require.NoError(t, err) + + // Create a test config provider + configProvider := config.NewPathProvider(configPath) + + // Create routes with the test config provider + routes := NewRegistryRoutesWithProvider(configProvider) + req := httptest.NewRequest("PUT", "/default", strings.NewReader(tt.requestBody)) req.Header.Set("Content-Type", "application/json") rctx := chi.NewRouteContext() diff --git a/pkg/api/v1/secrets.go b/pkg/api/v1/secrets.go index b491f52cc..a19f34fa9 100644 --- a/pkg/api/v1/secrets.go +++ b/pkg/api/v1/secrets.go @@ -20,12 +20,31 @@ const ( ) // SecretsRoutes defines the routes for the secrets API. -type SecretsRoutes struct{} +type SecretsRoutes struct { + configProvider config.Provider +} + +// NewSecretsRoutes creates a new SecretsRoutes with the default config provider +func NewSecretsRoutes() *SecretsRoutes { + return &SecretsRoutes{ + configProvider: config.NewDefaultProvider(), + } +} + +// NewSecretsRoutesWithProvider creates a new SecretsRoutes with a custom config provider +func NewSecretsRoutesWithProvider(provider config.Provider) *SecretsRoutes { + return &SecretsRoutes{ + configProvider: provider, + } +} // SecretsRouter creates a new router for the secrets API. func SecretsRouter() http.Handler { - routes := SecretsRoutes{} + routes := NewSecretsRoutes() + return secretsRouterWithRoutes(routes) +} +func secretsRouterWithRoutes(routes *SecretsRoutes) http.Handler { r := chi.NewRouter() // Setup secrets provider @@ -59,7 +78,7 @@ func SecretsRouter() http.Handler { // @Failure 400 {string} string "Bad Request" // @Failure 500 {string} string "Internal Server Error" // @Router /api/v1beta/secrets [post] -func (*SecretsRoutes) setupSecretsProvider(w http.ResponseWriter, r *http.Request) { +func (s *SecretsRoutes) setupSecretsProvider(w http.ResponseWriter, r *http.Request) { var req setupSecretsRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { logger.Errorf("Failed to decode request body: %v", err) @@ -87,7 +106,7 @@ func (*SecretsRoutes) setupSecretsProvider(w http.ResponseWriter, r *http.Reques } // Check current secrets provider configuration for appropriate messaging - cfg := config.GetConfig() + cfg := s.configProvider.GetConfig() isReconfiguration := false isInitialSetup := !cfg.Secrets.SetupCompleted if cfg.Secrets.SetupCompleted { @@ -157,7 +176,7 @@ func (*SecretsRoutes) setupSecretsProvider(w http.ResponseWriter, r *http.Reques } // Update the secrets provider type and mark setup as completed - err := config.UpdateConfig(func(c *config.Config) { + err := s.configProvider.UpdateConfig(func(c *config.Config) { c.Secrets.ProviderType = string(providerType) c.Secrets.SetupCompleted = true }) @@ -199,7 +218,7 @@ func (*SecretsRoutes) setupSecretsProvider(w http.ResponseWriter, r *http.Reques // @Failure 500 {string} string "Internal Server Error" // @Router /api/v1beta/secrets/default [get] func (s *SecretsRoutes) getSecretsProvider(w http.ResponseWriter, _ *http.Request) { - cfg := config.GetConfig() + cfg := s.configProvider.GetConfig() // Check if secrets provider is setup if !cfg.Secrets.SetupCompleted { @@ -500,8 +519,8 @@ func (s *SecretsRoutes) deleteSecret(w http.ResponseWriter, r *http.Request) { } // getSecretsManager is a helper function to get the secrets manager -func (*SecretsRoutes) getSecretsManager() (secrets.Provider, error) { - cfg := config.GetConfig() +func (s *SecretsRoutes) getSecretsManager() (secrets.Provider, error) { + cfg := s.configProvider.GetConfig() // Check if secrets setup has been completed if !cfg.Secrets.SetupCompleted { diff --git a/pkg/api/v1/secrets_test.go b/pkg/api/v1/secrets_test.go index afc51de24..851e6ab15 100644 --- a/pkg/api/v1/secrets_test.go +++ b/pkg/api/v1/secrets_test.go @@ -6,6 +6,8 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "os" + "path/filepath" "strings" "testing" @@ -13,13 +15,21 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/stacklok/toolhive/pkg/config" "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/secrets" ) func TestSecretsRouter(t *testing.T) { t.Parallel() - router := SecretsRouter() + + // Create a test config provider to avoid using the singleton + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.yaml") + provider := config.NewPathProvider(configPath) + + routes := NewSecretsRoutesWithProvider(provider) + router := secretsRouterWithRoutes(routes) assert.NotNil(t, router) } @@ -46,6 +56,17 @@ func TestSetupSecretsProvider_ValidRequests(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() + // Create a temporary config directory for this test + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "toolhive", "config.yaml") + + // Ensure the directory exists + err := os.MkdirAll(filepath.Dir(configPath), 0755) + require.NoError(t, err) + + // Create a test config provider + configProvider := config.NewPathProvider(configPath) + body, err := json.Marshal(tt.requestBody) require.NoError(t, err) @@ -53,7 +74,7 @@ func TestSetupSecretsProvider_ValidRequests(t *testing.T) { req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() - routes := &SecretsRoutes{} + routes := NewSecretsRoutesWithProvider(configProvider) routes.setupSecretsProvider(w, req) assert.Equal(t, tt.expectedCode, w.Code) @@ -101,8 +122,18 @@ func TestSetupSecretsProvider_InvalidRequests(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() + // Create a temporary config directory for this test + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "toolhive", "config.yaml") + + // Ensure the directory exists + err := os.MkdirAll(filepath.Dir(configPath), 0755) + require.NoError(t, err) + + // Create a test config provider + configProvider := config.NewPathProvider(configPath) + var body []byte - var err error if str, ok := tt.requestBody.(string); ok { body = []byte(str) } else { @@ -114,7 +145,7 @@ func TestSetupSecretsProvider_InvalidRequests(t *testing.T) { req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() - routes := &SecretsRoutes{} + routes := NewSecretsRoutesWithProvider(configProvider) routes.setupSecretsProvider(w, req) assert.Equal(t, tt.expectedCode, w.Code) @@ -164,8 +195,18 @@ func TestCreateSecret_InvalidRequests(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() + // Create a temporary config directory for this test + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "toolhive", "config.yaml") + + // Ensure the directory exists + err := os.MkdirAll(filepath.Dir(configPath), 0755) + require.NoError(t, err) + + // Create a test config provider + configProvider := config.NewPathProvider(configPath) + var body []byte - var err error if str, ok := tt.requestBody.(string); ok { body = []byte(str) } else { @@ -177,7 +218,7 @@ func TestCreateSecret_InvalidRequests(t *testing.T) { req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() - routes := &SecretsRoutes{} + routes := NewSecretsRoutesWithProvider(configProvider) routes.createSecret(w, req) assert.Equal(t, tt.expectedCode, w.Code) @@ -229,8 +270,18 @@ func TestUpdateSecret_InvalidRequests(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() + // Create a temporary config directory for this test + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "toolhive", "config.yaml") + + // Ensure the directory exists + err := os.MkdirAll(filepath.Dir(configPath), 0755) + require.NoError(t, err) + + // Create a test config provider + configProvider := config.NewPathProvider(configPath) + var body []byte - var err error if str, ok := tt.requestBody.(string); ok { body = []byte(str) } else { @@ -249,7 +300,7 @@ func TestUpdateSecret_InvalidRequests(t *testing.T) { w := httptest.NewRecorder() - routes := &SecretsRoutes{} + routes := NewSecretsRoutesWithProvider(configProvider) routes.updateSecret(w, req) assert.Equal(t, tt.expectedCode, w.Code) @@ -281,6 +332,17 @@ func TestDeleteSecret_InvalidRequests(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() + // Create a temporary config directory for this test + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "toolhive", "config.yaml") + + // Ensure the directory exists + err := os.MkdirAll(filepath.Dir(configPath), 0755) + require.NoError(t, err) + + // Create a test config provider + configProvider := config.NewPathProvider(configPath) + url := "/default/keys/" + tt.secretKey req := httptest.NewRequest(http.MethodDelete, url, nil) @@ -291,7 +353,7 @@ func TestDeleteSecret_InvalidRequests(t *testing.T) { w := httptest.NewRecorder() - routes := &SecretsRoutes{} + routes := NewSecretsRoutesWithProvider(configProvider) routes.deleteSecret(w, req) assert.Equal(t, tt.expectedCode, w.Code) @@ -405,12 +467,24 @@ func TestErrorHandling(t *testing.T) { t.Run("malformed json request", func(t *testing.T) { t.Parallel() + + // Create a temporary config directory for this test + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "toolhive", "config.yaml") + + // Ensure the directory exists + err := os.MkdirAll(filepath.Dir(configPath), 0755) + require.NoError(t, err) + + // Create a test config provider + configProvider := config.NewPathProvider(configPath) + malformedJSON := `{"provider_type": "encrypted", "invalid": json}` req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(malformedJSON)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() - routes := &SecretsRoutes{} + routes := NewSecretsRoutesWithProvider(configProvider) routes.setupSecretsProvider(w, req) assert.Equal(t, http.StatusBadRequest, w.Code) @@ -419,11 +493,23 @@ func TestErrorHandling(t *testing.T) { t.Run("empty request body", func(t *testing.T) { t.Parallel() + + // Create a temporary config directory for this test + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "toolhive", "config.yaml") + + // Ensure the directory exists + err := os.MkdirAll(filepath.Dir(configPath), 0755) + require.NoError(t, err) + + // Create a test config provider + configProvider := config.NewPathProvider(configPath) + req := httptest.NewRequest(http.MethodPost, "/default/keys", strings.NewReader("")) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() - routes := &SecretsRoutes{} + routes := NewSecretsRoutesWithProvider(configProvider) routes.createSecret(w, req) assert.Equal(t, http.StatusBadRequest, w.Code) @@ -431,11 +517,23 @@ func TestErrorHandling(t *testing.T) { t.Run("missing content type header", func(t *testing.T) { t.Parallel() + + // Create a temporary config directory for this test + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "toolhive", "config.yaml") + + // Ensure the directory exists + err := os.MkdirAll(filepath.Dir(configPath), 0755) + require.NoError(t, err) + + // Create a test config provider + configProvider := config.NewPathProvider(configPath) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(`{"provider_type": "none"}`)) // Deliberately not setting Content-Type header w := httptest.NewRecorder() - routes := &SecretsRoutes{} + routes := NewSecretsRoutesWithProvider(configProvider) routes.setupSecretsProvider(w, req) // Should still work as the handler doesn't strictly require content-type @@ -449,7 +547,20 @@ func TestRouterIntegration(t *testing.T) { t.Run("router setup test", func(t *testing.T) { t.Parallel() - router := SecretsRouter() + + // Create a temporary config directory for this test + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "toolhive", "config.yaml") + + // Ensure the directory exists + err := os.MkdirAll(filepath.Dir(configPath), 0755) + require.NoError(t, err) + + // Create a test config provider + configProvider := config.NewPathProvider(configPath) + + routes := NewSecretsRoutesWithProvider(configProvider) + router := secretsRouterWithRoutes(routes) // Test POST / endpoint setupReq := setupSecretsRequest{ diff --git a/pkg/client/config_test.go b/pkg/client/config_test.go index 104922808..6786d12e0 100644 --- a/pkg/client/config_test.go +++ b/pkg/client/config_test.go @@ -322,7 +322,7 @@ func TestSuccessfulClientConfigOperations(t *testing.T) { foundTypes[cf.ClientType] = true } - for _, expectedClient := range supportedClientIntegrations { + for _, expectedClient := range mockClientConfigs { assert.True(t, foundTypes[expectedClient.ClientType], "Should find config for client type %s", expectedClient.ClientType) } @@ -406,7 +406,8 @@ func TestSuccessfulClientConfigOperations(t *testing.T) { testURL := "http://localhost:9999/sse#test-server" for _, cf := range configs { - err := Upsert(cf, testServer, testURL, types.TransportTypeSSE.String()) + // Use the manager's Upsert method instead of the global function to avoid using the singleton config + err := manager.Upsert(cf, testServer, testURL, types.TransportTypeSSE.String()) require.NoError(t, err, "Should be able to add MCP server to %s config", cf.ClientType) // Read the file and verify the server was added diff --git a/pkg/client/discovery_test.go b/pkg/client/discovery_test.go index 96ad22cdd..95a214432 100644 --- a/pkg/client/discovery_test.go +++ b/pkg/client/discovery_test.go @@ -8,9 +8,11 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" "github.com/stacklok/toolhive/pkg/config" "github.com/stacklok/toolhive/pkg/groups" + "github.com/stacklok/toolhive/pkg/groups/mocks" "github.com/stacklok/toolhive/pkg/logger" ) @@ -201,27 +203,22 @@ func TestGetClientStatus_WithGroups(t *testing.T) { _, err = os.Create(filepath.Join(tempHome, ".claude.json")) require.NoError(t, err) - // Create a real groups manager for testing - groupManager, err := groups.NewManager() - require.NoError(t, err) + // Create a mock groups manager instead of a real one to avoid modifying host configuration + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockGroupManager := mocks.NewMockManager(ctrl) - // Clean up any existing test groups + // Set up mock expectations ctx := context.Background() - existingGroups, _ := groupManager.List(ctx) - for _, group := range existingGroups { - if group.Name == "test-dev-group" || group.Name == "test-prod-group" { - groupManager.Delete(ctx, group.Name) - } + mockGroups := []*groups.Group{ + { + Name: "test-dev-group", + RegisteredClients: []string{string(ClaudeCode), string(Cursor)}, + }, } - // Create test groups with registered clients - err = groupManager.Create(ctx, "test-dev-group") - require.NoError(t, err) - defer groupManager.Delete(ctx, "test-dev-group") - - // Register clients with groups - err = groupManager.RegisterClients(ctx, []string{"test-dev-group"}, []string{string(ClaudeCode), string(Cursor)}) - require.NoError(t, err) + mockGroupManager.EXPECT().List(ctx).Return(mockGroups, nil).AnyTimes() // Now test GetClientStatus using ClientManager with dependency injection // Use explicit client integrations for this test to avoid race conditions with global variable @@ -242,7 +239,16 @@ func TestGetClientStatus_WithGroups(t *testing.T) { }, } - manager := NewTestClientManager(tempHome, groupManager, clientIntegrations, config.NewDefaultProvider()) + // Create a test config provider instead of using the default one + testConfig := &config.Config{ + Clients: config.Clients{ + RegisteredClients: []string{}, // Empty to test group-based registration + }, + } + configProvider, cleanup := createTestConfigProvider(t, testConfig) + defer cleanup() + + manager := NewTestClientManager(tempHome, mockGroupManager, clientIntegrations, configProvider) statuses, err := manager.GetClientStatus(ctx) require.NoError(t, err) require.NotNil(t, statuses) diff --git a/pkg/client/manager.go b/pkg/client/manager.go index 79802f2ff..0d11bdaad 100644 --- a/pkg/client/manager.go +++ b/pkg/client/manager.go @@ -45,8 +45,9 @@ type Manager interface { } type defaultManager struct { - runtime rt.Runtime - groupManager groups.Manager + runtime rt.Runtime + groupManager groups.Manager + configProvider config.Provider } // NewManager creates a new client manager instance. @@ -62,13 +63,40 @@ func NewManager(ctx context.Context) (Manager, error) { } return &defaultManager{ - runtime: runtime, - groupManager: groupManager, + runtime: runtime, + groupManager: groupManager, + configProvider: config.NewDefaultProvider(), }, nil } +// NewManagerWithProvider creates a new client manager instance with a custom config provider. +// This is useful for testing to avoid using the singleton config. +func NewManagerWithProvider(ctx context.Context, configProvider config.Provider) (Manager, error) { + runtime, err := ct.NewFactory().Create(ctx) + if err != nil { + return nil, err + } + + groupManager, err := groups.NewManager() + if err != nil { + return nil, err + } + + return &defaultManager{ + runtime: runtime, + groupManager: groupManager, + configProvider: configProvider, + }, nil +} + +// SetConfigProvider sets a custom config provider for testing purposes. +// This allows tests to inject a test config provider to avoid modifying the real config file. +func (m *defaultManager) SetConfigProvider(provider config.Provider) { + m.configProvider = provider +} + func (m *defaultManager) ListClients(ctx context.Context) ([]RegisteredClient, error) { - cfg := config.GetConfig() + cfg := m.configProvider.GetConfig() // Get all groups allGroups, err := m.groupManager.List(ctx) @@ -288,7 +316,7 @@ func (m *defaultManager) getTargetClients(ctx context.Context, serverName, group } // Server has no group - use backward compatible behavior (update all registered clients) - appConfig := config.GetConfig() + appConfig := m.configProvider.GetConfig() targetClients := appConfig.Clients.RegisteredClients logger.Infof( "Server %s has no group, updating %d globally registered client(s) for backward compatibility", diff --git a/pkg/client/migration.go b/pkg/client/migration.go index c90929a30..45684b608 100644 --- a/pkg/client/migration.go +++ b/pkg/client/migration.go @@ -16,7 +16,8 @@ var migrationOnce sync.Once // This is called once at application startup func CheckAndPerformAutoDiscoveryMigration() { migrationOnce.Do(func() { - appConfig := config.GetConfig() + cfgprv := config.NewDefaultProvider() + appConfig := cfgprv.GetConfig() // Check if auto-discovery flag is set to true, use of deprecated object is expected here if appConfig.Clients.AutoDiscovery { @@ -43,7 +44,8 @@ func performAutoDiscoveryMigration() { } // Get current config to see what's already registered - appConfig := config.GetConfig() + cfgprv := config.NewDefaultProvider() + appConfig := cfgprv.GetConfig() var clientsToRegister []string var alreadyRegistered = appConfig.Clients.RegisteredClients diff --git a/pkg/config/config.go b/pkg/config/config.go index 00f7fb4b1..11dcf5255 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -174,9 +174,9 @@ func LoadOrCreateConfigWithPath(configPath string) (*Config, error) { // Create a new config with default values. config = createNewConfigWithDefaults() - // Persist the new default to disk. + // Persist the new default to disk using the specific path logger.Debugf("initializing configuration file at %s", configPath) - err = config.save() + err = config.saveToPath(configPath) if err != nil { return nil, fmt.Errorf("failed to write default config: %w", err) } diff --git a/pkg/config/interface.go b/pkg/config/interface.go index 7d63c5593..ba3a3c466 100644 --- a/pkg/config/interface.go +++ b/pkg/config/interface.go @@ -5,6 +5,12 @@ type Provider interface { GetConfig() *Config UpdateConfig(updateFn func(*Config)) error LoadOrCreateConfig() (*Config, error) + + // Registry operations + SetRegistryURL(registryURL string, allowPrivateRegistryIp bool) error + SetRegistryFile(registryPath string) error + UnsetRegistry() error + GetRegistryConfig() (url, localPath string, allowPrivateIP bool, registryType string) } // DefaultProvider implements Provider using the default XDG config path @@ -17,7 +23,7 @@ func NewDefaultProvider() *DefaultProvider { // GetConfig returns the singleton config (for backward compatibility) func (*DefaultProvider) GetConfig() *Config { - return GetConfig() + return getSingletonConfig() } // UpdateConfig updates the config using the default path @@ -30,6 +36,26 @@ func (*DefaultProvider) LoadOrCreateConfig() (*Config, error) { return LoadOrCreateConfig() } +// SetRegistryURL validates and sets a registry URL +func (d *DefaultProvider) SetRegistryURL(registryURL string, allowPrivateRegistryIp bool) error { + return setRegistryURL(d, registryURL, allowPrivateRegistryIp) +} + +// SetRegistryFile validates and sets a local registry file +func (d *DefaultProvider) SetRegistryFile(registryPath string) error { + return setRegistryFile(d, registryPath) +} + +// UnsetRegistry resets registry configuration to defaults +func (d *DefaultProvider) UnsetRegistry() error { + return unsetRegistry(d) +} + +// GetRegistryConfig returns current registry configuration +func (d *DefaultProvider) GetRegistryConfig() (url, localPath string, allowPrivateIP bool, registryType string) { + return getRegistryConfig(d) +} + // PathProvider implements Provider using a specific config path type PathProvider struct { configPath string @@ -60,3 +86,23 @@ func (p *PathProvider) UpdateConfig(updateFn func(*Config)) error { func (p *PathProvider) LoadOrCreateConfig() (*Config, error) { return LoadOrCreateConfigWithPath(p.configPath) } + +// SetRegistryURL validates and sets a registry URL +func (p *PathProvider) SetRegistryURL(registryURL string, allowPrivateRegistryIp bool) error { + return setRegistryURL(p, registryURL, allowPrivateRegistryIp) +} + +// SetRegistryFile validates and sets a local registry file +func (p *PathProvider) SetRegistryFile(registryPath string) error { + return setRegistryFile(p, registryPath) +} + +// UnsetRegistry resets registry configuration to defaults +func (p *PathProvider) UnsetRegistry() error { + return unsetRegistry(p) +} + +// GetRegistryConfig returns current registry configuration +func (p *PathProvider) GetRegistryConfig() (url, localPath string, allowPrivateIP bool, registryType string) { + return getRegistryConfig(p) +} diff --git a/pkg/config/registry.go b/pkg/config/registry.go index cec377b65..ea6502a41 100644 --- a/pkg/config/registry.go +++ b/pkg/config/registry.go @@ -34,8 +34,8 @@ func DetectRegistryType(input string) (registryType string, cleanPath string) { return RegistryTypeFile, filepath.Clean(input) } -// SetRegistryURL validates and sets a registry URL -func SetRegistryURL(registryURL string, allowPrivateRegistryIp bool) error { +// setRegistryURL validates and sets a registry URL using the provided provider +func setRegistryURL(provider Provider, registryURL string, allowPrivateRegistryIp bool) error { parsedURL, err := neturl.Parse(registryURL) if err != nil { return fmt.Errorf("invalid registry URL: %w", err) @@ -65,7 +65,7 @@ func SetRegistryURL(registryURL string, allowPrivateRegistryIp bool) error { } // Update the configuration - err = UpdateConfig(func(c *Config) { + err = provider.UpdateConfig(func(c *Config) { c.RegistryUrl = registryURL c.LocalRegistryPath = "" // Clear local path when setting URL c.AllowPrivateRegistryIp = allowPrivateRegistryIp @@ -77,8 +77,8 @@ func SetRegistryURL(registryURL string, allowPrivateRegistryIp bool) error { return nil } -// SetRegistryFile validates and sets a local registry file -func SetRegistryFile(registryPath string) error { +// setRegistryFile validates and sets a local registry file using the provided provider +func setRegistryFile(provider Provider, registryPath string) error { // Validate that the file exists and is readable if _, err := os.Stat(registryPath); err != nil { return fmt.Errorf("local registry file not found or not accessible: %w", err) @@ -109,7 +109,7 @@ func SetRegistryFile(registryPath string) error { } // Update the configuration - err = UpdateConfig(func(c *Config) { + err = provider.UpdateConfig(func(c *Config) { c.LocalRegistryPath = absPath c.RegistryUrl = "" // Clear URL when setting local path }) @@ -120,9 +120,9 @@ func SetRegistryFile(registryPath string) error { return nil } -// UnsetRegistry resets registry configuration to defaults -func UnsetRegistry() error { - err := UpdateConfig(func(c *Config) { +// unsetRegistry resets registry configuration to defaults using the provided provider +func unsetRegistry(provider Provider) error { + err := provider.UpdateConfig(func(c *Config) { c.RegistryUrl = "" c.LocalRegistryPath = "" c.AllowPrivateRegistryIp = false @@ -133,9 +133,9 @@ func UnsetRegistry() error { return nil } -// GetRegistryConfig returns current registry configuration -func GetRegistryConfig() (url, localPath string, allowPrivateIP bool, registryType string) { - cfg := GetConfig() +// getRegistryConfig returns current registry configuration using the provided provider +func getRegistryConfig(provider Provider) (url, localPath string, allowPrivateIP bool, registryType string) { + cfg := provider.GetConfig() if cfg.RegistryUrl != "" { return cfg.RegistryUrl, "", cfg.AllowPrivateRegistryIp, RegistryTypeURL diff --git a/pkg/config/singleton.go b/pkg/config/singleton.go index ef983e969..0b1bace68 100644 --- a/pkg/config/singleton.go +++ b/pkg/config/singleton.go @@ -7,25 +7,47 @@ import ( "github.com/stacklok/toolhive/pkg/logger" ) -// Singleton value - should only be written to by the GetConfig function. +// Singleton value - should only be written to by the getSingletonConfig function. var appConfig *Config -var lock = &sync.Mutex{} +var lock = &sync.RWMutex{} -// GetConfig is a Singleton that returns the application configuration. -func GetConfig() *Config { +// SetSingletonConfig allows tests to pre-initialize the singleton with test data +// This prevents the singleton from loading the real config file during tests +func SetSingletonConfig(cfg *Config) { + lock.Lock() + defer lock.Unlock() + appConfig = cfg +} + +// ResetSingleton clears the singleton - useful for test cleanup +func ResetSingleton() { + lock.Lock() + defer lock.Unlock() + appConfig = nil +} + +// getSingletonConfig is a Singleton that returns the application configuration. +// This is only used internally by the DefaultProvider +func getSingletonConfig() *Config { + // First check with read lock for performance + lock.RLock() + if appConfig != nil { + defer lock.RUnlock() + return appConfig + } + lock.RUnlock() + + // If config is nil, acquire write lock and double-check + lock.Lock() + defer lock.Unlock() if appConfig == nil { - lock.Lock() - defer lock.Unlock() - if appConfig == nil { - appConfig, err := LoadOrCreateConfig() - if err != nil { - logger.Errorf("error loading configuration: %v", err) - os.Exit(1) - } - - return appConfig + config, err := LoadOrCreateConfig() + if err != nil { + logger.Errorf("error loading configuration: %v", err) + os.Exit(1) } + appConfig = config } return appConfig } diff --git a/pkg/migration/default_group.go b/pkg/migration/default_group.go index 706493f3f..312fa61c4 100644 --- a/pkg/migration/default_group.go +++ b/pkg/migration/default_group.go @@ -15,6 +15,7 @@ import ( type DefaultGroupMigrator struct { groupManager groups.Manager workloadsManager workloads.Manager + configProvider config.Provider } // Migrate performs the complete default group migration @@ -124,7 +125,7 @@ func (m *DefaultGroupMigrator) migrateWorkloadsToDefaultGroup(ctx context.Contex // migrateClientConfigs migrates client configurations from global config to default group func (m *DefaultGroupMigrator) migrateClientConfigs(ctx context.Context) error { - appConfig := config.GetConfig() + appConfig := m.configProvider.GetConfig() // If there are no registered clients, nothing to migrate if len(appConfig.Clients.RegisteredClients) == 0 { diff --git a/pkg/migration/migration.go b/pkg/migration/migration.go index 153f015bf..620fba09f 100644 --- a/pkg/migration/migration.go +++ b/pkg/migration/migration.go @@ -5,6 +5,7 @@ import ( "context" "sync" + "github.com/stacklok/toolhive/pkg/config" "github.com/stacklok/toolhive/pkg/logger" ) @@ -24,6 +25,8 @@ func CheckAndPerformDefaultGroupMigration() { // performDefaultGroupMigration migrates all existing workloads to the default group func performDefaultGroupMigration() error { - migrator := &DefaultGroupMigrator{} + migrator := &DefaultGroupMigrator{ + configProvider: config.NewDefaultProvider(), + } return migrator.Migrate(context.Background()) } diff --git a/pkg/registry/factory.go b/pkg/registry/factory.go index e4e1fbf66..89b06c29d 100644 --- a/pkg/registry/factory.go +++ b/pkg/registry/factory.go @@ -10,6 +10,11 @@ var ( defaultProvider Provider defaultProviderOnce sync.Once defaultProviderErr error + // defaultProviderMu protects the ResetDefaultProvider operation + // to prevent race conditions when resetting the sync.Once. + // The mutex is NOT needed for GetDefaultProviderWithConfig since + // sync.Once already provides thread-safety for initialization. + defaultProviderMu sync.Mutex ) // NewRegistryProvider creates a new registry provider based on the configuration @@ -26,8 +31,14 @@ func NewRegistryProvider(cfg *config.Config) Provider { // GetDefaultProvider returns the default registry provider instance // This maintains backward compatibility with the existing singleton pattern func GetDefaultProvider() (Provider, error) { + return GetDefaultProviderWithConfig(config.NewDefaultProvider()) +} + +// GetDefaultProviderWithConfig returns a registry provider using the given config provider +// This allows tests to inject their own config provider +func GetDefaultProviderWithConfig(configProvider config.Provider) (Provider, error) { defaultProviderOnce.Do(func() { - cfg, err := config.LoadOrCreateConfig() + cfg, err := configProvider.LoadOrCreateConfig() if err != nil { defaultProviderErr = err return @@ -39,8 +50,15 @@ func GetDefaultProvider() (Provider, error) { } // ResetDefaultProvider clears the cached default provider instance -// This allows the provider to be recreated with updated configuration +// This allows the provider to be recreated with updated configuration. +// This function is thread-safe and can be called concurrently. +// The mutex is required here because we're modifying the sync.Once itself, +// which is not a thread-safe operation. func ResetDefaultProvider() { + defaultProviderMu.Lock() + defer defaultProviderMu.Unlock() + + // Reset the sync.Once to allow re-initialization defaultProviderOnce = sync.Once{} defaultProvider = nil defaultProviderErr = nil diff --git a/pkg/registry/provider_test.go b/pkg/registry/provider_test.go index 218b93e8a..8e9e597b0 100644 --- a/pkg/registry/provider_test.go +++ b/pkg/registry/provider_test.go @@ -5,6 +5,8 @@ import ( "path/filepath" "testing" + "github.com/stretchr/testify/require" + "github.com/stacklok/toolhive/pkg/config" ) @@ -215,10 +217,24 @@ func getTypeName(v interface{}) string { func TestGetRegistry(t *testing.T) { t.Parallel() - provider, err := GetDefaultProvider() - if err != nil { - t.Fatalf("Failed to get registry provider: %v", err) - } + + // Create a temporary config for testing + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "toolhive", "config.yaml") + + // Ensure the directory exists + err := os.MkdirAll(filepath.Dir(configPath), 0755) + require.NoError(t, err) + + // Create a test config provider + configProvider := config.NewPathProvider(configPath) + + // Create a test config + cfg, err := configProvider.LoadOrCreateConfig() + require.NoError(t, err) + + // Create provider with test config + provider := NewRegistryProvider(cfg) reg, err := provider.GetRegistry() if err != nil { t.Fatalf("Failed to get registry: %v", err) @@ -244,11 +260,26 @@ func TestGetRegistry(t *testing.T) { func TestGetServer(t *testing.T) { t.Parallel() + + // Create a temporary config for testing + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "toolhive", "config.yaml") + + // Ensure the directory exists + err := os.MkdirAll(filepath.Dir(configPath), 0755) + require.NoError(t, err) + + // Create a test config provider + configProvider := config.NewPathProvider(configPath) + + // Create a test config + cfg, err := configProvider.LoadOrCreateConfig() + require.NoError(t, err) + + // Create provider with test config + provider := NewRegistryProvider(cfg) + // Test getting an existing server - provider, err := GetDefaultProvider() - if err != nil { - t.Fatalf("Failed to get registry provider: %v", err) - } server, err := provider.GetServer("osv") if err != nil { t.Fatalf("Failed to get server: %v", err) @@ -281,11 +312,26 @@ func TestGetServer(t *testing.T) { func TestSearchServers(t *testing.T) { t.Parallel() + + // Create a temporary config for testing + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "toolhive", "config.yaml") + + // Ensure the directory exists + err := os.MkdirAll(filepath.Dir(configPath), 0755) + require.NoError(t, err) + + // Create a test config provider + configProvider := config.NewPathProvider(configPath) + + // Create a test config + cfg, err := configProvider.LoadOrCreateConfig() + require.NoError(t, err) + + // Create provider with test config + provider := NewRegistryProvider(cfg) + // Test searching for servers - provider, err := GetDefaultProvider() - if err != nil { - t.Fatalf("Failed to get registry provider: %v", err) - } servers, err := provider.SearchServers("search") if err != nil { t.Fatalf("Failed to search servers: %v", err) @@ -308,7 +354,25 @@ func TestSearchServers(t *testing.T) { func TestListServers(t *testing.T) { t.Parallel() - provider, err := GetDefaultProvider() + + // Create a temporary config for testing + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "toolhive", "config.yaml") + + // Ensure the directory exists + err := os.MkdirAll(filepath.Dir(configPath), 0755) + require.NoError(t, err) + + // Create a test config provider + configProvider := config.NewPathProvider(configPath) + + // Reset the default provider to ensure clean state + ResetDefaultProvider() + t.Cleanup(func() { + ResetDefaultProvider() + }) + + provider, err := GetDefaultProviderWithConfig(configProvider) if err != nil { t.Fatalf("Failed to get registry provider: %v", err) } diff --git a/pkg/runner/env.go b/pkg/runner/env.go index f87c90f1d..ac140fcfa 100644 --- a/pkg/runner/env.go +++ b/pkg/runner/env.go @@ -65,11 +65,20 @@ func (*DetachedEnvVarValidator) Validate( // CLIEnvVarValidator implements the EnvVarValidator interface for // CLI usage. If any missing, mandatory variables are found, this code will // prompt the user to supply them through stdin. -type CLIEnvVarValidator struct{} +type CLIEnvVarValidator struct { + configProvider config.Provider +} + +// NewCLIEnvVarValidator creates a new CLI environment variable validator with the given config provider. +func NewCLIEnvVarValidator(configProvider config.Provider) *CLIEnvVarValidator { + return &CLIEnvVarValidator{ + configProvider: configProvider, + } +} // Validate checks that all required environment variables and secrets are provided // and returns the processed environment variables to be set. -func (*CLIEnvVarValidator) Validate( +func (v *CLIEnvVarValidator) Validate( ctx context.Context, metadata *registry.ImageMetadata, runConfig *RunConfig, @@ -94,7 +103,7 @@ func (*CLIEnvVarValidator) Validate( registryEnvVars := metadata.EnvVars // Initialize secrets manager if needed - secretsManager := initializeSecretsManagerIfNeeded(registryEnvVars) + secretsManager := v.initializeSecretsManagerIfNeeded(registryEnvVars) // Process each environment variable from the registry for _, envVar := range registryEnvVars { @@ -211,7 +220,7 @@ func addAsSecret( } // initializeSecretsManagerIfNeeded initializes the secrets manager if there are secret environment variables -func initializeSecretsManagerIfNeeded(registryEnvVars []*registry.EnvVar) secrets.Provider { +func (v *CLIEnvVarValidator) initializeSecretsManagerIfNeeded(registryEnvVars []*registry.EnvVar) secrets.Provider { // Check if we have any secret environment variables hasSecrets := false for _, envVar := range registryEnvVars { @@ -225,7 +234,7 @@ func initializeSecretsManagerIfNeeded(registryEnvVars []*registry.EnvVar) secret return nil } - secretsManager, err := getSecretsManager() + secretsManager, err := v.getSecretsManager() if err != nil { logger.Warnf("Warning: Failed to initialize secrets manager: %v", err) logger.Warnf("Secret environment variables will be stored as regular environment variables") @@ -237,8 +246,8 @@ func initializeSecretsManagerIfNeeded(registryEnvVars []*registry.EnvVar) secret // Duplicated from cmd/thv/app/app.go // It may be possible to de-duplicate this in future. -func getSecretsManager() (secrets.Provider, error) { - cfg := config.GetConfig() +func (v *CLIEnvVarValidator) getSecretsManager() (secrets.Provider, error) { + cfg := v.configProvider.GetConfig() // Check if secrets setup has been completed if !cfg.Secrets.SetupCompleted { diff --git a/pkg/runner/retriever/retriever.go b/pkg/runner/retriever/retriever.go index c57acc762..54fdda57b 100644 --- a/pkg/runner/retriever/retriever.go +++ b/pkg/runner/retriever/retriever.go @@ -164,7 +164,8 @@ func resolveCACertPath(flagValue string) string { } // Otherwise, check configuration - cfg := config.GetConfig() + configProvider := config.NewDefaultProvider() + cfg := configProvider.GetConfig() if cfg.CACertificatePath != "" { logger.Debugf("Using configured CA certificate: %s", cfg.CACertificatePath) return cfg.CACertificatePath diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index 4a4884b4a..6c29e4189 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -140,7 +140,8 @@ func (r *Runner) Run(ctx context.Context) error { // Process secrets if provided if len(r.Config.Secrets) > 0 { - cfg := config.GetConfig() + cfgprovider := config.NewDefaultProvider() + cfg := cfgprovider.GetConfig() providerType, err := cfg.Secrets.GetProviderType() if err != nil { diff --git a/pkg/workloads/manager.go b/pkg/workloads/manager.go index 4deaf3498..ec89daded 100644 --- a/pkg/workloads/manager.go +++ b/pkg/workloads/manager.go @@ -66,8 +66,9 @@ type Manager interface { } type defaultManager struct { - runtime rt.Runtime - statuses statuses.StatusManager + runtime rt.Runtime + statuses statuses.StatusManager + configProvider config.Provider } // ErrWorkloadNotRunning is returned when a container cannot be found by name. @@ -91,8 +92,28 @@ func NewManager(ctx context.Context) (Manager, error) { } return &defaultManager{ - runtime: runtime, - statuses: statusManager, + runtime: runtime, + statuses: statusManager, + configProvider: config.NewDefaultProvider(), + }, nil +} + +// NewManagerWithProvider creates a new container manager instance with a custom config provider. +func NewManagerWithProvider(ctx context.Context, configProvider config.Provider) (Manager, error) { + runtime, err := ct.NewFactory().Create(ctx) + if err != nil { + return nil, err + } + + statusManager, err := statuses.NewStatusManager(runtime) + if err != nil { + return nil, fmt.Errorf("failed to create status manager: %w", err) + } + + return &defaultManager{ + runtime: runtime, + statuses: statusManager, + configProvider: configProvider, }, nil } @@ -104,8 +125,24 @@ func NewManagerFromRuntime(runtime rt.Runtime) (Manager, error) { } return &defaultManager{ - runtime: runtime, - statuses: statusManager, + runtime: runtime, + statuses: statusManager, + configProvider: config.NewDefaultProvider(), + }, nil +} + +// NewManagerFromRuntimeWithProvider creates a new container manager instance from an existing runtime with a +// custom config provider. +func NewManagerFromRuntimeWithProvider(runtime rt.Runtime, configProvider config.Provider) (Manager, error) { + statusManager, err := statuses.NewStatusManager(runtime) + if err != nil { + return nil, fmt.Errorf("failed to create status manager: %w", err) + } + + return &defaultManager{ + runtime: runtime, + statuses: statusManager, + configProvider: configProvider, }, nil } @@ -291,10 +328,10 @@ func (d *defaultManager) RunWorkload(ctx context.Context, runConfig *runner.RunC return err } -func validateSecretParameters(ctx context.Context, runConfig *runner.RunConfig) error { +func (d *defaultManager) validateSecretParameters(ctx context.Context, runConfig *runner.RunConfig) error { // If there are run secrets, validate them if len(runConfig.Secrets) > 0 { - cfg := config.GetConfig() + cfg := d.configProvider.GetConfig() providerType, err := cfg.Secrets.GetProviderType() if err != nil { @@ -316,7 +353,7 @@ func validateSecretParameters(ctx context.Context, runConfig *runner.RunConfig) func (d *defaultManager) RunWorkloadDetached(ctx context.Context, runConfig *runner.RunConfig) error { // before running, validate the parameters for the workload - err := validateSecretParameters(ctx, runConfig) + err := d.validateSecretParameters(ctx, runConfig) if err != nil { return fmt.Errorf("failed to validate workload parameters: %w", err) } @@ -356,7 +393,7 @@ func (d *defaultManager) RunWorkloadDetached(ctx context.Context, runConfig *run // NOTE: This breaks the abstraction slightly since this is only relevant for the CLI, but there // are checks inside `GetSecretsPassword` to ensure this does not get called in a detached process. // This will be addressed in a future re-think of the secrets manager interface. - if needSecretsPassword(runConfig.Secrets) { + if d.needSecretsPassword(runConfig.Secrets) { password, err := secrets.GetSecretsPassword("") if err != nil { return fmt.Errorf("failed to get secrets password: %v", err) @@ -863,14 +900,14 @@ func (d *defaultManager) loadRunnerFromState(ctx context.Context, baseName strin return runner.NewRunner(runConfig, d.statuses), nil } -func needSecretsPassword(secretOptions []string) bool { +func (d *defaultManager) needSecretsPassword(secretOptions []string) bool { // If the user did not ask for any secrets, then don't attempt to instantiate // the secrets manager. if len(secretOptions) == 0 { return false } // Ignore err - if the flag is not set, it's not needed. - providerType, _ := config.GetConfig().Secrets.GetProviderType() + providerType, _ := d.configProvider.GetConfig().Secrets.GetProviderType() return providerType == secrets.EncryptedType }