diff --git a/.github/dependabot.yaml b/.github/dependabot.yaml new file mode 100644 index 0000000..8789fde --- /dev/null +++ b/.github/dependabot.yaml @@ -0,0 +1,11 @@ +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "daily" + + - package-ecosystem: "gomod" + directory: "/" + schedule: + interval: "daily" diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml new file mode 100644 index 0000000..f137095 --- /dev/null +++ b/.github/workflows/ci.yaml @@ -0,0 +1,78 @@ +name: ci + +on: + push: + branches: + - main + pull_request: + branches: + - main + +permissions: + contents: read + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - name: Checkout Code + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + + - name: Setup Go + uses: actions/setup-go@d35c59abb061a4a6fb18e82ac0862c26744d6ab5 # v5 + with: + go-version: '1.25' + + - name: Run go vet + run: go vet ./... + + - name: Check formatting + run: | + unformatted=$(gofmt -l .) + if [ -n "$unformatted" ]; then + echo "The following files are not formatted:" + echo "$unformatted" + exit 1 + fi + + test-unit: + runs-on: ubuntu-latest + steps: + - name: Checkout Code + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + + - name: Setup Go + uses: actions/setup-go@d35c59abb061a4a6fb18e82ac0862c26744d6ab5 # v5 + with: + go-version: '1.25' + + - name: Run unit tests + run: make test-unit + + test-integration: + runs-on: ubuntu-latest + steps: + - name: Checkout Code + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + + - name: Setup Go + uses: actions/setup-go@d35c59abb061a4a6fb18e82ac0862c26744d6ab5 # v5 + with: + go-version: '1.25' + + - name: Run integration tests + run: make test-integration + + build: + runs-on: ubuntu-latest + steps: + - name: Checkout Code + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + + - name: Setup Go + uses: actions/setup-go@d35c59abb061a4a6fb18e82ac0862c26744d6ab5 # v5 + with: + go-version: '1.25' + + - name: Build + run: make build diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml new file mode 100644 index 0000000..c174417 --- /dev/null +++ b/.github/workflows/release.yaml @@ -0,0 +1,29 @@ +name: release + +on: + push: + branches: + - main + +permissions: + contents: write + +jobs: + release: + runs-on: ubuntu-latest + steps: + - name: Checkout Code + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + - name: Bump Version + id: tag_version + uses: mathieudutour/github-tag-action@d28fa2ccfbd16e871a4bdf35e11b3ad1bd56c0c1 # v6.2 + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + default_bump: minor + custom_release_rules: bug:patch:Fixes,chore:patch:Chores,docs:patch:Documentation,feat:minor:Features,refactor:minor:Refactors,test:patch:Tests,ci:patch:Development,dev:patch:Development + - name: Create Release + uses: ncipollo/release-action@b7eabc95ff50cbeeedec83973935c8f306dfcd0b # v1 + with: + tag: ${{ steps.tag_version.outputs.new_tag }} + name: ${{ steps.tag_version.outputs.new_tag }} + body: ${{ steps.tag_version.outputs.changelog }} diff --git a/.github/workflows/semantic-check.yaml b/.github/workflows/semantic-check.yaml new file mode 100644 index 0000000..6f97f9c --- /dev/null +++ b/.github/workflows/semantic-check.yaml @@ -0,0 +1,26 @@ +name: semantic-check +on: + pull_request_target: + types: + - opened + - edited + - synchronize + +permissions: + contents: read + pull-requests: read + +jobs: + main: + name: Semantic Commit Message Check + runs-on: ubuntu-latest + steps: + - name: Checkout Code + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + - uses: amannn/action-semantic-pull-request@e32d7e603df1aa1ba07e981f2a23455dee596825 # v5 + name: Check PR for Semantic Commit Message + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + requireScope: false + validateSingleCommit: true diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..8622f50 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 CruxStack + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..515931a --- /dev/null +++ b/Makefile @@ -0,0 +1,59 @@ +.PHONY: all build test test-unit test-integration lint fmt vet clean tidy help + +GO ?= go +GOFLAGS ?= +PACKAGES := $(shell $(GO) list ./... | grep -v '/examples/' | grep -v '/docs/' | grep -v '/integration') +INTEGRATION_PKG := ./integration/... + +all: fmt vet lint test-unit ## Run all checks and unit tests + +build: ## Build all packages + $(GO) build $(GOFLAGS) ./... + +test: test-unit ## Alias for test-unit + +test-unit: ## Run unit tests only + $(GO) test $(GOFLAGS) $(PACKAGES) + +test-unit-v: ## Run unit tests with verbose output + $(GO) test $(GOFLAGS) -v $(PACKAGES) + +test-integration: ## Run integration tests + $(GO) test $(GOFLAGS) -tags=integration $(INTEGRATION_PKG) + +test-integration-v: ## Run integration tests with verbose output + VERBOSE=1 $(GO) test $(GOFLAGS) -v -tags=integration $(INTEGRATION_PKG) + +test-all: test-unit test-integration ## Run all tests (unit + integration) + +test-v: ## Run unit tests with verbose output (alias for test-unit-v) + $(GO) test $(GOFLAGS) -v $(PACKAGES) + +test-race: ## Run tests with race detector + $(GO) test $(GOFLAGS) -race $(PACKAGES) + +test-cover: ## Run tests with coverage + $(GO) test $(GOFLAGS) -cover $(PACKAGES) + +test-cover-html: ## Run tests and generate HTML coverage report + $(GO) test $(GOFLAGS) -coverprofile=coverage.out $(PACKAGES) + $(GO) tool cover -html=coverage.out -o coverage.html + +lint: ## Run linter (requires golangci-lint) + @which golangci-lint > /dev/null || (echo "golangci-lint not installed, skipping" && exit 0) + golangci-lint run ./... + +fmt: ## Format code + $(GO) fmt ./... + +vet: ## Run go vet + $(GO) vet ./... + +tidy: ## Tidy go.mod + $(GO) mod tidy + +clean: ## Clean build artifacts + rm -f coverage.out coverage.html + +help: ## Show this help + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-15s\033[0m %s\n", $$1, $$2}' diff --git a/README.md b/README.md index f8a2443..a7d5290 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,210 @@ # github-app-setup-go -A Go library for creating and managing GitHub Apps using the [GitHub App Manifest flow](https://docs.github.com/en/apps/sharing-github-apps/registering-a-github-app-from-a-manifest). Provides a web-based installer, multiple credential storage backends, and utilities for configuration management in containerized environments. +A Go library for creating and managing GitHub Apps using the +[GitHub App Manifest flow](https://docs.github.com/en/apps/sharing-github-apps/registering-a-github-app-from-a-manifest). +Provides a web-based installer, multiple credential storage backends, and +utilities for configuration management in containerized environments. +## Features + +- **Web-based installer** - User-friendly UI for creating GitHub Apps with + pre-configured permissions +- **Multiple storage backends** - AWS SSM Parameter Store, `.env` files, or + individual files +- **Hot reload support** - Reload configuration via SIGHUP or programmatic + triggers +- **SSM ARN resolution** - Resolve AWS SSM Parameter Store ARNs in environment + variables (useful for Lambda) +- **Ready gate** - HTTP middleware that returns 503 until configuration is + loaded + +## Installation + +```bash +go get github.com/cruxstack/github-app-setup-go +``` + +## Packages + +| Package | Description | +|---------------|-----------------------------------------------------------| +| `installer` | HTTP handler implementing the GitHub App Manifest flow | +| `configstore` | Storage backends for GitHub App credentials | +| `configwait` | Startup wait logic, ready gate middleware, and reload | +| `ssmresolver` | Resolves SSM Parameter Store ARNs in environment vars | + +## Quick Start + +```go +package main + +import ( + "context" + "log" + "net/http" + + "github.com/cruxstack/github-app-setup-go/configstore" + "github.com/cruxstack/github-app-setup-go/configwait" + "github.com/cruxstack/github-app-setup-go/installer" +) + +func main() { + ctx := context.Background() + + // Create a storage backend (uses STORAGE_MODE env var, defaults to .env file) + store, err := configstore.NewFromEnv() + if err != nil { + log.Fatal(err) + } + + // Define the GitHub App manifest with required permissions + manifest := installer.Manifest{ + URL: "https://example.com", + Public: false, + DefaultPerms: map[string]string{ + "contents": "read", + "pull_requests": "write", + }, + DefaultEvents: []string{"pull_request", "push"}, + } + + // Create the installer handler + installerHandler, err := installer.New(installer.Config{ + Store: store, + Manifest: manifest, + AppDisplayName: "My GitHub App", + }) + if err != nil { + log.Fatal(err) + } + + // Set up routes + mux := http.NewServeMux() + mux.Handle("/setup", installerHandler) + mux.Handle("/callback", installerHandler) + + // Create a ready gate that allows /setup through before app is configured + gate := configwait.NewReadyGate(mux, []string{"/setup", "/callback", "/healthz"}) + + // Start the server + log.Println("Starting server on :8080") + log.Fatal(http.ListenAndServe(":8080", gate)) +} +``` + +## Configuration + +### Environment Variables + +#### Installer + +| Variable | Description | Default | +|--------------------------------|---------------------------------------------|----------------------| +| `GITHUB_URL` | GitHub base URL (for GHE Server) | `https://github.com` | +| `GITHUB_ORG` | Organization (empty = personal account) | - | +| `GITHUB_APP_INSTALLER_ENABLED` | Enable the installer UI (`true`, `1`, `yes`)| - | + +#### Storage + +| Variable | Description | Default | +|---------------------------|----------------------------------------------|-------------| +| `STORAGE_MODE` | Backend: `envfile`, `files`, or `aws-ssm` | `envfile` | +| `STORAGE_DIR` | Directory/path for local storage backends | `./.env` | +| `AWS_SSM_PARAMETER_PREFIX`| SSM parameter path prefix (for `aws-ssm`) | - | +| `AWS_SSM_KMS_KEY_ID` | Custom KMS key for SSM encryption | AWS managed | +| `AWS_SSM_TAGS` | JSON object of tags for SSM parameters | - | + +#### Config Wait + +| Variable | Description | Default | +|-----------------------------|--------------------------------------|---------| +| `CONFIG_WAIT_MAX_RETRIES` | Maximum retry attempts | `30` | +| `CONFIG_WAIT_RETRY_INTERVAL`| Duration between retries (e.g., `2s`)| `2s` | + +## Storage Backends + +### AWS SSM Parameter Store + +Stores credentials as encrypted SecureString parameters: + +```go +store, err := configstore.NewAWSSSMStore("/my-app/prod/", + configstore.WithKMSKey("alias/my-key"), + configstore.WithTags(map[string]string{ + "Environment": "production", + }), +) +``` + +Parameters are stored at paths like `/my-app/prod/GITHUB_APP_ID`, +`/my-app/prod/GITHUB_APP_PRIVATE_KEY`, etc. + +### Local .env File + +Saves credentials to a `.env` file, preserving existing content: + +```go +store := configstore.NewLocalEnvFileStore("./.env") +``` + +### Local Files + +Saves each credential as a separate file: + +```go +store := configstore.NewLocalFileStore("./secrets/") +// Creates: ./secrets/app-id, ./secrets/private-key.pem, etc. +``` + +## Hot Reload + +The library supports hot-reloading configuration via SIGHUP signals or +programmatic triggers: + +```go +// Create a reloader that calls your reload function +reloader := configwait.NewReloader(ctx, gate, func(ctx context.Context) error { + // Reload your configuration here + newHandler := buildHandler() + gate.SetHandler(newHandler) + gate.SetReady() + return nil +}) + +// Set as global reloader (allows installer to trigger reload after saving) +configwait.SetGlobalReloader(reloader) + +// Start listening for SIGHUP +reloader.Start() +``` + +## SSM ARN Resolution + +For Lambda deployments where secrets are passed as SSM ARNs: + +```go +// Resolve all environment variables that contain SSM ARNs +err := ssmresolver.ResolveEnvironmentWithRetry(ctx, ssmresolver.NewRetryConfigFromEnv()) + +// Or resolve a single value +resolver, _ := ssmresolver.New(ctx) +value, err := resolver.ResolveValue(ctx, os.Getenv("MY_SECRET")) +``` + +## Stored Credentials + +After a GitHub App is created, the following credentials are stored: + +| Key | Description | +|---------------------------|---------------------------------------| +| `GITHUB_APP_ID` | The numeric App ID | +| `GITHUB_APP_SLUG` | The app's URL slug | +| `GITHUB_APP_HTML_URL` | URL to the app's GitHub settings page | +| `GITHUB_WEBHOOK_SECRET` | Webhook signature secret | +| `GITHUB_CLIENT_ID` | OAuth client ID | +| `GITHUB_CLIENT_SECRET` | OAuth client secret | +| `GITHUB_APP_PRIVATE_KEY` | Private key (PEM format) | + +## License + +MIT License - Copyright 2025 CruxStack diff --git a/configstore/aws_ssm_store.go b/configstore/aws_ssm_store.go new file mode 100644 index 0000000..127e7f3 --- /dev/null +++ b/configstore/aws_ssm_store.go @@ -0,0 +1,225 @@ +// Copyright 2025 CruxStack +// SPDX-License-Identifier: MIT + +package configstore + +import ( + "context" + "errors" + "fmt" + "strconv" + "strings" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/ssm" + "github.com/aws/aws-sdk-go-v2/service/ssm/types" +) + +// SSMClient defines the interface for AWS SSM operations. +type SSMClient interface { + PutParameter(ctx context.Context, params *ssm.PutParameterInput, + optFns ...func(*ssm.Options)) (*ssm.PutParameterOutput, error) + GetParameter(ctx context.Context, params *ssm.GetParameterInput, + optFns ...func(*ssm.Options)) (*ssm.GetParameterOutput, error) +} + +// AWSSSMStore saves credentials to AWS SSM Parameter Store with encryption. +type AWSSSMStore struct { + ParameterPrefix string + KMSKeyID string + Tags map[string]string + ssmClient SSMClient +} + +// SSMStoreOption is a functional option for configuring AWSSSMStore. +type SSMStoreOption func(*AWSSSMStore) + +// WithKMSKey sets a custom KMS key ID for parameter encryption. +func WithKMSKey(keyID string) SSMStoreOption { + return func(s *AWSSSMStore) { + s.KMSKeyID = keyID + } +} + +// WithTags adds AWS tags to all created parameters. +func WithTags(tags map[string]string) SSMStoreOption { + return func(s *AWSSSMStore) { + s.Tags = tags + } +} + +// WithSSMClient sets a custom SSM client. +func WithSSMClient(client SSMClient) SSMStoreOption { + return func(s *AWSSSMStore) { + s.ssmClient = client + } +} + +// NewAWSSSMStore creates a new AWS SSM Parameter Store backend. +// The prefix is normalized to always end with a slash. +func NewAWSSSMStore(prefix string, opts ...SSMStoreOption) (*AWSSSMStore, error) { + if prefix == "" { + return nil, fmt.Errorf("parameter prefix cannot be empty") + } + + if !strings.HasSuffix(prefix, "/") { + prefix = prefix + "/" + } + + store := &AWSSSMStore{ + ParameterPrefix: prefix, + } + + for _, opt := range opts { + opt(store) + } + + if store.ssmClient == nil { + cfg, err := config.LoadDefaultConfig(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to load AWS config: %w", err) + } + store.ssmClient = ssm.NewFromConfig(cfg) + } + + return store, nil +} + +// Save writes credentials to AWS SSM as encrypted SecureString parameters. +func (s *AWSSSMStore) Save(ctx context.Context, creds *AppCredentials) error { + parameters := map[string]string{ + EnvGitHubAppID: fmt.Sprintf("%d", creds.AppID), + EnvGitHubWebhookSecret: creds.WebhookSecret, + EnvGitHubClientID: creds.ClientID, + EnvGitHubClientSecret: creds.ClientSecret, + EnvGitHubAppPrivateKey: creds.PrivateKey, + } + + if creds.AppSlug != "" { + parameters[EnvGitHubAppSlug] = creds.AppSlug + } + if creds.HTMLURL != "" { + parameters[EnvGitHubAppHTMLURL] = creds.HTMLURL + } + + for key, value := range creds.CustomFields { + if value != "" { + parameters[key] = value + } + } + + for name, value := range parameters { + if err := s.putParameter(ctx, name, value); err != nil { + return fmt.Errorf("failed to save parameter %s: %w", name, err) + } + } + + return nil +} + +// putParameter creates or updates a single SSM parameter. +func (s *AWSSSMStore) putParameter(ctx context.Context, name, value string) error { + input := &ssm.PutParameterInput{ + Name: aws.String(s.ParameterPrefix + name), + Value: aws.String(value), + Type: types.ParameterTypeSecureString, + Overwrite: aws.Bool(true), + DataType: aws.String("text"), + } + + if s.KMSKeyID != "" { + input.KeyId = aws.String(s.KMSKeyID) + } + + if len(s.Tags) > 0 { + var tags []types.Tag + for key, value := range s.Tags { + tags = append(tags, types.Tag{ + Key: aws.String(key), + Value: aws.String(value), + }) + } + input.Tags = tags + } + + _, err := s.ssmClient.PutParameter(ctx, input) + if err != nil { + return err + } + + return nil +} + +// Status returns the current registration state by checking required SSM parameters. +func (s *AWSSSMStore) Status(ctx context.Context) (*InstallerStatus, error) { + status := &InstallerStatus{} + required := []string{ + EnvGitHubAppID, + EnvGitHubWebhookSecret, + EnvGitHubClientID, + EnvGitHubClientSecret, + EnvGitHubAppPrivateKey, + } + + values := make(map[string]string) + for _, key := range required { + value, err := s.getParameterValue(ctx, key) + if err != nil { + if isParameterNotFound(err) { + return status, nil + } + return nil, err + } + values[key] = value + } + + status.Registered = true + if id, err := strconv.ParseInt(strings.TrimSpace(values[EnvGitHubAppID]), 10, 64); err == nil { + status.AppID = id + } + + if slug, err := s.getParameterValue(ctx, EnvGitHubAppSlug); err == nil { + status.AppSlug = slug + } else if !isParameterNotFound(err) { + return nil, err + } + + if html, err := s.getParameterValue(ctx, EnvGitHubAppHTMLURL); err == nil { + status.HTMLURL = html + } else if !isParameterNotFound(err) { + return nil, err + } + + if flag, err := s.getParameterValue(ctx, EnvGitHubAppInstallerEnabled); err == nil { + status.InstallerDisabled = isFalseString(flag) + } else if !isParameterNotFound(err) { + return nil, err + } + + return status, nil +} + +// DisableInstaller sets a parameter to disable the installer. +func (s *AWSSSMStore) DisableInstaller(ctx context.Context) error { + return s.putParameter(ctx, EnvGitHubAppInstallerEnabled, "false") +} + +func (s *AWSSSMStore) getParameterValue(ctx context.Context, name string) (string, error) { + output, err := s.ssmClient.GetParameter(ctx, &ssm.GetParameterInput{ + Name: aws.String(s.ParameterPrefix + name), + WithDecryption: aws.Bool(true), + }) + if err != nil { + return "", err + } + if output.Parameter == nil || output.Parameter.Value == nil { + return "", fmt.Errorf("parameter %s missing value", name) + } + return aws.ToString(output.Parameter.Value), nil +} + +func isParameterNotFound(err error) bool { + var notFound *types.ParameterNotFound + return errors.As(err, ¬Found) +} diff --git a/configstore/aws_ssm_store_test.go b/configstore/aws_ssm_store_test.go new file mode 100644 index 0000000..4211ed9 --- /dev/null +++ b/configstore/aws_ssm_store_test.go @@ -0,0 +1,451 @@ +// Copyright 2025 CruxStack +// SPDX-License-Identifier: MIT + +package configstore + +import ( + "context" + "fmt" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ssm" + "github.com/aws/aws-sdk-go-v2/service/ssm/types" +) + +// mockSSMClient implements SSMClient for testing +type mockSSMClient struct { + parameters map[string]string + putCalls []ssm.PutParameterInput + getCalls []ssm.GetParameterInput + putErr error + getErr error +} + +func newMockSSMClient() *mockSSMClient { + return &mockSSMClient{ + parameters: make(map[string]string), + } +} + +func (m *mockSSMClient) PutParameter(ctx context.Context, params *ssm.PutParameterInput, optFns ...func(*ssm.Options)) (*ssm.PutParameterOutput, error) { + m.putCalls = append(m.putCalls, *params) + if m.putErr != nil { + return nil, m.putErr + } + m.parameters[*params.Name] = *params.Value + return &ssm.PutParameterOutput{}, nil +} + +func (m *mockSSMClient) GetParameter(ctx context.Context, params *ssm.GetParameterInput, optFns ...func(*ssm.Options)) (*ssm.GetParameterOutput, error) { + m.getCalls = append(m.getCalls, *params) + if m.getErr != nil { + return nil, m.getErr + } + value, ok := m.parameters[*params.Name] + if !ok { + return nil, &types.ParameterNotFound{} + } + return &ssm.GetParameterOutput{ + Parameter: &types.Parameter{ + Name: params.Name, + Value: aws.String(value), + }, + }, nil +} + +func TestNewAWSSSMStore(t *testing.T) { + t.Run("empty prefix returns error", func(t *testing.T) { + _, err := NewAWSSSMStore("") + if err == nil { + t.Error("NewAWSSSMStore(\"\") should return error") + } + }) + + t.Run("prefix without trailing slash is normalized", func(t *testing.T) { + mock := newMockSSMClient() + store, err := NewAWSSSMStore("/my/prefix", WithSSMClient(mock)) + if err != nil { + t.Fatalf("NewAWSSSMStore() error = %v", err) + } + if store.ParameterPrefix != "/my/prefix/" { + t.Errorf("ParameterPrefix = %q, want %q", store.ParameterPrefix, "/my/prefix/") + } + }) + + t.Run("prefix with trailing slash is preserved", func(t *testing.T) { + mock := newMockSSMClient() + store, err := NewAWSSSMStore("/my/prefix/", WithSSMClient(mock)) + if err != nil { + t.Fatalf("NewAWSSSMStore() error = %v", err) + } + if store.ParameterPrefix != "/my/prefix/" { + t.Errorf("ParameterPrefix = %q, want %q", store.ParameterPrefix, "/my/prefix/") + } + }) +} + +func TestAWSSSMStore_WithOptions(t *testing.T) { + t.Run("WithKMSKey sets KMS key ID", func(t *testing.T) { + mock := newMockSSMClient() + store, err := NewAWSSSMStore("/prefix", WithSSMClient(mock), WithKMSKey("alias/my-key")) + if err != nil { + t.Fatalf("NewAWSSSMStore() error = %v", err) + } + if store.KMSKeyID != "alias/my-key" { + t.Errorf("KMSKeyID = %q, want %q", store.KMSKeyID, "alias/my-key") + } + }) + + t.Run("WithTags sets tags", func(t *testing.T) { + mock := newMockSSMClient() + tags := map[string]string{"env": "prod", "team": "platform"} + store, err := NewAWSSSMStore("/prefix", WithSSMClient(mock), WithTags(tags)) + if err != nil { + t.Fatalf("NewAWSSSMStore() error = %v", err) + } + if len(store.Tags) != 2 { + t.Errorf("Tags count = %d, want 2", len(store.Tags)) + } + if store.Tags["env"] != "prod" { + t.Errorf("Tags[\"env\"] = %q, want %q", store.Tags["env"], "prod") + } + }) +} + +func TestAWSSSMStore_Save(t *testing.T) { + mock := newMockSSMClient() + store, err := NewAWSSSMStore("/app/github/", WithSSMClient(mock)) + if err != nil { + t.Fatalf("NewAWSSSMStore() error = %v", err) + } + + creds := &AppCredentials{ + AppID: 12345, + AppSlug: "my-app", + ClientID: "Iv1.abc123", + ClientSecret: "secret123", + WebhookSecret: "whsec_123", + PrivateKey: "-----BEGIN RSA PRIVATE KEY-----\nkey\n-----END RSA PRIVATE KEY-----", + HTMLURL: "https://github.com/apps/my-app", + CustomFields: map[string]string{ + "STS_DOMAIN": "sts.example.com", + }, + } + + err = store.Save(context.Background(), creds) + if err != nil { + t.Fatalf("Save() error = %v", err) + } + + // Verify all expected parameters were saved + expectedParams := map[string]string{ + "/app/github/GITHUB_APP_ID": "12345", + "/app/github/GITHUB_APP_SLUG": "my-app", + "/app/github/GITHUB_CLIENT_ID": "Iv1.abc123", + "/app/github/GITHUB_CLIENT_SECRET": "secret123", + "/app/github/GITHUB_WEBHOOK_SECRET": "whsec_123", + "/app/github/GITHUB_APP_PRIVATE_KEY": "-----BEGIN RSA PRIVATE KEY-----\nkey\n-----END RSA PRIVATE KEY-----", + "/app/github/GITHUB_APP_HTML_URL": "https://github.com/apps/my-app", + "/app/github/STS_DOMAIN": "sts.example.com", + } + + for name, wantValue := range expectedParams { + gotValue, ok := mock.parameters[name] + if !ok { + t.Errorf("Parameter %q was not saved", name) + continue + } + if gotValue != wantValue { + t.Errorf("Parameter %q = %q, want %q", name, gotValue, wantValue) + } + } + + // Verify all parameters were saved as SecureString + for _, call := range mock.putCalls { + if call.Type != types.ParameterTypeSecureString { + t.Errorf("Parameter %q type = %v, want SecureString", *call.Name, call.Type) + } + if call.Overwrite == nil || !*call.Overwrite { + t.Errorf("Parameter %q Overwrite should be true", *call.Name) + } + } +} + +func TestAWSSSMStore_Save_WithKMSKey(t *testing.T) { + mock := newMockSSMClient() + store, err := NewAWSSSMStore("/prefix/", WithSSMClient(mock), WithKMSKey("alias/custom-key")) + if err != nil { + t.Fatalf("NewAWSSSMStore() error = %v", err) + } + + creds := &AppCredentials{ + AppID: 1, + ClientID: "client", + ClientSecret: "secret", + WebhookSecret: "webhook", + PrivateKey: "key", + } + + err = store.Save(context.Background(), creds) + if err != nil { + t.Fatalf("Save() error = %v", err) + } + + // Verify KMS key was used + for _, call := range mock.putCalls { + if call.KeyId == nil || *call.KeyId != "alias/custom-key" { + t.Errorf("Parameter %q KeyId = %v, want alias/custom-key", *call.Name, call.KeyId) + } + } +} + +func TestAWSSSMStore_Save_WithTags(t *testing.T) { + mock := newMockSSMClient() + tags := map[string]string{"env": "test", "app": "myapp"} + store, err := NewAWSSSMStore("/prefix/", WithSSMClient(mock), WithTags(tags)) + if err != nil { + t.Fatalf("NewAWSSSMStore() error = %v", err) + } + + creds := &AppCredentials{ + AppID: 1, + ClientID: "client", + ClientSecret: "secret", + WebhookSecret: "webhook", + PrivateKey: "key", + } + + err = store.Save(context.Background(), creds) + if err != nil { + t.Fatalf("Save() error = %v", err) + } + + // Verify tags were applied to all parameters + for _, call := range mock.putCalls { + if len(call.Tags) != 2 { + t.Errorf("Parameter %q has %d tags, want 2", *call.Name, len(call.Tags)) + } + } +} + +func TestAWSSSMStore_Save_OmitsEmptyOptionalFields(t *testing.T) { + mock := newMockSSMClient() + store, err := NewAWSSSMStore("/prefix/", WithSSMClient(mock)) + if err != nil { + t.Fatalf("NewAWSSSMStore() error = %v", err) + } + + creds := &AppCredentials{ + AppID: 1, + ClientID: "client", + ClientSecret: "secret", + WebhookSecret: "webhook", + PrivateKey: "key", + // AppSlug and HTMLURL are empty + } + + err = store.Save(context.Background(), creds) + if err != nil { + t.Fatalf("Save() error = %v", err) + } + + // Optional parameters should not be saved + for _, name := range []string{"/prefix/GITHUB_APP_SLUG", "/prefix/GITHUB_APP_HTML_URL"} { + if _, ok := mock.parameters[name]; ok { + t.Errorf("Empty optional parameter %q should not be saved", name) + } + } +} + +func TestAWSSSMStore_Save_Error(t *testing.T) { + mock := newMockSSMClient() + mock.putErr = fmt.Errorf("access denied") + + store, err := NewAWSSSMStore("/prefix/", WithSSMClient(mock)) + if err != nil { + t.Fatalf("NewAWSSSMStore() error = %v", err) + } + + creds := &AppCredentials{ + AppID: 1, + ClientID: "client", + ClientSecret: "secret", + WebhookSecret: "webhook", + PrivateKey: "key", + } + + err = store.Save(context.Background(), creds) + if err == nil { + t.Error("Save() should return error when PutParameter fails") + } +} + +func TestAWSSSMStore_Status_Registered(t *testing.T) { + mock := newMockSSMClient() + mock.parameters = map[string]string{ + "/prefix/GITHUB_APP_ID": "12345", + "/prefix/GITHUB_APP_SLUG": "test-app", + "/prefix/GITHUB_APP_HTML_URL": "https://github.com/apps/test-app", + "/prefix/GITHUB_CLIENT_ID": "client123", + "/prefix/GITHUB_CLIENT_SECRET": "secret123", + "/prefix/GITHUB_WEBHOOK_SECRET": "webhook123", + "/prefix/GITHUB_APP_PRIVATE_KEY": "-----BEGIN RSA-----\nkey\n-----END RSA-----", + } + + store, err := NewAWSSSMStore("/prefix/", WithSSMClient(mock)) + if err != nil { + t.Fatalf("NewAWSSSMStore() error = %v", err) + } + + status, err := store.Status(context.Background()) + if err != nil { + t.Fatalf("Status() error = %v", err) + } + + if !status.Registered { + t.Error("Status.Registered = false, want true") + } + if status.AppID != 12345 { + t.Errorf("Status.AppID = %d, want 12345", status.AppID) + } + if status.AppSlug != "test-app" { + t.Errorf("Status.AppSlug = %q, want %q", status.AppSlug, "test-app") + } + if status.HTMLURL != "https://github.com/apps/test-app" { + t.Errorf("Status.HTMLURL = %q, want %q", status.HTMLURL, "https://github.com/apps/test-app") + } + if status.InstallerDisabled { + t.Error("Status.InstallerDisabled = true, want false") + } +} + +func TestAWSSSMStore_Status_NotRegistered(t *testing.T) { + mock := newMockSSMClient() + // No parameters exist + + store, err := NewAWSSSMStore("/prefix/", WithSSMClient(mock)) + if err != nil { + t.Fatalf("NewAWSSSMStore() error = %v", err) + } + + status, err := store.Status(context.Background()) + if err != nil { + t.Fatalf("Status() error = %v", err) + } + + if status.Registered { + t.Error("Status.Registered = true, want false (no parameters)") + } +} + +func TestAWSSSMStore_Status_InstallerDisabled(t *testing.T) { + mock := newMockSSMClient() + mock.parameters = map[string]string{ + "/prefix/GITHUB_APP_ID": "12345", + "/prefix/GITHUB_CLIENT_ID": "client", + "/prefix/GITHUB_CLIENT_SECRET": "secret", + "/prefix/GITHUB_WEBHOOK_SECRET": "webhook", + "/prefix/GITHUB_APP_PRIVATE_KEY": "key", + "/prefix/GITHUB_APP_INSTALLER_ENABLED": "false", + } + + store, err := NewAWSSSMStore("/prefix/", WithSSMClient(mock)) + if err != nil { + t.Fatalf("NewAWSSSMStore() error = %v", err) + } + + status, err := store.Status(context.Background()) + if err != nil { + t.Fatalf("Status() error = %v", err) + } + + if !status.InstallerDisabled { + t.Error("Status.InstallerDisabled = false, want true") + } +} + +func TestAWSSSMStore_Status_Error(t *testing.T) { + mock := newMockSSMClient() + mock.getErr = fmt.Errorf("access denied") + + store, err := NewAWSSSMStore("/prefix/", WithSSMClient(mock)) + if err != nil { + t.Fatalf("NewAWSSSMStore() error = %v", err) + } + + _, err = store.Status(context.Background()) + if err == nil { + t.Error("Status() should return error when GetParameter fails") + } +} + +func TestAWSSSMStore_DisableInstaller(t *testing.T) { + mock := newMockSSMClient() + store, err := NewAWSSSMStore("/prefix/", WithSSMClient(mock)) + if err != nil { + t.Fatalf("NewAWSSSMStore() error = %v", err) + } + + err = store.DisableInstaller(context.Background()) + if err != nil { + t.Fatalf("DisableInstaller() error = %v", err) + } + + // Verify the parameter was set + value, ok := mock.parameters["/prefix/GITHUB_APP_INSTALLER_ENABLED"] + if !ok { + t.Error("DisableInstaller() did not create parameter") + } + if value != "false" { + t.Errorf("Parameter value = %q, want %q", value, "false") + } +} + +func TestAWSSSMStore_DisableInstaller_Error(t *testing.T) { + mock := newMockSSMClient() + mock.putErr = fmt.Errorf("access denied") + + store, err := NewAWSSSMStore("/prefix/", WithSSMClient(mock)) + if err != nil { + t.Fatalf("NewAWSSSMStore() error = %v", err) + } + + err = store.DisableInstaller(context.Background()) + if err == nil { + t.Error("DisableInstaller() should return error when PutParameter fails") + } +} + +func TestIsParameterNotFound(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + { + name: "ParameterNotFound error", + err: &types.ParameterNotFound{}, + want: true, + }, + { + name: "other error", + err: fmt.Errorf("some other error"), + want: false, + }, + { + name: "nil error", + err: nil, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isParameterNotFound(tt.err) + if got != tt.want { + t.Errorf("isParameterNotFound() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/configstore/local_env_store.go b/configstore/local_env_store.go new file mode 100644 index 0000000..b94af76 --- /dev/null +++ b/configstore/local_env_store.go @@ -0,0 +1,245 @@ +// Copyright 2025 CruxStack +// SPDX-License-Identifier: MIT + +package configstore + +import ( + "bufio" + "context" + "fmt" + "os" + "path/filepath" + "strconv" + "strings" +) + +// LocalEnvFileStore saves credentials to a .env file. +type LocalEnvFileStore struct { + FilePath string +} + +// NewLocalEnvFileStore creates a store that saves credentials to the given path. +func NewLocalEnvFileStore(filepath string) *LocalEnvFileStore { + return &LocalEnvFileStore{FilePath: filepath} +} + +// Save writes credentials to .env format, preserving existing content. +// It also sets the environment variables in the current process so they +// are immediately available to the application. +func (s *LocalEnvFileStore) Save(ctx context.Context, creds *AppCredentials) error { + dir := filepath.Dir(s.FilePath) + if err := os.MkdirAll(dir, 0700); err != nil { + return fmt.Errorf("failed to create directory %s: %w", dir, err) + } + + existingValues, originalLines, err := parseEnvFile(s.FilePath) + if err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to read existing .env file: %w", err) + } + if existingValues == nil { + existingValues = make(map[string]string) + } + + for key, value := range creds.CustomFields { + if value != "" { + existingValues[key] = value + } + } + + singleLinePEM := strings.ReplaceAll(creds.PrivateKey, "\n", "\\n") + + existingValues[EnvGitHubAppID] = fmt.Sprintf("%d", creds.AppID) + existingValues[EnvGitHubWebhookSecret] = creds.WebhookSecret + existingValues[EnvGitHubClientID] = creds.ClientID + existingValues[EnvGitHubClientSecret] = creds.ClientSecret + existingValues[EnvGitHubAppPrivateKey] = singleLinePEM + if creds.AppSlug != "" { + existingValues[EnvGitHubAppSlug] = creds.AppSlug + } + if creds.HTMLURL != "" { + existingValues[EnvGitHubAppHTMLURL] = creds.HTMLURL + } + + if err := writeEnvFile(s.FilePath, existingValues, originalLines); err != nil { + return fmt.Errorf("failed to write .env file: %w", err) + } + + // Set environment variables in the current process so they are + // immediately available for configuration reload. + os.Setenv(EnvGitHubAppID, fmt.Sprintf("%d", creds.AppID)) + os.Setenv(EnvGitHubWebhookSecret, creds.WebhookSecret) + os.Setenv(EnvGitHubClientID, creds.ClientID) + os.Setenv(EnvGitHubClientSecret, creds.ClientSecret) + os.Setenv(EnvGitHubAppPrivateKey, singleLinePEM) + if creds.AppSlug != "" { + os.Setenv(EnvGitHubAppSlug, creds.AppSlug) + } + if creds.HTMLURL != "" { + os.Setenv(EnvGitHubAppHTMLURL, creds.HTMLURL) + } + for key, value := range creds.CustomFields { + if value != "" { + os.Setenv(key, value) + } + } + + return nil +} + +func parseEnvFile(path string) (map[string]string, []string, error) { + file, err := os.Open(path) + if err != nil { + return nil, nil, err + } + defer file.Close() + + values := make(map[string]string) + var lines []string + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := scanner.Text() + lines = append(lines, line) + + trimmed := strings.TrimSpace(line) + if trimmed == "" || strings.HasPrefix(trimmed, "#") { + continue + } + + idx := strings.Index(line, "=") + if idx == -1 { + continue + } + + key := strings.TrimSpace(line[:idx]) + value := strings.TrimSpace(line[idx+1:]) + + if len(value) >= 2 { + if (strings.HasPrefix(value, "\"") && strings.HasSuffix(value, "\"")) || + (strings.HasPrefix(value, "'") && strings.HasSuffix(value, "'")) { + value = value[1 : len(value)-1] + } + } + + values[key] = value + } + + if err := scanner.Err(); err != nil { + return nil, nil, err + } + + return values, lines, nil +} + +func writeEnvFile(path string, values map[string]string, originalLines []string) error { + var outputLines []string + writtenKeys := make(map[string]bool) + + for _, line := range originalLines { + trimmed := strings.TrimSpace(line) + + if trimmed == "" || strings.HasPrefix(trimmed, "#") { + outputLines = append(outputLines, line) + continue + } + + idx := strings.Index(line, "=") + if idx == -1 { + outputLines = append(outputLines, line) + continue + } + + key := strings.TrimSpace(line[:idx]) + + if newValue, ok := values[key]; ok { + outputLines = append(outputLines, formatEnvLine(key, newValue)) + writtenKeys[key] = true + } else { + outputLines = append(outputLines, line) + } + } + + for key, value := range values { + if !writtenKeys[key] { + outputLines = append(outputLines, formatEnvLine(key, value)) + } + } + + content := strings.Join(outputLines, "\n") + if len(outputLines) > 0 { + content += "\n" + } + + return os.WriteFile(path, []byte(content), 0600) +} + +func formatEnvLine(key, value string) string { + needsQuotes := strings.ContainsAny(value, " \t\n\r\"'\\#") || strings.Contains(value, "\\n") + + if needsQuotes { + escaped := strings.ReplaceAll(value, "\"", "\\\"") + return fmt.Sprintf("%s=\"%s\"", key, escaped) + } + + return fmt.Sprintf("%s=%s", key, value) +} + +// Status returns the current registration state by checking the .env file. +func (s *LocalEnvFileStore) Status(ctx context.Context) (*InstallerStatus, error) { + values, _, err := parseEnvFile(s.FilePath) + if err != nil { + if os.IsNotExist(err) { + return &InstallerStatus{}, nil + } + return nil, err + } + + status := &InstallerStatus{ + AppSlug: values[EnvGitHubAppSlug], + HTMLURL: values[EnvGitHubAppHTMLURL], + } + + if idStr := strings.TrimSpace(values[EnvGitHubAppID]); idStr != "" { + if id, err := strconv.ParseInt(idStr, 10, 64); err == nil { + status.AppID = id + } + } + + status.Registered = hasAllValues(values, + EnvGitHubAppID, + EnvGitHubWebhookSecret, + EnvGitHubClientID, + EnvGitHubClientSecret, + EnvGitHubAppPrivateKey, + ) + + status.InstallerDisabled = isFalseString(values[EnvGitHubAppInstallerEnabled]) + + return status, nil +} + +// DisableInstaller sets GITHUB_APP_INSTALLER_ENABLED=false in the .env file. +func (s *LocalEnvFileStore) DisableInstaller(ctx context.Context) error { + dir := filepath.Dir(s.FilePath) + if err := os.MkdirAll(dir, 0700); err != nil { + return fmt.Errorf("failed to create directory %s: %w", dir, err) + } + + values, originalLines, err := parseEnvFile(s.FilePath) + if err != nil { + if !os.IsNotExist(err) { + return err + } + } + if values == nil { + values = make(map[string]string) + } + + values[EnvGitHubAppInstallerEnabled] = "false" + + if err := writeEnvFile(s.FilePath, values, originalLines); err != nil { + return fmt.Errorf("failed to persist installer flag: %w", err) + } + + return nil +} diff --git a/configstore/local_env_store_test.go b/configstore/local_env_store_test.go new file mode 100644 index 0000000..fb652cd --- /dev/null +++ b/configstore/local_env_store_test.go @@ -0,0 +1,590 @@ +// Copyright 2025 CruxStack +// SPDX-License-Identifier: MIT + +package configstore + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestParseEnvFile(t *testing.T) { + tests := []struct { + name string + content string + wantValues map[string]string + wantLines int + }{ + { + name: "simple key-value pairs", + content: "FOO=bar\nBAZ=qux", + wantValues: map[string]string{ + "FOO": "bar", + "BAZ": "qux", + }, + wantLines: 2, + }, + { + name: "double-quoted values", + content: `KEY="value with spaces"`, + wantValues: map[string]string{ + "KEY": "value with spaces", + }, + wantLines: 1, + }, + { + name: "single-quoted values", + content: `KEY='value with spaces'`, + wantValues: map[string]string{ + "KEY": "value with spaces", + }, + wantLines: 1, + }, + { + name: "unquoted value with equals sign", + content: "KEY=value=with=equals", + wantValues: map[string]string{ + "KEY": "value=with=equals", + }, + wantLines: 1, + }, + { + name: "comments are ignored", + content: "# This is a comment\nKEY=value\n# Another comment", + wantValues: map[string]string{ + "KEY": "value", + }, + wantLines: 3, + }, + { + name: "empty lines preserved", + content: "FOO=bar\n\nBAZ=qux", + wantValues: map[string]string{ + "FOO": "bar", + "BAZ": "qux", + }, + wantLines: 3, + }, + { + name: "whitespace around equals", + content: "KEY = value", + wantValues: map[string]string{ + "KEY": "value", + }, + wantLines: 1, + }, + { + name: "PEM key with escaped newlines", + content: `PRIVATE_KEY="-----BEGIN RSA PRIVATE KEY-----\nMIIE...\n-----END RSA PRIVATE KEY-----"`, + wantValues: map[string]string{ + "PRIVATE_KEY": `-----BEGIN RSA PRIVATE KEY-----\nMIIE...\n-----END RSA PRIVATE KEY-----`, + }, + wantLines: 1, + }, + { + name: "empty file", + content: "", + wantValues: map[string]string{}, + wantLines: 0, + }, + { + name: "only comments", + content: "# comment 1\n# comment 2", + wantValues: map[string]string{}, + wantLines: 2, + }, + { + name: "line without equals ignored", + content: "INVALID_LINE\nVALID=true", + wantValues: map[string]string{ + "VALID": "true", + }, + wantLines: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := t.TempDir() + envPath := filepath.Join(tmpDir, ".env") + + if err := os.WriteFile(envPath, []byte(tt.content), 0600); err != nil { + t.Fatalf("Failed to write test file: %v", err) + } + + values, lines, err := parseEnvFile(envPath) + if err != nil { + t.Fatalf("parseEnvFile() error = %v", err) + } + + if len(values) != len(tt.wantValues) { + t.Errorf("parseEnvFile() got %d values, want %d", len(values), len(tt.wantValues)) + } + + for k, want := range tt.wantValues { + got, ok := values[k] + if !ok { + t.Errorf("parseEnvFile() missing key %q", k) + continue + } + if got != want { + t.Errorf("parseEnvFile()[%q] = %q, want %q", k, got, want) + } + } + + if len(lines) != tt.wantLines { + t.Errorf("parseEnvFile() got %d lines, want %d", len(lines), tt.wantLines) + } + }) + } +} + +func TestParseEnvFile_NotExists(t *testing.T) { + values, lines, err := parseEnvFile("/nonexistent/path/.env") + + if !os.IsNotExist(err) { + t.Errorf("parseEnvFile() error = %v, want os.IsNotExist", err) + } + if values != nil { + t.Errorf("parseEnvFile() values = %v, want nil", values) + } + if lines != nil { + t.Errorf("parseEnvFile() lines = %v, want nil", lines) + } +} + +func TestFormatEnvLine(t *testing.T) { + tests := []struct { + name string + key string + value string + want string + }{ + { + name: "simple value", + key: "KEY", + value: "value", + want: "KEY=value", + }, + { + name: "value with spaces needs quotes", + key: "KEY", + value: "value with spaces", + want: `KEY="value with spaces"`, + }, + { + name: "value with escaped newlines needs quotes", + key: "PEM", + value: `-----BEGIN RSA-----\nMIIE\n-----END RSA-----`, + want: `PEM="-----BEGIN RSA-----\nMIIE\n-----END RSA-----"`, + }, + { + name: "value with hash needs quotes", + key: "KEY", + value: "value#comment", + want: `KEY="value#comment"`, + }, + { + name: "value with double quote is escaped", + key: "KEY", + value: `value"quoted`, + want: `KEY="value\"quoted"`, + }, + { + name: "value with single quote needs quotes", + key: "KEY", + value: "it's", + want: `KEY="it's"`, + }, + { + name: "value with tab needs quotes", + key: "KEY", + value: "value\twith\ttabs", + want: `KEY="value with tabs"`, + }, + { + name: "empty value", + key: "KEY", + value: "", + want: "KEY=", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := formatEnvLine(tt.key, tt.value) + if got != tt.want { + t.Errorf("formatEnvLine(%q, %q) = %q, want %q", tt.key, tt.value, got, tt.want) + } + }) + } +} + +func TestWriteEnvFile_PreservesComments(t *testing.T) { + tmpDir := t.TempDir() + envPath := filepath.Join(tmpDir, ".env") + + originalContent := `# Database config +DB_HOST=localhost + +# App config +APP_NAME=test +` + if err := os.WriteFile(envPath, []byte(originalContent), 0600); err != nil { + t.Fatalf("Failed to write initial file: %v", err) + } + + values, lines, err := parseEnvFile(envPath) + if err != nil { + t.Fatalf("parseEnvFile() error = %v", err) + } + + // Update one value and add a new one + values["APP_NAME"] = "updated" + values["NEW_KEY"] = "new_value" + + if err := writeEnvFile(envPath, values, lines); err != nil { + t.Fatalf("writeEnvFile() error = %v", err) + } + + content, err := os.ReadFile(envPath) + if err != nil { + t.Fatalf("Failed to read file: %v", err) + } + + result := string(content) + + // Comments should be preserved + if !strings.Contains(result, "# Database config") { + t.Error("writeEnvFile() lost comment '# Database config'") + } + if !strings.Contains(result, "# App config") { + t.Error("writeEnvFile() lost comment '# App config'") + } + + // Updated value should be present + if !strings.Contains(result, "APP_NAME=updated") { + t.Error("writeEnvFile() did not update APP_NAME") + } + + // New key should be appended + if !strings.Contains(result, "NEW_KEY=new_value") { + t.Error("writeEnvFile() did not add NEW_KEY") + } + + // Original structure should be maintained + if !strings.Contains(result, "DB_HOST=localhost") { + t.Error("writeEnvFile() lost DB_HOST") + } +} + +func TestLocalEnvFileStore_Save(t *testing.T) { + tmpDir := t.TempDir() + envPath := filepath.Join(tmpDir, ".env") + + store := NewLocalEnvFileStore(envPath) + creds := &AppCredentials{ + AppID: 12345, + AppSlug: "my-app", + ClientID: "Iv1.abc123", + ClientSecret: "secret123", + WebhookSecret: "whsec_123", + PrivateKey: "-----BEGIN RSA PRIVATE KEY-----\nMIIE...\n-----END RSA PRIVATE KEY-----\n", + HTMLURL: "https://github.com/apps/my-app", + CustomFields: map[string]string{ + "STS_DOMAIN": "sts.example.com", + }, + } + + err := store.Save(context.Background(), creds) + if err != nil { + t.Fatalf("Save() error = %v", err) + } + + // Verify file was created with correct permissions + info, err := os.Stat(envPath) + if err != nil { + t.Fatalf("File not created: %v", err) + } + if info.Mode().Perm() != 0600 { + t.Errorf("File permissions = %o, want 0600", info.Mode().Perm()) + } + + // Parse and verify contents + values, _, err := parseEnvFile(envPath) + if err != nil { + t.Fatalf("parseEnvFile() error = %v", err) + } + + checks := map[string]string{ + EnvGitHubAppID: "12345", + EnvGitHubAppSlug: "my-app", + EnvGitHubClientID: "Iv1.abc123", + EnvGitHubClientSecret: "secret123", + EnvGitHubWebhookSecret: "whsec_123", + EnvGitHubAppHTMLURL: "https://github.com/apps/my-app", + "STS_DOMAIN": "sts.example.com", + } + + for key, want := range checks { + got := values[key] + if got != want { + t.Errorf("values[%q] = %q, want %q", key, got, want) + } + } + + // PEM key should have newlines escaped + pemValue := values[EnvGitHubAppPrivateKey] + if strings.Contains(pemValue, "\n") && !strings.Contains(pemValue, "\\n") { + t.Error("Private key should have literal \\n not actual newlines") + } +} + +func TestLocalEnvFileStore_Save_CreatesDirectory(t *testing.T) { + tmpDir := t.TempDir() + envPath := filepath.Join(tmpDir, "nested", "dir", ".env") + + store := NewLocalEnvFileStore(envPath) + creds := &AppCredentials{ + AppID: 1, + ClientID: "client", + ClientSecret: "secret", + WebhookSecret: "webhook", + PrivateKey: "key", + } + + err := store.Save(context.Background(), creds) + if err != nil { + t.Fatalf("Save() error = %v", err) + } + + if _, err := os.Stat(envPath); os.IsNotExist(err) { + t.Error("Save() did not create file in nested directory") + } +} + +func TestLocalEnvFileStore_Save_PreservesExistingValues(t *testing.T) { + tmpDir := t.TempDir() + envPath := filepath.Join(tmpDir, ".env") + + // Write initial content + initialContent := `# My app config +EXISTING_KEY=existing_value +PORT=8080 +` + if err := os.WriteFile(envPath, []byte(initialContent), 0600); err != nil { + t.Fatalf("Failed to write initial file: %v", err) + } + + store := NewLocalEnvFileStore(envPath) + creds := &AppCredentials{ + AppID: 1, + ClientID: "client", + ClientSecret: "secret", + WebhookSecret: "webhook", + PrivateKey: "key", + } + + err := store.Save(context.Background(), creds) + if err != nil { + t.Fatalf("Save() error = %v", err) + } + + values, _, err := parseEnvFile(envPath) + if err != nil { + t.Fatalf("parseEnvFile() error = %v", err) + } + + // Existing values should be preserved + if values["EXISTING_KEY"] != "existing_value" { + t.Error("Save() overwrote EXISTING_KEY") + } + if values["PORT"] != "8080" { + t.Error("Save() overwrote PORT") + } + + // New values should be present + if values[EnvGitHubAppID] != "1" { + t.Error("Save() did not write GITHUB_APP_ID") + } +} + +func TestLocalEnvFileStore_Status(t *testing.T) { + tmpDir := t.TempDir() + envPath := filepath.Join(tmpDir, ".env") + + content := `GITHUB_APP_ID=12345 +GITHUB_APP_SLUG=my-app +GITHUB_APP_HTML_URL=https://github.com/apps/my-app +GITHUB_APP_PRIVATE_KEY="-----BEGIN RSA-----\nMIIE\n-----END RSA-----" +GITHUB_WEBHOOK_SECRET=whsec_123 +GITHUB_CLIENT_ID=Iv1.abc +GITHUB_CLIENT_SECRET=secret +` + if err := os.WriteFile(envPath, []byte(content), 0600); err != nil { + t.Fatalf("Failed to write file: %v", err) + } + + store := NewLocalEnvFileStore(envPath) + status, err := store.Status(context.Background()) + if err != nil { + t.Fatalf("Status() error = %v", err) + } + + if !status.Registered { + t.Error("Status.Registered = false, want true") + } + if status.AppID != 12345 { + t.Errorf("Status.AppID = %d, want 12345", status.AppID) + } + if status.AppSlug != "my-app" { + t.Errorf("Status.AppSlug = %q, want %q", status.AppSlug, "my-app") + } + if status.HTMLURL != "https://github.com/apps/my-app" { + t.Errorf("Status.HTMLURL = %q, want %q", status.HTMLURL, "https://github.com/apps/my-app") + } + if status.InstallerDisabled { + t.Error("Status.InstallerDisabled = true, want false") + } +} + +func TestLocalEnvFileStore_Status_NotRegistered(t *testing.T) { + tmpDir := t.TempDir() + envPath := filepath.Join(tmpDir, ".env") + + // Missing required fields + content := `GITHUB_APP_ID=12345 +# Missing other required fields +` + if err := os.WriteFile(envPath, []byte(content), 0600); err != nil { + t.Fatalf("Failed to write file: %v", err) + } + + store := NewLocalEnvFileStore(envPath) + status, err := store.Status(context.Background()) + if err != nil { + t.Fatalf("Status() error = %v", err) + } + + if status.Registered { + t.Error("Status.Registered = true, want false (missing required fields)") + } +} + +func TestLocalEnvFileStore_Status_FileNotExists(t *testing.T) { + store := NewLocalEnvFileStore("/nonexistent/.env") + status, err := store.Status(context.Background()) + if err != nil { + t.Fatalf("Status() error = %v, want nil for nonexistent file", err) + } + + if status.Registered { + t.Error("Status.Registered = true, want false for nonexistent file") + } +} + +func TestLocalEnvFileStore_Status_InstallerDisabled(t *testing.T) { + tmpDir := t.TempDir() + envPath := filepath.Join(tmpDir, ".env") + + content := `GITHUB_APP_INSTALLER_ENABLED=false +` + if err := os.WriteFile(envPath, []byte(content), 0600); err != nil { + t.Fatalf("Failed to write file: %v", err) + } + + store := NewLocalEnvFileStore(envPath) + status, err := store.Status(context.Background()) + if err != nil { + t.Fatalf("Status() error = %v", err) + } + + if !status.InstallerDisabled { + t.Error("Status.InstallerDisabled = false, want true") + } +} + +func TestLocalEnvFileStore_DisableInstaller(t *testing.T) { + tmpDir := t.TempDir() + envPath := filepath.Join(tmpDir, ".env") + + store := NewLocalEnvFileStore(envPath) + err := store.DisableInstaller(context.Background()) + if err != nil { + t.Fatalf("DisableInstaller() error = %v", err) + } + + values, _, err := parseEnvFile(envPath) + if err != nil { + t.Fatalf("parseEnvFile() error = %v", err) + } + + if values[EnvGitHubAppInstallerEnabled] != "false" { + t.Errorf("DisableInstaller() did not set %s=false", EnvGitHubAppInstallerEnabled) + } +} + +func TestLocalEnvFileStore_RoundTrip(t *testing.T) { + tmpDir := t.TempDir() + envPath := filepath.Join(tmpDir, ".env") + + store := NewLocalEnvFileStore(envPath) + + // Original credentials with complex PEM key + original := &AppCredentials{ + AppID: 99999, + AppSlug: "test-app", + ClientID: "Iv1.complex123", + ClientSecret: "super-secret-value", + WebhookSecret: "whsec_complex", + PrivateKey: `-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEA0Z3VS5JJcds3xfn/ygWyF8PbnGy0AHB7MhgHW1FZ ++multiline+content+here +-----END RSA PRIVATE KEY----- +`, + HTMLURL: "https://github.com/apps/test-app", + CustomFields: map[string]string{ + "CUSTOM_DOMAIN": "custom.example.com", + "ANOTHER_FIELD": "another-value", + }, + } + + // Save + if err := store.Save(context.Background(), original); err != nil { + t.Fatalf("Save() error = %v", err) + } + + // Verify via Status + status, err := store.Status(context.Background()) + if err != nil { + t.Fatalf("Status() error = %v", err) + } + + if !status.Registered { + t.Error("Status.Registered = false after Save()") + } + if status.AppID != original.AppID { + t.Errorf("Status.AppID = %d, want %d", status.AppID, original.AppID) + } + if status.AppSlug != original.AppSlug { + t.Errorf("Status.AppSlug = %q, want %q", status.AppSlug, original.AppSlug) + } + if status.HTMLURL != original.HTMLURL { + t.Errorf("Status.HTMLURL = %q, want %q", status.HTMLURL, original.HTMLURL) + } + + // Also verify custom fields were saved + values, _, err := parseEnvFile(envPath) + if err != nil { + t.Fatalf("parseEnvFile() error = %v", err) + } + if values["CUSTOM_DOMAIN"] != "custom.example.com" { + t.Error("Custom field CUSTOM_DOMAIN was not saved") + } + if values["ANOTHER_FIELD"] != "another-value" { + t.Error("Custom field ANOTHER_FIELD was not saved") + } +} diff --git a/configstore/local_file_store.go b/configstore/local_file_store.go new file mode 100644 index 0000000..ca5cc8a --- /dev/null +++ b/configstore/local_file_store.go @@ -0,0 +1,143 @@ +// Copyright 2025 CruxStack +// SPDX-License-Identifier: MIT + +package configstore + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strconv" + "strings" +) + +// LocalFileStore saves credentials as individual files in a directory. +type LocalFileStore struct { + Dir string +} + +// NewLocalFileStore creates a store that saves credentials as files in dir. +func NewLocalFileStore(dir string) *LocalFileStore { + return &LocalFileStore{Dir: dir} +} + +// Save writes credentials to individual files in the store directory. +func (s *LocalFileStore) Save(ctx context.Context, creds *AppCredentials) error { + if err := os.MkdirAll(s.Dir, 0700); err != nil { + return fmt.Errorf("failed to create directory %s: %w", s.Dir, err) + } + + files := map[string]struct { + content string + mode os.FileMode + }{ + "app-id": {content: fmt.Sprintf("%d", creds.AppID), mode: 0644}, + "private-key.pem": {content: creds.PrivateKey, mode: 0600}, + "webhook-secret": {content: creds.WebhookSecret, mode: 0600}, + "client-id": {content: creds.ClientID, mode: 0644}, + "client-secret": {content: creds.ClientSecret, mode: 0600}, + } + + if creds.AppSlug != "" { + files["app-slug"] = struct { + content string + mode os.FileMode + }{content: creds.AppSlug, mode: 0644} + } + if creds.HTMLURL != "" { + files["app-html-url"] = struct { + content string + mode os.FileMode + }{content: creds.HTMLURL, mode: 0644} + } + + for key, value := range creds.CustomFields { + if value != "" { + filename := strings.ToLower(strings.ReplaceAll(key, "_", "-")) + files[filename] = struct { + content string + mode os.FileMode + }{content: value, mode: 0644} + } + } + + for name, file := range files { + path := filepath.Join(s.Dir, name) + if err := os.WriteFile(path, []byte(file.content), file.mode); err != nil { + return fmt.Errorf("failed to write %s: %w", path, err) + } + } + + return nil +} + +// Status returns the current registration state by checking required files. +func (s *LocalFileStore) Status(ctx context.Context) (*InstallerStatus, error) { + status := &InstallerStatus{} + + appIDBytes, err := os.ReadFile(filepath.Join(s.Dir, "app-id")) + if err != nil { + if os.IsNotExist(err) { + return status, nil + } + return nil, err + } + + if id, err := strconv.ParseInt(strings.TrimSpace(string(appIDBytes)), 10, 64); err == nil { + status.AppID = id + } + + required := []string{"client-id", "client-secret", "webhook-secret", "private-key.pem"} + for _, name := range required { + if _, err := os.Stat(filepath.Join(s.Dir, name)); err != nil { + if os.IsNotExist(err) { + return status, nil + } + return nil, err + } + } + status.Registered = true + + if slug, err := readTrimmedFile(filepath.Join(s.Dir, "app-slug")); err == nil { + status.AppSlug = slug + } else if !os.IsNotExist(err) { + return nil, err + } + + if html, err := readTrimmedFile(filepath.Join(s.Dir, "app-html-url")); err == nil { + status.HTMLURL = html + } else if !os.IsNotExist(err) { + return nil, err + } + + if _, err := os.Stat(filepath.Join(s.Dir, "installer-disabled")); err == nil { + status.InstallerDisabled = true + } else if !os.IsNotExist(err) { + return nil, err + } + + return status, nil +} + +// DisableInstaller creates a marker file to disable the installer. +func (s *LocalFileStore) DisableInstaller(ctx context.Context) error { + if err := os.MkdirAll(s.Dir, 0700); err != nil { + return fmt.Errorf("failed to create directory %s: %w", s.Dir, err) + } + + path := filepath.Join(s.Dir, "installer-disabled") + if err := os.WriteFile(path, []byte("disabled"), 0600); err != nil { + return fmt.Errorf("failed to write %s: %w", path, err) + } + + return nil +} + +func readTrimmedFile(path string) (string, error) { + data, err := os.ReadFile(path) + if err != nil { + return "", err + } + return strings.TrimSpace(string(data)), nil +} diff --git a/configstore/local_file_store_test.go b/configstore/local_file_store_test.go new file mode 100644 index 0000000..a56840a --- /dev/null +++ b/configstore/local_file_store_test.go @@ -0,0 +1,445 @@ +// Copyright 2025 CruxStack +// SPDX-License-Identifier: MIT + +package configstore + +import ( + "context" + "os" + "path/filepath" + "testing" +) + +func TestLocalFileStore_Save(t *testing.T) { + tmpDir := t.TempDir() + store := NewLocalFileStore(tmpDir) + + creds := &AppCredentials{ + AppID: 12345, + AppSlug: "my-app", + ClientID: "Iv1.abc123", + ClientSecret: "secret123", + WebhookSecret: "whsec_123", + PrivateKey: "-----BEGIN RSA PRIVATE KEY-----\nMIIE...\n-----END RSA PRIVATE KEY-----\n", + HTMLURL: "https://github.com/apps/my-app", + CustomFields: map[string]string{ + "STS_DOMAIN": "sts.example.com", + "CUSTOM_VALUE": "custom123", + }, + } + + err := store.Save(context.Background(), creds) + if err != nil { + t.Fatalf("Save() error = %v", err) + } + + // Verify core files exist with correct content + checks := map[string]struct { + content string + mode os.FileMode + }{ + "app-id": {content: "12345", mode: 0644}, + "app-slug": {content: "my-app", mode: 0644}, + "client-id": {content: "Iv1.abc123", mode: 0644}, + "client-secret": {content: "secret123", mode: 0600}, + "webhook-secret": {content: "whsec_123", mode: 0600}, + "private-key.pem": {content: "-----BEGIN RSA PRIVATE KEY-----\nMIIE...\n-----END RSA PRIVATE KEY-----\n", mode: 0600}, + "app-html-url": {content: "https://github.com/apps/my-app", mode: 0644}, + "sts-domain": {content: "sts.example.com", mode: 0644}, + "custom-value": {content: "custom123", mode: 0644}, + } + + for name, want := range checks { + path := filepath.Join(tmpDir, name) + + // Check file exists + info, err := os.Stat(path) + if err != nil { + t.Errorf("File %q not created: %v", name, err) + continue + } + + // Check permissions + if info.Mode().Perm() != want.mode { + t.Errorf("File %q permissions = %o, want %o", name, info.Mode().Perm(), want.mode) + } + + // Check content + content, err := os.ReadFile(path) + if err != nil { + t.Errorf("Failed to read %q: %v", name, err) + continue + } + if string(content) != want.content { + t.Errorf("File %q content = %q, want %q", name, string(content), want.content) + } + } +} + +func TestLocalFileStore_Save_CreatesDirectory(t *testing.T) { + tmpDir := t.TempDir() + nestedDir := filepath.Join(tmpDir, "nested", "config", "dir") + + store := NewLocalFileStore(nestedDir) + creds := &AppCredentials{ + AppID: 1, + ClientID: "client", + ClientSecret: "secret", + WebhookSecret: "webhook", + PrivateKey: "key", + } + + err := store.Save(context.Background(), creds) + if err != nil { + t.Fatalf("Save() error = %v", err) + } + + // Directory should be created with 0700 permissions + info, err := os.Stat(nestedDir) + if err != nil { + t.Fatalf("Directory not created: %v", err) + } + if !info.IsDir() { + t.Error("Expected directory, got file") + } + if info.Mode().Perm() != 0700 { + t.Errorf("Directory permissions = %o, want 0700", info.Mode().Perm()) + } +} + +func TestLocalFileStore_Save_OmitsEmptyOptionalFields(t *testing.T) { + tmpDir := t.TempDir() + store := NewLocalFileStore(tmpDir) + + creds := &AppCredentials{ + AppID: 1, + ClientID: "client", + ClientSecret: "secret", + WebhookSecret: "webhook", + PrivateKey: "key", + // AppSlug and HTMLURL are empty + } + + err := store.Save(context.Background(), creds) + if err != nil { + t.Fatalf("Save() error = %v", err) + } + + // Optional files should not exist + for _, name := range []string{"app-slug", "app-html-url"} { + path := filepath.Join(tmpDir, name) + if _, err := os.Stat(path); !os.IsNotExist(err) { + t.Errorf("File %q should not exist for empty value", name) + } + } +} + +func TestLocalFileStore_Save_SkipsEmptyCustomFields(t *testing.T) { + tmpDir := t.TempDir() + store := NewLocalFileStore(tmpDir) + + creds := &AppCredentials{ + AppID: 1, + ClientID: "client", + ClientSecret: "secret", + WebhookSecret: "webhook", + PrivateKey: "key", + CustomFields: map[string]string{ + "EMPTY_FIELD": "", + "VALID_FIELD": "value", + }, + } + + err := store.Save(context.Background(), creds) + if err != nil { + t.Fatalf("Save() error = %v", err) + } + + // Empty custom field should not be created + if _, err := os.Stat(filepath.Join(tmpDir, "empty-field")); !os.IsNotExist(err) { + t.Error("Empty custom field should not create a file") + } + + // Valid custom field should exist + if _, err := os.Stat(filepath.Join(tmpDir, "valid-field")); err != nil { + t.Errorf("Valid custom field not created: %v", err) + } +} + +func TestLocalFileStore_Status_Registered(t *testing.T) { + tmpDir := t.TempDir() + + // Create all required files + files := map[string]string{ + "app-id": "12345", + "app-slug": "test-app", + "app-html-url": "https://github.com/apps/test-app", + "client-id": "client123", + "client-secret": "secret123", + "webhook-secret": "webhook123", + "private-key.pem": "-----BEGIN RSA-----\n...\n-----END RSA-----", + } + for name, content := range files { + if err := os.WriteFile(filepath.Join(tmpDir, name), []byte(content), 0644); err != nil { + t.Fatalf("Failed to create %s: %v", name, err) + } + } + + store := NewLocalFileStore(tmpDir) + status, err := store.Status(context.Background()) + if err != nil { + t.Fatalf("Status() error = %v", err) + } + + if !status.Registered { + t.Error("Status.Registered = false, want true") + } + if status.AppID != 12345 { + t.Errorf("Status.AppID = %d, want 12345", status.AppID) + } + if status.AppSlug != "test-app" { + t.Errorf("Status.AppSlug = %q, want %q", status.AppSlug, "test-app") + } + if status.HTMLURL != "https://github.com/apps/test-app" { + t.Errorf("Status.HTMLURL = %q, want %q", status.HTMLURL, "https://github.com/apps/test-app") + } + if status.InstallerDisabled { + t.Error("Status.InstallerDisabled = true, want false") + } +} + +func TestLocalFileStore_Status_NotRegistered_MissingAppID(t *testing.T) { + tmpDir := t.TempDir() + + // Directory exists but no app-id file + store := NewLocalFileStore(tmpDir) + status, err := store.Status(context.Background()) + if err != nil { + t.Fatalf("Status() error = %v", err) + } + + if status.Registered { + t.Error("Status.Registered = true, want false (no app-id)") + } +} + +func TestLocalFileStore_Status_NotRegistered_MissingRequiredFiles(t *testing.T) { + tmpDir := t.TempDir() + + // Create app-id but missing required files + if err := os.WriteFile(filepath.Join(tmpDir, "app-id"), []byte("12345"), 0644); err != nil { + t.Fatalf("Failed to create app-id: %v", err) + } + + store := NewLocalFileStore(tmpDir) + status, err := store.Status(context.Background()) + if err != nil { + t.Fatalf("Status() error = %v", err) + } + + if status.Registered { + t.Error("Status.Registered = true, want false (missing required files)") + } +} + +func TestLocalFileStore_Status_InstallerDisabled(t *testing.T) { + tmpDir := t.TempDir() + + // Create all required files plus installer-disabled marker + files := map[string]string{ + "app-id": "12345", + "client-id": "client", + "client-secret": "secret", + "webhook-secret": "webhook", + "private-key.pem": "key", + "installer-disabled": "disabled", + } + for name, content := range files { + if err := os.WriteFile(filepath.Join(tmpDir, name), []byte(content), 0644); err != nil { + t.Fatalf("Failed to create %s: %v", name, err) + } + } + + store := NewLocalFileStore(tmpDir) + status, err := store.Status(context.Background()) + if err != nil { + t.Fatalf("Status() error = %v", err) + } + + if !status.InstallerDisabled { + t.Error("Status.InstallerDisabled = false, want true") + } +} + +func TestLocalFileStore_Status_DirectoryNotExists(t *testing.T) { + store := NewLocalFileStore("/nonexistent/directory") + status, err := store.Status(context.Background()) + if err != nil { + t.Fatalf("Status() error = %v, want nil for nonexistent directory", err) + } + + if status.Registered { + t.Error("Status.Registered = true, want false for nonexistent directory") + } +} + +func TestLocalFileStore_Status_WhitespaceInAppID(t *testing.T) { + tmpDir := t.TempDir() + + // Create app-id with whitespace and newline + files := map[string]string{ + "app-id": " 12345\n", + "client-id": "client", + "client-secret": "secret", + "webhook-secret": "webhook", + "private-key.pem": "key", + } + for name, content := range files { + if err := os.WriteFile(filepath.Join(tmpDir, name), []byte(content), 0644); err != nil { + t.Fatalf("Failed to create %s: %v", name, err) + } + } + + store := NewLocalFileStore(tmpDir) + status, err := store.Status(context.Background()) + if err != nil { + t.Fatalf("Status() error = %v", err) + } + + if status.AppID != 12345 { + t.Errorf("Status.AppID = %d, want 12345 (whitespace should be trimmed)", status.AppID) + } +} + +func TestLocalFileStore_DisableInstaller(t *testing.T) { + tmpDir := t.TempDir() + store := NewLocalFileStore(tmpDir) + + err := store.DisableInstaller(context.Background()) + if err != nil { + t.Fatalf("DisableInstaller() error = %v", err) + } + + // Check marker file exists + markerPath := filepath.Join(tmpDir, "installer-disabled") + info, err := os.Stat(markerPath) + if err != nil { + t.Fatalf("Marker file not created: %v", err) + } + + // Check permissions (should be 0600 for security) + if info.Mode().Perm() != 0600 { + t.Errorf("Marker file permissions = %o, want 0600", info.Mode().Perm()) + } +} + +func TestLocalFileStore_DisableInstaller_CreatesDirectory(t *testing.T) { + tmpDir := t.TempDir() + nestedDir := filepath.Join(tmpDir, "nested", "dir") + store := NewLocalFileStore(nestedDir) + + err := store.DisableInstaller(context.Background()) + if err != nil { + t.Fatalf("DisableInstaller() error = %v", err) + } + + // Directory should be created + if _, err := os.Stat(nestedDir); err != nil { + t.Errorf("Directory not created: %v", err) + } + + // Marker file should exist + if _, err := os.Stat(filepath.Join(nestedDir, "installer-disabled")); err != nil { + t.Errorf("Marker file not created: %v", err) + } +} + +func TestLocalFileStore_RoundTrip(t *testing.T) { + tmpDir := t.TempDir() + store := NewLocalFileStore(tmpDir) + + original := &AppCredentials{ + AppID: 99999, + AppSlug: "roundtrip-app", + ClientID: "Iv1.roundtrip", + ClientSecret: "roundtrip-secret", + WebhookSecret: "whsec_roundtrip", + PrivateKey: "-----BEGIN RSA PRIVATE KEY-----\nroundtrip-key-content\n-----END RSA PRIVATE KEY-----\n", + HTMLURL: "https://github.com/apps/roundtrip-app", + CustomFields: map[string]string{ + "CUSTOM_FIELD": "custom-value", + }, + } + + // Save + if err := store.Save(context.Background(), original); err != nil { + t.Fatalf("Save() error = %v", err) + } + + // Verify via Status + status, err := store.Status(context.Background()) + if err != nil { + t.Fatalf("Status() error = %v", err) + } + + if !status.Registered { + t.Error("Status.Registered = false after Save()") + } + if status.AppID != original.AppID { + t.Errorf("Status.AppID = %d, want %d", status.AppID, original.AppID) + } + if status.AppSlug != original.AppSlug { + t.Errorf("Status.AppSlug = %q, want %q", status.AppSlug, original.AppSlug) + } + if status.HTMLURL != original.HTMLURL { + t.Errorf("Status.HTMLURL = %q, want %q", status.HTMLURL, original.HTMLURL) + } + + // Verify custom field was saved + content, err := os.ReadFile(filepath.Join(tmpDir, "custom-field")) + if err != nil { + t.Fatalf("Custom field file not found: %v", err) + } + if string(content) != "custom-value" { + t.Errorf("Custom field content = %q, want %q", string(content), "custom-value") + } +} + +func TestReadTrimmedFile(t *testing.T) { + tmpDir := t.TempDir() + + tests := []struct { + name string + content string + want string + }{ + {"no whitespace", "value", "value"}, + {"trailing newline", "value\n", "value"}, + {"leading and trailing spaces", " value ", "value"}, + {"mixed whitespace", "\t value \n", "value"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + path := filepath.Join(tmpDir, "test-"+tt.name) + if err := os.WriteFile(path, []byte(tt.content), 0644); err != nil { + t.Fatalf("Failed to write file: %v", err) + } + + got, err := readTrimmedFile(path) + if err != nil { + t.Fatalf("readTrimmedFile() error = %v", err) + } + if got != tt.want { + t.Errorf("readTrimmedFile() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestReadTrimmedFile_NotExists(t *testing.T) { + _, err := readTrimmedFile("/nonexistent/file") + if !os.IsNotExist(err) { + t.Errorf("readTrimmedFile() error = %v, want os.IsNotExist", err) + } +} diff --git a/configstore/store.go b/configstore/store.go new file mode 100644 index 0000000..53b1427 --- /dev/null +++ b/configstore/store.go @@ -0,0 +1,162 @@ +// Copyright 2025 CruxStack +// SPDX-License-Identifier: MIT + +// Package configstore provides storage backends for GitHub App credentials. +// It supports multiple storage backends including AWS SSM Parameter Store, +// local .env files, and individual files. +package configstore + +import ( + "context" + "encoding/json" + "fmt" + "os" + "strings" +) + +const ( + EnvGitHubAppID = "GITHUB_APP_ID" + EnvGitHubAppSlug = "GITHUB_APP_SLUG" + EnvGitHubAppHTMLURL = "GITHUB_APP_HTML_URL" + EnvGitHubAppPrivateKey = "GITHUB_APP_PRIVATE_KEY" + EnvGitHubWebhookSecret = "GITHUB_WEBHOOK_SECRET" + EnvGitHubClientID = "GITHUB_CLIENT_ID" + EnvGitHubClientSecret = "GITHUB_CLIENT_SECRET" +) + +const ( + EnvGitHubAppInstallerEnabled = "GITHUB_APP_INSTALLER_ENABLED" + EnvStorageMode = "STORAGE_MODE" + EnvStorageDir = "STORAGE_DIR" + EnvAWSSSMParameterPfx = "AWS_SSM_PARAMETER_PREFIX" + EnvAWSSSMKMSKeyID = "AWS_SSM_KMS_KEY_ID" + EnvAWSSSMTags = "AWS_SSM_TAGS" +) + +// Storage mode constants for STORAGE_MODE environment variable. +const ( + // StorageModeEnvFile saves credentials to a .env file (default mode). + StorageModeEnvFile = "envfile" + // StorageModeFiles saves credentials as individual files in a directory. + StorageModeFiles = "files" + // StorageModeAWSSSM saves credentials to AWS SSM Parameter Store. + StorageModeAWSSSM = "aws-ssm" +) + +// HookConfig contains webhook configuration returned from GitHub. +type HookConfig struct { + URL string `json:"url"` +} + +// AppCredentials holds credentials returned from GitHub App manifest creation. +type AppCredentials struct { + AppID int64 `json:"id"` + AppSlug string `json:"slug"` + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + WebhookSecret string `json:"webhook_secret"` + PrivateKey string `json:"pem"` + HTMLURL string `json:"html_url"` + HookConfig HookConfig `json:"hook_config"` + + // CustomFields stores additional app-specific values alongside credentials. + CustomFields map[string]string `json:"-"` +} + +// InstallerStatus describes the current GitHub App registration state. +type InstallerStatus struct { + Registered bool + InstallerDisabled bool + AppID int64 + AppSlug string + HTMLURL string +} + +// Store saves app credentials to various backends (local disk, AWS SSM, etc). +type Store interface { + Save(ctx context.Context, creds *AppCredentials) error + Status(ctx context.Context) (*InstallerStatus, error) + DisableInstaller(ctx context.Context) error +} + +// NewFromEnv creates a Store based on environment variable configuration. +// It reads STORAGE_MODE to determine the backend type: +// - "envfile" (default): saves to a .env file at STORAGE_DIR (default: ./.env) +// - "files": saves to individual files in STORAGE_DIR directory +// - "aws-ssm": saves to AWS SSM Parameter Store with AWS_SSM_PARAMETER_PREFIX +// +// Returns an error if configuration is invalid or store creation fails. +func NewFromEnv() (Store, error) { + mode := GetEnvDefault(EnvStorageMode, StorageModeEnvFile) + + switch mode { + case StorageModeFiles: + dir := GetEnvDefault(EnvStorageDir, "./.env") + return NewLocalFileStore(dir), nil + + case StorageModeEnvFile: + path := GetEnvDefault(EnvStorageDir, "./.env") + return NewLocalEnvFileStore(path), nil + + case StorageModeAWSSSM: + prefix := os.Getenv(EnvAWSSSMParameterPfx) + if prefix == "" { + return nil, fmt.Errorf("%s is required when using %s storage mode", EnvAWSSSMParameterPfx, StorageModeAWSSSM) + } + + var opts []SSMStoreOption + + if kmsKeyID := os.Getenv(EnvAWSSSMKMSKeyID); kmsKeyID != "" { + opts = append(opts, WithKMSKey(kmsKeyID)) + } + + if tagsJSON := os.Getenv(EnvAWSSSMTags); tagsJSON != "" { + var tags map[string]string + if err := json.Unmarshal([]byte(tagsJSON), &tags); err != nil { + return nil, fmt.Errorf("failed to parse %s as JSON: %w", EnvAWSSSMTags, err) + } + opts = append(opts, WithTags(tags)) + } + + return NewAWSSSMStore(prefix, opts...) + + default: + return nil, fmt.Errorf("unknown %s: %s (expected '%s', '%s', or '%s')", + EnvStorageMode, mode, StorageModeEnvFile, StorageModeFiles, StorageModeAWSSSM) + } +} + +// InstallerEnabled returns true if the installer is enabled via environment variable. +func InstallerEnabled() bool { + v := strings.ToLower(os.Getenv(EnvGitHubAppInstallerEnabled)) + return v == "true" || v == "1" || v == "yes" +} + +// GetEnvDefault returns an env var value, or defaultValue if not set or empty. +func GetEnvDefault(key, defaultValue string) string { + if v := os.Getenv(key); v != "" { + return v + } + return defaultValue +} + +func hasAllValues(values map[string]string, keys ...string) bool { + if len(values) == 0 { + return false + } + for _, key := range keys { + if strings.TrimSpace(values[key]) == "" { + return false + } + } + return true +} + +func isFalseString(v string) bool { + switch strings.ToLower(strings.TrimSpace(v)) { + case "false", "0", "no", "off": + return true + default: + return false + } +} diff --git a/configstore/store_test.go b/configstore/store_test.go new file mode 100644 index 0000000..e5c4373 --- /dev/null +++ b/configstore/store_test.go @@ -0,0 +1,312 @@ +// Copyright 2025 CruxStack +// SPDX-License-Identifier: MIT + +package configstore + +import ( + "os" + "testing" +) + +func TestInstallerEnabled(t *testing.T) { + tests := []struct { + name string + value string + want bool + }{ + {"true lowercase", "true", true}, + {"TRUE uppercase", "TRUE", true}, + {"True mixed", "True", true}, + {"1", "1", true}, + {"yes", "yes", true}, + {"YES uppercase", "YES", true}, + {"false", "false", false}, + {"FALSE", "FALSE", false}, + {"0", "0", false}, + {"no", "no", false}, + {"empty string", "", false}, + {"random string", "enabled", false}, + {"on (not supported)", "on", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + os.Setenv(EnvGitHubAppInstallerEnabled, tt.value) + defer os.Unsetenv(EnvGitHubAppInstallerEnabled) + + got := InstallerEnabled() + if got != tt.want { + t.Errorf("InstallerEnabled() with %q = %v, want %v", tt.value, got, tt.want) + } + }) + } +} + +func TestHasAllValues(t *testing.T) { + tests := []struct { + name string + values map[string]string + keys []string + want bool + }{ + { + name: "all keys present with values", + values: map[string]string{"a": "1", "b": "2", "c": "3"}, + keys: []string{"a", "b"}, + want: true, + }, + { + name: "missing key", + values: map[string]string{"a": "1"}, + keys: []string{"a", "b"}, + want: false, + }, + { + name: "empty value fails", + values: map[string]string{"a": "1", "b": ""}, + keys: []string{"a", "b"}, + want: false, + }, + { + name: "whitespace-only value fails", + values: map[string]string{"a": "1", "b": " "}, + keys: []string{"a", "b"}, + want: false, + }, + { + name: "nil map", + values: nil, + keys: []string{"a"}, + want: false, + }, + { + name: "empty map", + values: map[string]string{}, + keys: []string{"a"}, + want: false, + }, + { + name: "no keys required", + values: map[string]string{"a": "1"}, + keys: []string{}, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := hasAllValues(tt.values, tt.keys...) + if got != tt.want { + t.Errorf("hasAllValues() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestIsFalseString(t *testing.T) { + tests := []struct { + value string + want bool + }{ + {"false", true}, + {"FALSE", true}, + {"False", true}, + {"0", true}, + {"no", true}, + {"NO", true}, + {"off", true}, + {"OFF", true}, + {" false ", true}, // with whitespace + {"true", false}, + {"1", false}, + {"yes", false}, + {"on", false}, + {"", false}, + {"random", false}, + } + + for _, tt := range tests { + t.Run(tt.value, func(t *testing.T) { + got := isFalseString(tt.value) + if got != tt.want { + t.Errorf("isFalseString(%q) = %v, want %v", tt.value, got, tt.want) + } + }) + } +} + +func TestGetEnvDefault(t *testing.T) { + tests := []struct { + name string + envKey string + envValue string + setEnv bool + defaultValue string + want string + }{ + { + name: "env set returns env value", + envKey: "TEST_VAR", + envValue: "custom_value", + setEnv: true, + defaultValue: "default", + want: "custom_value", + }, + { + name: "env not set returns default", + envKey: "TEST_VAR_UNSET", + envValue: "", + setEnv: false, + defaultValue: "default_value", + want: "default_value", + }, + { + name: "empty env returns default", + envKey: "TEST_VAR_EMPTY", + envValue: "", + setEnv: true, + defaultValue: "default", + want: "default", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + os.Unsetenv(tt.envKey) + if tt.setEnv { + os.Setenv(tt.envKey, tt.envValue) + defer os.Unsetenv(tt.envKey) + } + + got := GetEnvDefault(tt.envKey, tt.defaultValue) + if got != tt.want { + t.Errorf("GetEnvDefault(%q, %q) = %q, want %q", tt.envKey, tt.defaultValue, got, tt.want) + } + }) + } +} + +func TestNewFromEnv_StorageModes(t *testing.T) { + t.Run("default mode creates LocalEnvFileStore", func(t *testing.T) { + os.Unsetenv(EnvStorageMode) + defer os.Unsetenv(EnvStorageMode) + + store, err := NewFromEnv() + if err != nil { + t.Fatalf("NewFromEnv() error = %v", err) + } + + if _, ok := store.(*LocalEnvFileStore); !ok { + t.Errorf("NewFromEnv() returned %T, want *LocalEnvFileStore", store) + } + }) + + t.Run("envfile mode creates LocalEnvFileStore", func(t *testing.T) { + os.Setenv(EnvStorageMode, StorageModeEnvFile) + defer os.Unsetenv(EnvStorageMode) + + store, err := NewFromEnv() + if err != nil { + t.Fatalf("NewFromEnv() error = %v", err) + } + + if _, ok := store.(*LocalEnvFileStore); !ok { + t.Errorf("NewFromEnv() returned %T, want *LocalEnvFileStore", store) + } + }) + + t.Run("files mode creates LocalFileStore", func(t *testing.T) { + os.Setenv(EnvStorageMode, StorageModeFiles) + defer os.Unsetenv(EnvStorageMode) + + store, err := NewFromEnv() + if err != nil { + t.Fatalf("NewFromEnv() error = %v", err) + } + + if _, ok := store.(*LocalFileStore); !ok { + t.Errorf("NewFromEnv() returned %T, want *LocalFileStore", store) + } + }) + + t.Run("aws-ssm mode requires prefix", func(t *testing.T) { + os.Setenv(EnvStorageMode, StorageModeAWSSSM) + os.Unsetenv(EnvAWSSSMParameterPfx) + defer os.Unsetenv(EnvStorageMode) + + _, err := NewFromEnv() + if err == nil { + t.Error("NewFromEnv() with aws-ssm and no prefix should return error") + } + }) + + t.Run("unknown mode returns error", func(t *testing.T) { + os.Setenv(EnvStorageMode, "invalid-mode") + defer os.Unsetenv(EnvStorageMode) + + _, err := NewFromEnv() + if err == nil { + t.Error("NewFromEnv() with unknown mode should return error") + } + }) +} + +func TestNewFromEnv_CustomStorageDir(t *testing.T) { + t.Run("envfile mode uses STORAGE_DIR", func(t *testing.T) { + os.Setenv(EnvStorageMode, StorageModeEnvFile) + os.Setenv(EnvStorageDir, "/custom/path/.env") + defer os.Unsetenv(EnvStorageMode) + defer os.Unsetenv(EnvStorageDir) + + store, err := NewFromEnv() + if err != nil { + t.Fatalf("NewFromEnv() error = %v", err) + } + + envStore, ok := store.(*LocalEnvFileStore) + if !ok { + t.Fatalf("NewFromEnv() returned %T, want *LocalEnvFileStore", store) + } + + if envStore.FilePath != "/custom/path/.env" { + t.Errorf("FilePath = %q, want %q", envStore.FilePath, "/custom/path/.env") + } + }) + + t.Run("files mode uses STORAGE_DIR", func(t *testing.T) { + os.Setenv(EnvStorageMode, StorageModeFiles) + os.Setenv(EnvStorageDir, "/custom/dir") + defer os.Unsetenv(EnvStorageMode) + defer os.Unsetenv(EnvStorageDir) + + store, err := NewFromEnv() + if err != nil { + t.Fatalf("NewFromEnv() error = %v", err) + } + + fileStore, ok := store.(*LocalFileStore) + if !ok { + t.Fatalf("NewFromEnv() returned %T, want *LocalFileStore", store) + } + + if fileStore.Dir != "/custom/dir" { + t.Errorf("Dir = %q, want %q", fileStore.Dir, "/custom/dir") + } + }) +} + +func TestNewFromEnv_AWSSSMTags(t *testing.T) { + t.Run("invalid JSON tags returns error", func(t *testing.T) { + os.Setenv(EnvStorageMode, StorageModeAWSSSM) + os.Setenv(EnvAWSSSMParameterPfx, "/test/prefix") + os.Setenv(EnvAWSSSMTags, "not valid json") + defer os.Unsetenv(EnvStorageMode) + defer os.Unsetenv(EnvAWSSSMParameterPfx) + defer os.Unsetenv(EnvAWSSSMTags) + + _, err := NewFromEnv() + if err == nil { + t.Error("NewFromEnv() with invalid JSON tags should return error") + } + }) +} diff --git a/configwait/configwait.go b/configwait/configwait.go new file mode 100644 index 0000000..cb5e6ab --- /dev/null +++ b/configwait/configwait.go @@ -0,0 +1,203 @@ +// Copyright 2025 CruxStack +// SPDX-License-Identifier: MIT + +// Package configwait provides utilities for waiting on configuration availability +// during startup. +package configwait + +import ( + "context" + "encoding/json" + "net/http" + "os" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/chainguard-dev/clog" +) + +const ( + EnvMaxRetries = "CONFIG_WAIT_MAX_RETRIES" + EnvRetryInterval = "CONFIG_WAIT_RETRY_INTERVAL" +) + +const ( + DefaultMaxRetries = 30 + DefaultRetryInterval = 2 * time.Second +) + +// Config configures the wait behavior. +type Config struct { + MaxRetries int + RetryInterval time.Duration +} + +// NewConfigFromEnv creates a Config from environment variables. +func NewConfigFromEnv() Config { + cfg := Config{ + MaxRetries: DefaultMaxRetries, + RetryInterval: DefaultRetryInterval, + } + + if v := os.Getenv(EnvMaxRetries); v != "" { + if n, err := strconv.Atoi(v); err == nil && n > 0 { + cfg.MaxRetries = n + } + } + + if v := os.Getenv(EnvRetryInterval); v != "" { + if d, err := time.ParseDuration(v); err == nil && d > 0 { + cfg.RetryInterval = d + } + } + + return cfg +} + +// LoadFunc attempts to load configuration; returns nil on success. +type LoadFunc func(ctx context.Context) error + +// Wait blocks until load succeeds or max retries is reached. +func Wait(ctx context.Context, cfg Config, load LoadFunc) error { + log := clog.FromContext(ctx) + var lastErr error + + for attempt := 1; attempt <= cfg.MaxRetries; attempt++ { + if err := load(ctx); err != nil { + lastErr = err + log.Warnf("[configwait] attempt %d/%d failed: %v", attempt, cfg.MaxRetries, err) + + if attempt < cfg.MaxRetries { + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(cfg.RetryInterval): + + } + } + } else { + if attempt > 1 { + log.Infof("[configwait] configuration loaded successfully after %d attempts", attempt) + } + return nil + } + } + + return lastErr +} + +// ReadyGate gates HTTP requests until the service is ready. +type ReadyGate struct { + inner http.Handler + allowedPaths []string + ready atomic.Bool + handler atomic.Value // stores http.Handler once ready + + mu sync.Mutex + handlerReady chan struct{} +} + +// NewReadyGate creates a ReadyGate wrapping the given handler. +// The allowedPaths are path prefixes always allowed through (e.g., "/setup"). +// The inner handler can be nil initially; call SetHandler() once ready. +func NewReadyGate(inner http.Handler, allowedPaths []string) *ReadyGate { + rg := &ReadyGate{ + inner: inner, + allowedPaths: allowedPaths, + handlerReady: make(chan struct{}), + } + if inner != nil { + rg.handler.Store(inner) + } + return rg +} + +// SetReady marks the service as ready to handle all requests. +func (rg *ReadyGate) SetReady() { + rg.ready.Store(true) +} + +// SetHandler sets the main handler to use once ready. +func (rg *ReadyGate) SetHandler(h http.Handler) { + rg.handler.Store(h) + rg.mu.Lock() + defer rg.mu.Unlock() + select { + case <-rg.handlerReady: + default: + close(rg.handlerReady) + } +} + +// IsReady returns true if the service is ready. +func (rg *ReadyGate) IsReady() bool { + return rg.ready.Load() +} + +// ServeHTTP implements http.Handler. +func (rg *ReadyGate) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if rg.isAllowedPath(r.URL.Path) { + h := rg.getHandler() + if h != nil { + h.ServeHTTP(w, r) + return + } + rg.serveUnavailable(w, r, "service starting up") + return + } + + if !rg.ready.Load() { + rg.serveUnavailable(w, r, "service not ready, configuration loading") + return + } + + h := rg.getHandler() + if h == nil { + rg.serveUnavailable(w, r, "service starting up") + return + } + h.ServeHTTP(w, r) +} + +// isAllowedPath checks if the path matches any allowed path prefix. +func (rg *ReadyGate) isAllowedPath(path string) bool { + for _, allowed := range rg.allowedPaths { + if allowed == "/" { + if path == "/" { + return true + } + continue + } + if strings.HasPrefix(path, allowed) { + return true + } + } + return false +} + +// getHandler returns the current handler. +func (rg *ReadyGate) getHandler() http.Handler { + h := rg.handler.Load() + if h == nil { + return nil + } + return h.(http.Handler) +} + +// serveUnavailable writes a 503 response. +func (rg *ReadyGate) serveUnavailable(w http.ResponseWriter, r *http.Request, message string) { + log := clog.FromContext(r.Context()) + + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Retry-After", "5") + w.WriteHeader(http.StatusServiceUnavailable) + if err := json.NewEncoder(w).Encode(map[string]string{ + "error": "service_unavailable", + "message": message, + }); err != nil { + log.Errorf("[configwait] failed to write unavailable response: %v", err) + } +} diff --git a/configwait/configwait_test.go b/configwait/configwait_test.go new file mode 100644 index 0000000..852c948 --- /dev/null +++ b/configwait/configwait_test.go @@ -0,0 +1,398 @@ +// Copyright 2025 CruxStack +// SPDX-License-Identifier: MIT + +package configwait + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" +) + +func TestWait_ImmediateSuccess(t *testing.T) { + ctx := context.Background() + cfg := Config{ + MaxRetries: 3, + RetryInterval: 10 * time.Millisecond, + } + + callCount := 0 + err := Wait(ctx, cfg, func(ctx context.Context) error { + callCount++ + return nil + }) + + if err != nil { + t.Errorf("Wait() error = %v, want nil", err) + } + if callCount != 1 { + t.Errorf("Load function called %d times, want 1", callCount) + } +} + +func TestWait_RetryThenSuccess(t *testing.T) { + ctx := context.Background() + cfg := Config{ + MaxRetries: 5, + RetryInterval: 10 * time.Millisecond, + } + + callCount := 0 + err := Wait(ctx, cfg, func(ctx context.Context) error { + callCount++ + if callCount < 3 { + return errors.New("not ready") + } + return nil + }) + + if err != nil { + t.Errorf("Wait() error = %v, want nil", err) + } + if callCount != 3 { + t.Errorf("Load function called %d times, want 3", callCount) + } +} + +func TestWait_MaxRetriesExceeded(t *testing.T) { + ctx := context.Background() + cfg := Config{ + MaxRetries: 3, + RetryInterval: 10 * time.Millisecond, + } + + callCount := 0 + expectedErr := errors.New("always fail") + err := Wait(ctx, cfg, func(ctx context.Context) error { + callCount++ + return expectedErr + }) + + if err != expectedErr { + t.Errorf("Wait() error = %v, want %v", err, expectedErr) + } + if callCount != 3 { + t.Errorf("Load function called %d times, want 3", callCount) + } +} + +func TestWait_ContextCancellation(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cfg := Config{ + MaxRetries: 100, + RetryInterval: 100 * time.Millisecond, + } + + callCount := 0 + go func() { + time.Sleep(50 * time.Millisecond) + cancel() + }() + + err := Wait(ctx, cfg, func(ctx context.Context) error { + callCount++ + return errors.New("not ready") + }) + + if err != context.Canceled { + t.Errorf("Wait() error = %v, want %v", err, context.Canceled) + } + // Should have been cancelled after 1-2 attempts + if callCount > 2 { + t.Errorf("Load function called %d times, expected <= 2", callCount) + } +} + +func TestReadyGate_NotReadyReturns503(t *testing.T) { + inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("ok")) + }) + + gate := NewReadyGate(inner, nil) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + rec := httptest.NewRecorder() + + gate.ServeHTTP(rec, req) + + if rec.Code != http.StatusServiceUnavailable { + t.Errorf("Status = %d, want %d", rec.Code, http.StatusServiceUnavailable) + } +} + +func TestReadyGate_ReadyPassesThrough(t *testing.T) { + inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("ok")) + }) + + gate := NewReadyGate(inner, nil) + gate.SetReady() + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + rec := httptest.NewRecorder() + + gate.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("Status = %d, want %d", rec.Code, http.StatusOK) + } + if rec.Body.String() != "ok" { + t.Errorf("Body = %q, want %q", rec.Body.String(), "ok") + } +} + +func TestReadyGate_AllowedPathsPassThrough(t *testing.T) { + var called atomic.Bool + inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called.Store(true) + w.WriteHeader(http.StatusOK) + w.Write([]byte("setup page")) + }) + + gate := NewReadyGate(inner, []string{"/setup", "/healthz"}) + // Not ready, but /setup should pass through + + tests := []struct { + path string + wantStatus int + wantBody string + }{ + {"/setup", http.StatusOK, "setup page"}, + {"/setup/callback", http.StatusOK, "setup page"}, + {"/healthz", http.StatusOK, "setup page"}, + {"/other", http.StatusServiceUnavailable, ""}, + {"/webhook", http.StatusServiceUnavailable, ""}, + } + + for _, tt := range tests { + t.Run(tt.path, func(t *testing.T) { + called.Store(false) + req := httptest.NewRequest(http.MethodGet, tt.path, nil) + rec := httptest.NewRecorder() + + gate.ServeHTTP(rec, req) + + if rec.Code != tt.wantStatus { + t.Errorf("Status = %d, want %d", rec.Code, tt.wantStatus) + } + + if tt.wantStatus == http.StatusOK && !called.Load() { + t.Error("Inner handler was not called for allowed path") + } + }) + } +} + +func TestReadyGate_IsReady(t *testing.T) { + gate := NewReadyGate(nil, nil) + + if gate.IsReady() { + t.Error("IsReady() = true, want false initially") + } + + gate.SetReady() + + if !gate.IsReady() { + t.Error("IsReady() = false, want true after SetReady()") + } +} + +func TestReadyGate_SetHandler(t *testing.T) { + gate := NewReadyGate(nil, nil) + + // Initially no handler + req := httptest.NewRequest(http.MethodGet, "/test", nil) + rec := httptest.NewRecorder() + gate.ServeHTTP(rec, req) + + if rec.Code != http.StatusServiceUnavailable { + t.Errorf("Status = %d, want %d when not ready", rec.Code, http.StatusServiceUnavailable) + } + + // Set handler and mark ready + gate.SetHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("dynamic handler")) + })) + gate.SetReady() + + req = httptest.NewRequest(http.MethodGet, "/test", nil) + rec = httptest.NewRecorder() + gate.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("Status = %d, want %d after SetReady()", rec.Code, http.StatusOK) + } + if rec.Body.String() != "dynamic handler" { + t.Errorf("Body = %q, want %q", rec.Body.String(), "dynamic handler") + } +} + +func TestNewConfigFromEnv_Defaults(t *testing.T) { + cfg := NewConfigFromEnv() + + if cfg.MaxRetries != DefaultMaxRetries { + t.Errorf("MaxRetries = %d, want %d", cfg.MaxRetries, DefaultMaxRetries) + } + if cfg.RetryInterval != DefaultRetryInterval { + t.Errorf("RetryInterval = %v, want %v", cfg.RetryInterval, DefaultRetryInterval) + } +} + +func TestReloader_Trigger(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + gate := NewReadyGate(nil, nil) + + var reloadCount atomic.Int32 + reloadFunc := func(ctx context.Context) error { + reloadCount.Add(1) + return nil + } + + reloader := NewReloader(ctx, gate, reloadFunc) + reloader.Start() + + // Trigger a reload + reloader.Trigger() + + // Wait for reload to complete + time.Sleep(50 * time.Millisecond) + + if got := reloadCount.Load(); got != 1 { + t.Errorf("Reload count = %d, want 1", got) + } +} + +func TestReloader_MultipleTriggers(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + gate := NewReadyGate(nil, nil) + + var reloadCount atomic.Int32 + reloadFunc := func(ctx context.Context) error { + reloadCount.Add(1) + // Simulate some work + time.Sleep(20 * time.Millisecond) + return nil + } + + reloader := NewReloader(ctx, gate, reloadFunc) + reloader.Start() + + // Trigger multiple reloads rapidly + reloader.Trigger() + reloader.Trigger() + reloader.Trigger() + + // Wait for reloads to complete + time.Sleep(100 * time.Millisecond) + + // Only one or two reloads should have occurred due to deduplication + got := reloadCount.Load() + if got > 2 { + t.Errorf("Reload count = %d, want <= 2 (due to deduplication)", got) + } + if got < 1 { + t.Errorf("Reload count = %d, want >= 1", got) + } +} + +func TestReloader_ReloadError(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + gate := NewReadyGate(nil, nil) + + var reloadCount atomic.Int32 + reloadFunc := func(ctx context.Context) error { + reloadCount.Add(1) + return errors.New("reload failed") + } + + reloader := NewReloader(ctx, gate, reloadFunc) + reloader.Start() + + // Trigger a reload + reloader.Trigger() + + // Wait for reload to complete + time.Sleep(50 * time.Millisecond) + + // Reload should have been attempted + if got := reloadCount.Load(); got != 1 { + t.Errorf("Reload count = %d, want 1", got) + } +} + +func TestReloader_ContextCancellation(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + gate := NewReadyGate(nil, nil) + + var reloadCount atomic.Int32 + reloadFunc := func(ctx context.Context) error { + reloadCount.Add(1) + return nil + } + + reloader := NewReloader(ctx, gate, reloadFunc) + done := reloader.Start() + + // Cancel context + cancel() + + // Wait for reloader to stop + select { + case <-done: + // Good, reloader stopped + case <-time.After(100 * time.Millisecond): + t.Error("Reloader did not stop after context cancellation") + } +} + +func TestGlobalReloader(t *testing.T) { + // Clear any existing global reloader + SetGlobalReloader(nil) + + // TriggerReload should be a no-op when no global reloader is set + TriggerReload() // Should not panic + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + gate := NewReadyGate(nil, nil) + + var reloadCount atomic.Int32 + reloadFunc := func(ctx context.Context) error { + reloadCount.Add(1) + return nil + } + + reloader := NewReloader(ctx, gate, reloadFunc) + reloader.Start() + + // Set global reloader + SetGlobalReloader(reloader) + + // Now TriggerReload should work + TriggerReload() + + // Wait for reload to complete + time.Sleep(50 * time.Millisecond) + + if got := reloadCount.Load(); got != 1 { + t.Errorf("Reload count = %d, want 1", got) + } + + // Clean up + SetGlobalReloader(nil) +} diff --git a/configwait/reloader.go b/configwait/reloader.go new file mode 100644 index 0000000..a1541fd --- /dev/null +++ b/configwait/reloader.go @@ -0,0 +1,152 @@ +// Copyright 2025 CruxStack +// SPDX-License-Identifier: MIT + +package configwait + +import ( + "context" + "os" + "os/signal" + "sync" + "syscall" + + "github.com/chainguard-dev/clog" +) + +// ReloadFunc is called when a reload is triggered. +type ReloadFunc func(ctx context.Context) error + +// Reloader manages configuration reloading via SIGHUP or programmatic triggers. +type Reloader struct { + gate *ReadyGate + reloadFunc ReloadFunc + ctx context.Context + + mu sync.Mutex + reloading bool + reloadCh chan struct{} +} + +// NewReloader creates a Reloader that calls reloadFunc when triggered. +func NewReloader(ctx context.Context, gate *ReadyGate, reloadFunc ReloadFunc) *Reloader { + return &Reloader{ + gate: gate, + reloadFunc: reloadFunc, + ctx: ctx, + reloadCh: make(chan struct{}, 1), + } +} + +// Start begins listening for SIGHUP signals and programmatic triggers. +// Returns a channel that closes when the reloader stops. +func (r *Reloader) Start() <-chan struct{} { + done := make(chan struct{}) + log := clog.FromContext(r.ctx) + + sighupCh := make(chan os.Signal, 1) + signal.Notify(sighupCh, syscall.SIGHUP) + + go func() { + defer close(done) + defer signal.Stop(sighupCh) + + for { + select { + case <-r.ctx.Done(): + return + case <-sighupCh: + log.Infof("[reloader] received SIGHUP, triggering reload") + r.doReload() + case <-r.reloadCh: + log.Infof("[reloader] programmatic reload triggered") + r.doReload() + } + } + }() + + return done +} + +// Trigger requests a configuration reload. Safe to call from any goroutine. +func (r *Reloader) Trigger() { + log := clog.FromContext(r.ctx) + + select { + case r.reloadCh <- struct{}{}: + default: + log.Infof("[reloader] reload already pending, ignoring trigger") + } +} + +// doReload performs the reload operation. +func (r *Reloader) doReload() { + log := clog.FromContext(r.ctx) + + r.mu.Lock() + if r.reloading { + r.mu.Unlock() + log.Infof("[reloader] reload already in progress, skipping") + return + } + r.reloading = true + r.mu.Unlock() + + defer func() { + r.mu.Lock() + r.reloading = false + r.mu.Unlock() + }() + + log.Infof("[reloader] starting configuration reload...") + + if err := r.reloadFunc(r.ctx); err != nil { + log.Errorf("[reloader] reload failed: %v", err) + return + } + + log.Infof("[reloader] configuration reloaded successfully") +} + +var ( + globalReloaderMu sync.RWMutex + globalReloader *Reloader + + reloadCounter int64 + reloadCounterMu sync.Mutex +) + +// SetGlobalReloader sets the global reloader instance. +func SetGlobalReloader(r *Reloader) { + globalReloaderMu.Lock() + defer globalReloaderMu.Unlock() + globalReloader = r +} + +// TriggerReload triggers a reload using the global reloader (no-op if unset). +func TriggerReload() { + reloadCounterMu.Lock() + reloadCounter++ + reloadCounterMu.Unlock() + + globalReloaderMu.RLock() + r := globalReloader + globalReloaderMu.RUnlock() + + if r != nil { + r.Trigger() + } +} + +// GetReloadCount returns the number of times TriggerReload has been called. +func GetReloadCount() int64 { + reloadCounterMu.Lock() + defer reloadCounterMu.Unlock() + return reloadCounter +} + +// ResetReloadCounter resets the reload counter to zero. +func ResetReloadCounter() { + reloadCounterMu.Lock() + defer reloadCounterMu.Unlock() + reloadCounter = 0 +} diff --git a/examples/simple/.env.example b/examples/simple/.env.example new file mode 100644 index 0000000..880339c --- /dev/null +++ b/examples/simple/.env.example @@ -0,0 +1,52 @@ +# Simple GitHub App Example - Configuration +# Copy this file to .env and customize as needed + +# ------------------------------------------------------------------------------ +# Logging +# ------------------------------------------------------------------------------ + +# Log output format: "text" (default) or "json" +LOG_FORMAT=text + +# ------------------------------------------------------------------------------ +# GitHub App Installer +# ------------------------------------------------------------------------------ + +# Enable the web-based installer UI at /setup +# Set to "false" after creating your GitHub App +GITHUB_APP_INSTALLER_ENABLED=true + +# GitHub base URL (change for GitHub Enterprise Server) +GITHUB_URL=https://github.com + +# Organization to create the app under (leave empty for personal account) +GITHUB_ORG= + +# ------------------------------------------------------------------------------ +# Server +# ------------------------------------------------------------------------------ + +# HTTP port to listen on +PORT=8080 + +# ------------------------------------------------------------------------------ +# Config Wait (startup behavior) +# ------------------------------------------------------------------------------ + +# Maximum retry attempts when waiting for configuration +CONFIG_WAIT_MAX_RETRIES=30 + +# Duration between retry attempts +CONFIG_WAIT_RETRY_INTERVAL=2s + +# ------------------------------------------------------------------------------ +# GitHub App Credentials (auto-populated by installer) +# These are filled in automatically when you create the app via /setup +# ------------------------------------------------------------------------------ + +# GITHUB_APP_ID= +# GITHUB_APP_SLUG= +# GITHUB_WEBHOOK_SECRET= +# GITHUB_CLIENT_ID= +# GITHUB_CLIENT_SECRET= +# GITHUB_APP_PRIVATE_KEY= diff --git a/examples/simple/Dockerfile b/examples/simple/Dockerfile new file mode 100644 index 0000000..0898fc9 --- /dev/null +++ b/examples/simple/Dockerfile @@ -0,0 +1,44 @@ +# syntax=docker/dockerfile:1 + +# ------------------------------------------------------------------ builder --- +FROM golang:1.25-alpine AS builder + +RUN apk add --no-cache git ca-certificates + +WORKDIR /build + +# Copy go.mod files first for better caching +COPY go.mod go.sum ./ +RUN go mod download + +# Copy source code +COPY . . + +# Build the example binary +WORKDIR /build/examples/simple +RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w" -o /out/app . + +# ------------------------------------------------------------------ runtime --- +FROM alpine:3.21 + +RUN apk add --no-cache ca-certificates + +# Create non-root user +RUN addgroup -g 1000 app && \ + adduser -u 1000 -G app -s /bin/sh -D app + +# Create data directory for .env file storage +RUN mkdir -p /data && chown app:app /data + +COPY --from=builder /out/app /usr/local/bin/app + +USER app +WORKDIR /data + +EXPOSE 8080 + +ENV PORT=8080 +ENV STORAGE_MODE=envfile +ENV STORAGE_DIR=/data/.env + +CMD ["app"] diff --git a/examples/simple/README.md b/examples/simple/README.md new file mode 100644 index 0000000..0cf8995 --- /dev/null +++ b/examples/simple/README.md @@ -0,0 +1,163 @@ +# Simple GitHub App Example + +A minimal example demonstrating a GitHub App with webhook handling using Docker +and Docker Compose. + +## Features + +- Web-based GitHub App installer at `/setup` +- Webhook endpoint at `/webhook` that logs received events +- Configurable logging format (text or JSON via slog) +- Credentials stored in `.env` file (envfile storage backend) +- Health check endpoint at `/healthz` + +## Prerequisites + +- Docker Engine 24.0+ and Docker Compose v2.20+ +- [ngrok](https://ngrok.com/) or similar tunnel for webhook delivery +- GitHub account with permission to create GitHub Apps + +## Quick Start + +### 1. Start the Application + +```bash +cd examples/simple +docker compose up --build +``` + +The app starts with the installer enabled by default. + +### 2. Expose the Application + +GitHub needs to reach your webhook endpoint. Use ngrok or a similar tunnel: + +```bash +ngrok http 8080 +``` + +Copy the HTTPS forwarding URL (e.g., `https://abc123.ngrok-free.app`). + +### 3. Create the GitHub App + +1. Open your ngrok URL with `/setup` path in a browser: + ``` + https://abc123.ngrok-free.app/setup + ``` + +2. Enter your desired app name + +3. Update the webhook URL to your ngrok URL + `/webhook`: + ``` + https://abc123.ngrok-free.app/webhook + ``` + +4. Click "Create GitHub App" and authorize the app creation + +5. The installer saves credentials automatically and the app reloads + +### 4. Install the App + +After creation, click "Install App" to grant the app access to your +repositories. Select which repositories should send webhook events. + +### 5. Test Webhooks + +Push a commit or create a pull request in an installed repository. You should +see log output like: + +``` +level=INFO msg="received webhook" event=push action="" delivery_id=abc123 repository=owner/repo sender=username payload_size=1234 +``` + +## Configuration + +Copy `.env.example` to `.env` to customize settings: + +```bash +cp .env.example .env +``` + +### Environment Variables + +| Variable | Description | Default | +|--------------------------------|------------------------------|---------------------| +| `LOG_FORMAT` | Log format: `text` or `json` | `text` | +| `PORT` | HTTP port | `8080` | +| `GITHUB_APP_INSTALLER_ENABLED` | Enable installer UI | `true` | +| `GITHUB_URL` | GitHub base URL | `https://github.com`| +| `GITHUB_ORG` | Organization for app | - | + +### Disabling the Installer + +After setup, disable the installer for security: + +1. Click "Disable Setup & Continue" in the success page, or +2. Set `GITHUB_APP_INSTALLER_ENABLED=false` in `.env` and restart + +## Endpoints + +| Path | Description | +|------------|--------------------------------| +| `/setup` | GitHub App installer (enabled) | +| `/webhook` | Webhook receiver | +| `/healthz` | Health check | + +## Logs + +The application uses Go's `slog` package with configurable output format. + +**Text format** (default): +``` +level=INFO msg="received webhook" event=push repository=owner/repo +``` + +**JSON format** (`LOG_FORMAT=json`): +```json +{"level":"INFO","msg":"received webhook","event":"push","repository":"owner/repo"} +``` + +## Architecture + +``` + +-------------------+ + | ngrok | + | (public HTTPS) | + +---------+---------+ + | + v ++-------------------------------+-------------------------------+ +| App Container | +| | +| /setup --> Installer (creates GitHub App) | +| /webhook --> Webhook Handler (logs events) | +| /healthz --> Health Check | +| | +| Storage: /data/.env (persisted via Docker volume) | ++---------------------------------------------------------------+ +``` + +## Troubleshooting + +### App not receiving webhooks + +1. Verify ngrok is running and the URL hasn't changed +2. Check webhook URL in GitHub App settings matches your ngrok URL + `/webhook` +3. View webhook deliveries in GitHub App settings for error details + +### Configuration not loading + +Check logs for retry messages: +```bash +docker compose logs -f +``` + +The app retries configuration loading for 60 seconds by default. + +### Resetting the app + +To start fresh, remove the Docker volume: +```bash +docker compose down -v +docker compose up --build +``` diff --git a/examples/simple/compose.yaml b/examples/simple/compose.yaml new file mode 100644 index 0000000..6578b9c --- /dev/null +++ b/examples/simple/compose.yaml @@ -0,0 +1,33 @@ +services: + app: + build: + context: ../.. + dockerfile: examples/simple/Dockerfile + ports: + - "${PORT:-8080}:8080" + volumes: + # persist credentials across restarts + - app-data:/data + environment: + - LOG_FORMAT=${LOG_FORMAT:-text} + - GITHUB_APP_INSTALLER_ENABLED=${GITHUB_APP_INSTALLER_ENABLED:-true} + - GITHUB_URL=${GITHUB_URL:-https://github.com} + - GITHUB_ORG=${GITHUB_ORG:-} + - STORAGE_MODE=envfile + - STORAGE_DIR=/data/.env + - CONFIG_WAIT_MAX_RETRIES=${CONFIG_WAIT_MAX_RETRIES:-30} + - CONFIG_WAIT_RETRY_INTERVAL=${CONFIG_WAIT_RETRY_INTERVAL:-2s} + env_file: + # load .env file if it exists (for pre-configured credentials) + - path: .env + required: false + restart: unless-stopped + healthcheck: + test: ["CMD", "wget", "-q", "--spider", "http://localhost:8080/healthz"] + interval: 10s + timeout: 5s + retries: 3 + start_period: 5s + +volumes: + app-data: diff --git a/examples/simple/main.go b/examples/simple/main.go new file mode 100644 index 0000000..05bd9f9 --- /dev/null +++ b/examples/simple/main.go @@ -0,0 +1,273 @@ +// Copyright 2025 CruxStack +// SPDX-License-Identifier: MIT + +// Example demonstrating a GitHub App with webhook handling. +package main + +import ( + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "os" + "os/signal" + "strings" + "time" + + "github.com/cruxstack/github-app-setup-go/configstore" + "github.com/cruxstack/github-app-setup-go/configwait" + "github.com/cruxstack/github-app-setup-go/installer" +) + +const ( + defaultPort = 8080 + defaultReadHeaderTimeout = 10 * time.Second + defaultShutdownTimeout = 30 * time.Second +) + +func main() { + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) + defer cancel() + + log := setupLogger() + ctx = withLogger(ctx, log) + + port := defaultPort + if p := os.Getenv("PORT"); p != "" { + fmt.Sscanf(p, "%d", &port) + } + + allowedPaths := []string{"/healthz"} + installerEnabled := configstore.InstallerEnabled() + if installerEnabled { + allowedPaths = append(allowedPaths, "/setup", "/callback", "/") + } + + gate := configwait.NewReadyGate(nil, allowedPaths) + mux := http.NewServeMux() + + mux.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) { + if gate.IsReady() { + w.WriteHeader(http.StatusOK) + w.Write([]byte("ok")) + } else { + w.WriteHeader(http.StatusServiceUnavailable) + w.Write([]byte("not ready")) + } + }) + + if installerEnabled { + store, err := configstore.NewFromEnv() + if err != nil { + log.Error("failed to create config store", "error", err) + os.Exit(1) + } + + manifest := installer.Manifest{ + URL: "https://github.com/cruxstack/github-app-setup-go", + Public: false, + DefaultPerms: map[string]string{ + "contents": "read", + "pull_requests": "read", + }, + DefaultEvents: []string{ + "push", + "pull_request", + }, + } + + installerCfg := installer.NewConfigFromEnv() + installerCfg.Store = store + installerCfg.Manifest = manifest + installerCfg.AppDisplayName = "Simple Webhook App" + + installerHandler, err := installer.New(installerCfg) + if err != nil { + log.Error("failed to create installer handler", "error", err) + os.Exit(1) + } + + mux.Handle("/setup", installerHandler) + mux.Handle("/setup/", installerHandler) + mux.Handle("/callback", installerHandler) + mux.Handle("/", installerHandler) + + log.Info("installer enabled, visit /setup to create GitHub App") + } + + gate.SetHandler(mux) + + srv := &http.Server{ + Addr: fmt.Sprintf(":%d", port), + ReadHeaderTimeout: defaultReadHeaderTimeout, + Handler: gate, + } + + log.Info("starting HTTP server", "port", port, "installer_enabled", installerEnabled) + + go func() { + if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Error("server error", "error", err) + os.Exit(1) + } + }() + + go func() { + waitCfg := configwait.NewConfigFromEnv() + + err := configwait.Wait(ctx, waitCfg, func(ctx context.Context) error { + return loadConfig(ctx, log, mux) + }) + + if err != nil { + log.Error("failed to load configuration after retries", "error", err) + os.Exit(1) + } + + log.Info("configuration loaded, service is ready") + gate.SetReady() + + reloader := configwait.NewReloader(ctx, gate, func(ctx context.Context) error { + return loadConfig(ctx, log, mux) + }) + configwait.SetGlobalReloader(reloader) + reloader.Start() + + log.Info("configuration reloader started (send SIGHUP to reload)") + }() + + <-ctx.Done() + log.Info("shutting down server...") + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), defaultShutdownTimeout) + defer shutdownCancel() + + if err := srv.Shutdown(shutdownCtx); err != nil { + log.Error("server shutdown error", "error", err) + os.Exit(1) + } +} + +// loadConfig loads configuration and sets up the webhook handler. +func loadConfig(_ context.Context, log *slog.Logger, mux *http.ServeMux) error { + webhookSecret := os.Getenv(configstore.EnvGitHubWebhookSecret) + if webhookSecret == "" { + return fmt.Errorf("%s is not set", configstore.EnvGitHubWebhookSecret) + } + + appID := os.Getenv(configstore.EnvGitHubAppID) + if appID == "" { + return fmt.Errorf("%s is not set", configstore.EnvGitHubAppID) + } + + log.Info("loaded GitHub App configuration", "app_id", appID) + + mux.HandleFunc("/webhook", webhookHandler(log, webhookSecret)) + + return nil +} + +// webhookHandler returns an HTTP handler that processes GitHub webhooks. +func webhookHandler(log *slog.Logger, secret string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + body, err := io.ReadAll(r.Body) + if err != nil { + log.Error("failed to read webhook body", "error", err) + http.Error(w, "failed to read body", http.StatusBadRequest) + return + } + defer r.Body.Close() + + signature := r.Header.Get("X-Hub-Signature-256") + if !validateSignature(body, signature, secret) { + log.Warn("webhook signature validation failed", + "remote_addr", r.RemoteAddr, + "has_signature", signature != "", + ) + http.Error(w, "invalid signature", http.StatusUnauthorized) + return + } + + eventType := r.Header.Get("X-GitHub-Event") + deliveryID := r.Header.Get("X-GitHub-Delivery") + + var payload struct { + Action string `json:"action"` + Repository struct { + FullName string `json:"full_name"` + } `json:"repository"` + Sender struct { + Login string `json:"login"` + } `json:"sender"` + } + if err := json.Unmarshal(body, &payload); err != nil { + log.Warn("failed to parse webhook payload", "error", err) + } + + log.Info("received webhook", + "event", eventType, + "action", payload.Action, + "delivery_id", deliveryID, + "repository", payload.Repository.FullName, + "sender", payload.Sender.Login, + "payload_size", len(body), + ) + + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status":"ok"}`)) + } +} + +// validateSignature validates the GitHub webhook signature. +func validateSignature(payload []byte, signature, secret string) bool { + if signature == "" || secret == "" { + return false + } + + if !strings.HasPrefix(signature, "sha256=") { + return false + } + sig := strings.TrimPrefix(signature, "sha256=") + + mac := hmac.New(sha256.New, []byte(secret)) + mac.Write(payload) + expected := hex.EncodeToString(mac.Sum(nil)) + + return hmac.Equal([]byte(sig), []byte(expected)) +} + +// setupLogger creates a slog.Logger based on LOG_FORMAT environment variable. +func setupLogger() *slog.Logger { + format := strings.ToLower(os.Getenv("LOG_FORMAT")) + + var handler slog.Handler + opts := &slog.HandlerOptions{ + Level: slog.LevelInfo, + } + + switch format { + case "json": + handler = slog.NewJSONHandler(os.Stderr, opts) + default: + handler = slog.NewTextHandler(os.Stderr, opts) + } + + return slog.New(handler) +} + +type loggerKey struct{} + +// withLogger adds a logger to the context. +func withLogger(ctx context.Context, log *slog.Logger) context.Context { + return context.WithValue(ctx, loggerKey{}, log) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..e3aa7f5 --- /dev/null +++ b/go.mod @@ -0,0 +1,26 @@ +module github.com/cruxstack/github-app-setup-go + +go 1.25 + +require ( + github.com/aws/aws-sdk-go-v2 v1.41.0 + github.com/aws/aws-sdk-go-v2/config v1.32.5 + github.com/aws/aws-sdk-go-v2/service/ssm v1.67.7 + github.com/chainguard-dev/clog v1.8.0 + gopkg.in/yaml.v3 v3.0.1 +) + +require ( + github.com/aws/aws-sdk-go-v2/credentials v1.19.5 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.16 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.16 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.16 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.16 // indirect + github.com/aws/aws-sdk-go-v2/service/signin v1.0.4 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.30.7 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.12 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.41.5 // indirect + github.com/aws/smithy-go v1.24.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..8e9ab60 --- /dev/null +++ b/go.sum @@ -0,0 +1,36 @@ +github.com/aws/aws-sdk-go-v2 v1.41.0 h1:tNvqh1s+v0vFYdA1xq0aOJH+Y5cRyZ5upu6roPgPKd4= +github.com/aws/aws-sdk-go-v2 v1.41.0/go.mod h1:MayyLB8y+buD9hZqkCW3kX1AKq07Y5pXxtgB+rRFhz0= +github.com/aws/aws-sdk-go-v2/config v1.32.5 h1:pz3duhAfUgnxbtVhIK39PGF/AHYyrzGEyRD9Og0QrE8= +github.com/aws/aws-sdk-go-v2/config v1.32.5/go.mod h1:xmDjzSUs/d0BB7ClzYPAZMmgQdrodNjPPhd6bGASwoE= +github.com/aws/aws-sdk-go-v2/credentials v1.19.5 h1:xMo63RlqP3ZZydpJDMBsH9uJ10hgHYfQFIk1cHDXrR4= +github.com/aws/aws-sdk-go-v2/credentials v1.19.5/go.mod h1:hhbH6oRcou+LpXfA/0vPElh/e0M3aFeOblE1sssAAEk= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.16 h1:80+uETIWS1BqjnN9uJ0dBUaETh+P1XwFy5vwHwK5r9k= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.16/go.mod h1:wOOsYuxYuB/7FlnVtzeBYRcjSRtQpAW0hCP7tIULMwo= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.16 h1:rgGwPzb82iBYSvHMHXc8h9mRoOUBZIGFgKb9qniaZZc= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.16/go.mod h1:L/UxsGeKpGoIj6DxfhOWHWQ/kGKcd4I1VncE4++IyKA= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.16 h1:1jtGzuV7c82xnqOVfx2F0xmJcOw5374L7N6juGW6x6U= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.16/go.mod h1:M2E5OQf+XLe+SZGmmpaI2yy+J326aFf6/+54PoxSANc= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 h1:WKuaxf++XKWlHWu9ECbMlha8WOEGm0OUEZqm4K/Gcfk= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4/go.mod h1:ZWy7j6v1vWGmPReu0iSGvRiise4YI5SkR3OHKTZ6Wuc= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4 h1:0ryTNEdJbzUCEWkVXEXoqlXV72J5keC1GvILMOuD00E= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4/go.mod h1:HQ4qwNZh32C3CBeO6iJLQlgtMzqeG17ziAA/3KDJFow= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.16 h1:oHjJHeUy0ImIV0bsrX0X91GkV5nJAyv1l1CC9lnO0TI= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.16/go.mod h1:iRSNGgOYmiYwSCXxXaKb9HfOEj40+oTKn8pTxMlYkRM= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.4 h1:HpI7aMmJ+mm1wkSHIA2t5EaFFv5EFYXePW30p1EIrbQ= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.4/go.mod h1:C5RdGMYGlfM0gYq/tifqgn4EbyX99V15P2V3R+VHbQU= +github.com/aws/aws-sdk-go-v2/service/ssm v1.67.7 h1:0q42w8/mywPCzQD1IoWIBUCYfBJc5+fLwtZNpHffBSM= +github.com/aws/aws-sdk-go-v2/service/ssm v1.67.7/go.mod h1:urlU9nfKJEfi0+8T9luB3f3Y0UnomH/yxI7tTrfH9es= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.7 h1:eYnlt6QxnFINKzwxP5/Ucs1vkG7VT3Iezmvfgc2waUw= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.7/go.mod h1:+fWt2UHSb4kS7Pu8y+BMBvJF0EWx+4H0hzNwtDNRTrg= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.12 h1:AHDr0DaHIAo8c9t1emrzAlVDFp+iMMKnPdYy6XO4MCE= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.12/go.mod h1:GQ73XawFFiWxyWXMHWfhiomvP3tXtdNar/fi8z18sx0= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.5 h1:SciGFVNZ4mHdm7gpD1dgZYnCuVdX1s+lFTg4+4DOy70= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.5/go.mod h1:iW40X4QBmUxdP+fZNOpfmkdMZqsovezbAeO+Ubiv2pk= +github.com/aws/smithy-go v1.24.0 h1:LpilSUItNPFr1eY85RYgTIg5eIEPtvFbskaFcmmIUnk= +github.com/aws/smithy-go v1.24.0/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= +github.com/chainguard-dev/clog v1.8.0 h1:frlTMEdg3XQR+ioQ6O9i92uigY8GTUcWKpuCFkhcCHA= +github.com/chainguard-dev/clog v1.8.0/go.mod h1:5MQOZi+Iu7fV7GcJG8ag8rCB5elEOpqRMKEASgnGVdo= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/installer/installer.go b/installer/installer.go new file mode 100644 index 0000000..140c2c3 --- /dev/null +++ b/installer/installer.go @@ -0,0 +1,461 @@ +// Copyright 2025 CruxStack +// SPDX-License-Identifier: MIT + +// Package installer provides a web-based installer for creating GitHub Apps +// using the GitHub App Manifest flow. +package installer + +import ( + "bytes" + "context" + "embed" + "encoding/json" + "fmt" + "html/template" + "io" + "net/http" + "os" + "strings" + "time" + + "github.com/chainguard-dev/clog" + + "github.com/cruxstack/github-app-setup-go/configstore" + "github.com/cruxstack/github-app-setup-go/configwait" +) + +//go:embed templates/* +var templateFS embed.FS + +var indexTemplate = template.Must(template.ParseFS(templateFS, "templates/index.html")) +var successTemplate = template.Must(template.ParseFS(templateFS, "templates/success.html")) + +const ( + httpClientTimeout = 30 * time.Second + EnvGitHubURL = "GITHUB_URL" + EnvGitHubOrg = "GITHUB_ORG" + disableSetupPath = "/setup/disable" +) + +// CredentialsSavedFunc is called after credentials are saved. +type CredentialsSavedFunc func(ctx context.Context, creds *configstore.AppCredentials) error + +// Config holds the installer configuration. +type Config struct { + Store configstore.Store + Manifest Manifest + AppDisplayName string + GitHubURL string + GitHubOrg string + RedirectURL string + WebhookURL string + OnCredentialsSaved CredentialsSavedFunc +} + +// NewConfigFromEnv creates a Config from environment variables. +func NewConfigFromEnv() Config { + return Config{ + GitHubURL: configstore.GetEnvDefault(EnvGitHubURL, "https://github.com"), + GitHubOrg: os.Getenv(EnvGitHubOrg), + } +} + +// Handler handles the GitHub App manifest installation flow. +type Handler struct { + config Config +} + +type indexTemplateData struct { + AppDisplayName string + GitHubURL string + GitHubOrg string + FormActionURL string + ManifestJSON template.JS + WebhookURL string + NeedsWebhook bool + DefaultAppName string +} + +type successTemplateData struct { + AppDisplayName string + AppID int64 + AppSlug string + HTMLURL string + InstallURL string + DisableActionURL string + InstallerDisabled bool +} + +// New creates a new installer Handler with the given configuration. +func New(cfg Config) (*Handler, error) { + if cfg.Store == nil { + return nil, fmt.Errorf("store is required") + } + if cfg.GitHubURL == "" { + cfg.GitHubURL = "https://github.com" + } + if cfg.AppDisplayName == "" { + cfg.AppDisplayName = "GitHub App" + } + return &Handler{config: cfg}, nil +} + +// ServeHTTP implements http.Handler. +func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + path := r.URL.Path + + switch { + case (r.Method == http.MethodGet || r.Method == http.MethodHead) && (path == "/" || path == ""): + h.handleRoot(w, r) + case (r.Method == http.MethodGet || r.Method == http.MethodHead) && (path == "/setup" || path == "/setup/"): + h.handleIndex(w, r) + case (r.Method == http.MethodGet || r.Method == http.MethodHead) && path == "/callback": + h.handleCallback(w, r) + + case r.Method == http.MethodPost && (path == disableSetupPath || path == disableSetupPath+"/"): + h.handleDisable(w, r) + default: + http.NotFound(w, r) + } +} + +// handleRoot redirects to /setup if enabled, otherwise returns 404. +func (h *Handler) handleRoot(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + log := clog.FromContext(ctx) + + status, err := h.config.Store.Status(ctx) + if err != nil { + log.Errorf("[installer] failed to read installer status: %v", err) + http.Error(w, "Failed to load installer status", http.StatusInternalServerError) + return + } + + if status != nil && status.InstallerDisabled { + http.NotFound(w, r) + return + } + + http.Redirect(w, r, "/setup", http.StatusFound) +} + +// handleIndex serves the main page. +func (h *Handler) handleIndex(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + log := clog.FromContext(ctx) + + status, err := h.config.Store.Status(ctx) + + if err != nil { + log.Errorf("[installer] failed to read installer status: %v", err) + http.Error(w, "Failed to load installer status", http.StatusInternalServerError) + return + } + if status != nil && status.Registered { + data := h.successDataFromStatus(status) + h.renderSuccess(w, r, data) + return + } + + redirectURL := h.config.RedirectURL + if redirectURL == "" { + redirectURL = getBaseURL(ctx, r) + log.Infof("[installer] auto-detected redirect url: url=%s host=%s forwarded_host=%s", + redirectURL, r.Host, r.Header.Get("X-Forwarded-Host")) + } + + webhookURL := h.config.WebhookURL + if webhookURL == "" { + webhookURL = r.FormValue("webhook_url") + if webhookURL == "" { + webhookURL = getBaseURL(ctx, r) + "/webhook" + log.Infof("[installer] auto-detected webhook url: url=%s", webhookURL) + } + } + + manifest := h.config.Manifest.Clone() + if manifest == nil { + manifest = &Manifest{} + } + manifest.RedirectURL = redirectURL + "/callback" + manifest.HookAttributes.URL = webhookURL + manifest.HookAttributes.Active = webhookURL != "" + + log.Infof("[installer] manifest redirect_url: %s", manifest.RedirectURL) + manifestJSON, err := json.Marshal(manifest) + if err != nil { + http.Error(w, "Failed to generate manifest", http.StatusInternalServerError) + return + } + + var formActionURL string + if h.config.GitHubOrg != "" { + formActionURL = fmt.Sprintf("%s/organizations/%s/settings/apps/new", + h.config.GitHubURL, h.config.GitHubOrg) + } else { + formActionURL = fmt.Sprintf("%s/settings/apps/new", h.config.GitHubURL) + } + + defaultAppName := manifest.Name + if defaultAppName == "" { + defaultAppName = strings.ToLower(strings.ReplaceAll(h.config.AppDisplayName, " ", "-")) + } + + data := indexTemplateData{ + AppDisplayName: h.config.AppDisplayName, + GitHubURL: h.config.GitHubURL, + GitHubOrg: h.config.GitHubOrg, + FormActionURL: formActionURL, + ManifestJSON: template.JS(manifestJSON), + WebhookURL: webhookURL, + NeedsWebhook: h.config.WebhookURL == "", + DefaultAppName: defaultAppName, + } + + var buf bytes.Buffer + if err := indexTemplate.Execute(&buf, data); err != nil { + log.Errorf("[installer] failed to render index template: %v", err) + http.Error(w, "Failed to render page", http.StatusInternalServerError) + return + } + setSecurityHeaders(w) + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusOK) + if _, err := buf.WriteTo(w); err != nil { + log.Errorf("[installer] failed to write response: %v", err) + } +} + +// handleCallback handles the GitHub redirect after app creation. +func (h *Handler) handleCallback(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + log := clog.FromContext(ctx) + + code := r.URL.Query().Get("code") + if code == "" { + http.Error(w, "Missing code parameter", http.StatusBadRequest) + return + } + if !isValidOAuthCode(code) { + http.Error(w, "Invalid code parameter", http.StatusBadRequest) + return + } + + var customDomain string + if cookie, err := r.Cookie("custom_domain"); err == nil { + customDomain = cookie.Value + http.SetCookie(w, &http.Cookie{ + Name: "custom_domain", + Value: "", + Path: "/", + MaxAge: -1, + }) + } + + creds, err := h.exchangeCode(ctx, code) + if err != nil { + log.Errorf("[installer] failed to exchange code: %v", err) + http.Error(w, "Failed to exchange code", http.StatusInternalServerError) + return + } + + if creds.CustomFields == nil { + creds.CustomFields = make(map[string]string) + } + + if customDomain != "" { + creds.CustomFields["CUSTOM_DOMAIN"] = customDomain + } + + if h.config.OnCredentialsSaved != nil { + if err := h.config.OnCredentialsSaved(ctx, creds); err != nil { + log.Errorf("[installer] OnCredentialsSaved callback failed: %v", err) + } + } + + if err := h.config.Store.Save(ctx, creds); err != nil { + log.Errorf("[installer] failed to save credentials: %v", err) + http.Error(w, "Failed to save credentials", http.StatusInternalServerError) + return + } + + log.Infof("[installer] successfully created github app: slug=%s app_id=%d", creds.AppSlug, creds.AppID) + + log.Infof("[installer] triggering configuration reload") + configwait.TriggerReload() + + data := h.successDataFromCreds(creds) + h.renderSuccess(w, r, data) +} + +// exchangeCode exchanges the temporary code for app credentials. +func (h *Handler) exchangeCode(ctx context.Context, code string) (*configstore.AppCredentials, error) { + url := fmt.Sprintf("%s/api/v3/app-manifests/%s/conversions", h.config.GitHubURL, code) + + if h.config.GitHubURL == "https://github.com" { + url = fmt.Sprintf("https://api.github.com/app-manifests/%s/conversions", code) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Accept", "application/vnd.github+json") + + client := &http.Client{ + Timeout: httpClientTimeout, + } + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to call GitHub API: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusCreated { + return nil, fmt.Errorf("GitHub API returned %d: %s", resp.StatusCode, string(body)) + } + + var creds configstore.AppCredentials + if err := json.Unmarshal(body, &creds); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + return &creds, nil +} + +func (h *Handler) handleDisable(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + log := clog.FromContext(ctx) + + status, err := h.config.Store.Status(ctx) + if err != nil { + log.Errorf("[installer] failed to check status: %v", err) + http.Error(w, "Failed to check installer status", http.StatusInternalServerError) + return + } + if status == nil || !status.Registered { + http.Error(w, "Cannot disable installer before app is registered", http.StatusBadRequest) + return + } + + if err := h.config.Store.DisableInstaller(ctx); err != nil { + log.Errorf("[installer] failed to disable installer: %v", err) + http.Error(w, "Failed to disable installer", http.StatusInternalServerError) + return + } + + log.Infof("[installer] installer disabled via setup UI") + http.Redirect(w, r, "/healthz", http.StatusSeeOther) +} + +func (h *Handler) successDataFromCreds(creds *configstore.AppCredentials) successTemplateData { + data := successTemplateData{ + AppDisplayName: h.config.AppDisplayName, + AppID: creds.AppID, + AppSlug: creds.AppSlug, + HTMLURL: creds.HTMLURL, + DisableActionURL: disableSetupPath, + } + data.InstallURL = h.installURLFor(creds.AppSlug, creds.HTMLURL) + return data +} + +func (h *Handler) successDataFromStatus(status *configstore.InstallerStatus) successTemplateData { + if status == nil { + return successTemplateData{AppDisplayName: h.config.AppDisplayName} + } + data := successTemplateData{ + AppDisplayName: h.config.AppDisplayName, + AppID: status.AppID, + AppSlug: status.AppSlug, + HTMLURL: status.HTMLURL, + InstallerDisabled: status.InstallerDisabled, + DisableActionURL: disableSetupPath, + } + data.InstallURL = h.installURLFor(status.AppSlug, status.HTMLURL) + return data +} + +func (h *Handler) installURLFor(slug, htmlURL string) string { + if slug != "" { + githubURL := h.config.GitHubURL + if githubURL == "" { + githubURL = "https://github.com" + } + return fmt.Sprintf("%s/apps/%s/installations/new", githubURL, slug) + } + if htmlURL != "" { + trimmed := strings.TrimRight(htmlURL, "/") + return trimmed + "/installations/new" + } + return "" +} + +func (h *Handler) renderSuccess(w http.ResponseWriter, r *http.Request, data successTemplateData) { + ctx := r.Context() + log := clog.FromContext(ctx) + + var buf bytes.Buffer + if err := successTemplate.Execute(&buf, data); err != nil { + log.Errorf("[installer] failed to render success template: %v", err) + http.Error(w, "Failed to render page", http.StatusInternalServerError) + return + } + setSecurityHeaders(w) + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusOK) + if _, err := buf.WriteTo(w); err != nil { + log.Errorf("[installer] failed to write response: %v", err) + } +} + +// setSecurityHeaders sets common security headers. +func setSecurityHeaders(w http.ResponseWriter) { + w.Header().Set("X-Content-Type-Options", "nosniff") + w.Header().Set("X-Frame-Options", "DENY") + w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") +} + +// isValidOAuthCode validates the OAuth code format from GitHub. +func isValidOAuthCode(code string) bool { + if len(code) < 10 || len(code) > 100 { + return false + } + for _, c := range code { + if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9')) { + return false + } + } + return true +} + +// getBaseURL derives the base URL from the request headers. +func getBaseURL(ctx context.Context, r *http.Request) string { + log := clog.FromContext(ctx) + + host := r.Header.Get("X-Forwarded-Host") + if host == "" { + host = r.Host + } + + scheme := r.Header.Get("X-Forwarded-Proto") + if scheme == "" { + scheme = "https" + if host == "localhost" || strings.HasPrefix(host, "localhost:") || + host == "127.0.0.1" || strings.HasPrefix(host, "127.0.0.1:") { + scheme = "http" + } + } else if scheme == "http" && !strings.HasPrefix(host, "localhost") && !strings.HasPrefix(host, "127.0.0.1") { + scheme = "https" + } + + baseURL := scheme + "://" + host + log.Debugf("[installer] getBaseURL: scheme=%s host=%s r.Host=%s X-Forwarded-Proto=%s X-Forwarded-Host=%s result=%s", + scheme, host, r.Host, r.Header.Get("X-Forwarded-Proto"), r.Header.Get("X-Forwarded-Host"), baseURL) + return baseURL +} diff --git a/installer/installer_test.go b/installer/installer_test.go new file mode 100644 index 0000000..c2c7d43 --- /dev/null +++ b/installer/installer_test.go @@ -0,0 +1,503 @@ +// Copyright 2025 CruxStack +// SPDX-License-Identifier: MIT + +package installer + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/cruxstack/github-app-setup-go/configstore" +) + +func TestGetBaseURL(t *testing.T) { + tests := []struct { + name string + host string + xForwardedHost string + xForwardedProto string + want string + }{ + { + name: "localhost defaults to http", + host: "localhost:8080", + want: "http://localhost:8080", + }, + { + name: "localhost without port", + host: "localhost", + want: "http://localhost", + }, + { + name: "127.0.0.1 defaults to http", + host: "127.0.0.1:8080", + want: "http://127.0.0.1:8080", + }, + { + name: "non-localhost defaults to https", + host: "example.com", + want: "https://example.com", + }, + { + name: "non-localhost with port defaults to https", + host: "example.com:8443", + want: "https://example.com:8443", + }, + { + name: "X-Forwarded-Host takes precedence", + host: "internal-lb:8080", + xForwardedHost: "api.example.com", + want: "https://api.example.com", + }, + { + name: "X-Forwarded-Proto http allowed for localhost", + host: "localhost:3000", + xForwardedProto: "http", + want: "http://localhost:3000", + }, + { + name: "X-Forwarded-Proto http upgraded to https for non-localhost", + host: "example.com", + xForwardedProto: "http", + want: "https://example.com", + }, + { + name: "X-Forwarded-Proto https respected", + host: "example.com", + xForwardedProto: "https", + want: "https://example.com", + }, + { + name: "both forwarded headers", + host: "internal:8080", + xForwardedHost: "app.example.com", + xForwardedProto: "https", + want: "https://app.example.com", + }, + { + name: "X-Forwarded-Proto http with localhost forwarded host", + host: "production:8080", + xForwardedHost: "localhost:3000", + xForwardedProto: "http", + want: "http://localhost:3000", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Host = tt.host + + if tt.xForwardedHost != "" { + req.Header.Set("X-Forwarded-Host", tt.xForwardedHost) + } + if tt.xForwardedProto != "" { + req.Header.Set("X-Forwarded-Proto", tt.xForwardedProto) + } + + got := getBaseURL(context.Background(), req) + if got != tt.want { + t.Errorf("getBaseURL() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestInstallURLFor(t *testing.T) { + tests := []struct { + name string + githubURL string + slug string + htmlURL string + want string + }{ + { + name: "slug provided uses GitHub URL", + githubURL: "https://github.com", + slug: "my-app", + htmlURL: "", + want: "https://github.com/apps/my-app/installations/new", + }, + { + name: "slug with GHE URL", + githubURL: "https://github.mycompany.com", + slug: "internal-app", + htmlURL: "", + want: "https://github.mycompany.com/apps/internal-app/installations/new", + }, + { + name: "no slug uses htmlURL", + githubURL: "https://github.com", + slug: "", + htmlURL: "https://github.com/apps/my-app", + want: "https://github.com/apps/my-app/installations/new", + }, + { + name: "htmlURL with trailing slash", + githubURL: "https://github.com", + slug: "", + htmlURL: "https://github.com/apps/my-app/", + want: "https://github.com/apps/my-app/installations/new", + }, + { + name: "neither slug nor htmlURL", + githubURL: "https://github.com", + slug: "", + htmlURL: "", + want: "", + }, + { + name: "slug takes precedence over htmlURL", + githubURL: "https://github.com", + slug: "preferred-app", + htmlURL: "https://github.com/apps/other-app", + want: "https://github.com/apps/preferred-app/installations/new", + }, + { + name: "empty githubURL defaults to github.com", + githubURL: "", + slug: "my-app", + htmlURL: "", + want: "https://github.com/apps/my-app/installations/new", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := &Handler{ + config: Config{ + GitHubURL: tt.githubURL, + }, + } + + got := h.installURLFor(tt.slug, tt.htmlURL) + if got != tt.want { + t.Errorf("installURLFor(%q, %q) = %q, want %q", tt.slug, tt.htmlURL, got, tt.want) + } + }) + } +} + +func TestManifestClone(t *testing.T) { + t.Run("nil manifest returns nil", func(t *testing.T) { + var m *Manifest = nil + got := m.Clone() + if got != nil { + t.Errorf("Clone() = %v, want nil", got) + } + }) + + t.Run("clones all scalar fields", func(t *testing.T) { + original := &Manifest{ + Name: "test-app", + URL: "https://example.com", + RedirectURL: "https://example.com/callback", + Public: true, + HookAttributes: HookAttributes{ + URL: "https://example.com/webhook", + Active: true, + }, + } + + clone := original.Clone() + + if clone.Name != original.Name { + t.Errorf("Clone().Name = %q, want %q", clone.Name, original.Name) + } + if clone.URL != original.URL { + t.Errorf("Clone().URL = %q, want %q", clone.URL, original.URL) + } + if clone.RedirectURL != original.RedirectURL { + t.Errorf("Clone().RedirectURL = %q, want %q", clone.RedirectURL, original.RedirectURL) + } + if clone.Public != original.Public { + t.Errorf("Clone().Public = %v, want %v", clone.Public, original.Public) + } + if clone.HookAttributes.URL != original.HookAttributes.URL { + t.Errorf("Clone().HookAttributes.URL = %q, want %q", clone.HookAttributes.URL, original.HookAttributes.URL) + } + if clone.HookAttributes.Active != original.HookAttributes.Active { + t.Errorf("Clone().HookAttributes.Active = %v, want %v", clone.HookAttributes.Active, original.HookAttributes.Active) + } + }) + + t.Run("DefaultPerms is deep copied", func(t *testing.T) { + original := &Manifest{ + DefaultPerms: map[string]string{ + "contents": "read", + "pull_requests": "write", + }, + } + + clone := original.Clone() + + // Verify values are the same + if clone.DefaultPerms["contents"] != "read" { + t.Error("Clone().DefaultPerms missing expected value") + } + + // Modify clone, original should be unchanged + clone.DefaultPerms["contents"] = "write" + clone.DefaultPerms["new_perm"] = "read" + + if original.DefaultPerms["contents"] != "read" { + t.Error("Modifying clone affected original DefaultPerms") + } + if _, exists := original.DefaultPerms["new_perm"]; exists { + t.Error("Adding to clone affected original DefaultPerms") + } + }) + + t.Run("DefaultEvents is deep copied", func(t *testing.T) { + original := &Manifest{ + DefaultEvents: []string{"push", "pull_request"}, + } + + clone := original.Clone() + + // Verify values are the same + if len(clone.DefaultEvents) != 2 { + t.Error("Clone().DefaultEvents has wrong length") + } + + // Modify clone, original should be unchanged + clone.DefaultEvents[0] = "modified" + clone.DefaultEvents = append(clone.DefaultEvents, "new_event") + + if original.DefaultEvents[0] != "push" { + t.Error("Modifying clone affected original DefaultEvents") + } + if len(original.DefaultEvents) != 2 { + t.Error("Appending to clone affected original DefaultEvents") + } + }) + + t.Run("nil maps and slices handled", func(t *testing.T) { + original := &Manifest{ + Name: "test", + DefaultPerms: nil, + DefaultEvents: nil, + } + + clone := original.Clone() + + if clone.DefaultPerms != nil { + t.Error("Clone() should have nil DefaultPerms when original is nil") + } + if clone.DefaultEvents != nil { + t.Error("Clone() should have nil DefaultEvents when original is nil") + } + }) +} + +func TestNew_Validation(t *testing.T) { + t.Run("nil store returns error", func(t *testing.T) { + _, err := New(Config{Store: nil}) + if err == nil { + t.Error("New() with nil store should return error") + } + }) + + t.Run("valid config succeeds", func(t *testing.T) { + store := &mockStore{} + h, err := New(Config{Store: store}) + if err != nil { + t.Errorf("New() error = %v, want nil", err) + } + if h == nil { + t.Error("New() returned nil handler") + } + }) + + t.Run("empty GitHubURL defaults to github.com", func(t *testing.T) { + store := &mockStore{} + h, _ := New(Config{Store: store, GitHubURL: ""}) + if h.config.GitHubURL != "https://github.com" { + t.Errorf("GitHubURL = %q, want %q", h.config.GitHubURL, "https://github.com") + } + }) + + t.Run("empty AppDisplayName defaults", func(t *testing.T) { + store := &mockStore{} + h, _ := New(Config{Store: store, AppDisplayName: ""}) + if h.config.AppDisplayName != "GitHub App" { + t.Errorf("AppDisplayName = %q, want %q", h.config.AppDisplayName, "GitHub App") + } + }) +} + +func TestHandler_ServeHTTP_Routing(t *testing.T) { + store := &mockStore{ + statusFunc: func(ctx context.Context) (*configstore.InstallerStatus, error) { + return &configstore.InstallerStatus{Registered: false}, nil + }, + } + + h, _ := New(Config{Store: store}) + + tests := []struct { + name string + method string + path string + wantStatus int + }{ + {"GET root redirects to setup", http.MethodGet, "/", http.StatusFound}, + {"GET /setup returns page", http.MethodGet, "/setup", http.StatusOK}, + {"GET /setup/ returns page", http.MethodGet, "/setup/", http.StatusOK}, + {"GET /callback without code returns 400", http.MethodGet, "/callback", http.StatusBadRequest}, + {"POST /setup/disable without registration returns 400", http.MethodPost, "/setup/disable", http.StatusBadRequest}, + {"unknown path returns 404", http.MethodGet, "/unknown", http.StatusNotFound}, + {"POST to GET-only path returns 404", http.MethodPost, "/setup", http.StatusNotFound}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(tt.method, tt.path, nil) + rec := httptest.NewRecorder() + + h.ServeHTTP(rec, req) + + if rec.Code != tt.wantStatus { + t.Errorf("ServeHTTP(%s %s) status = %d, want %d", tt.method, tt.path, rec.Code, tt.wantStatus) + } + }) + } +} + +func TestHandler_handleRoot_DisabledInstaller(t *testing.T) { + store := &mockStore{ + statusFunc: func(ctx context.Context) (*configstore.InstallerStatus, error) { + return &configstore.InstallerStatus{InstallerDisabled: true}, nil + }, + } + + h, _ := New(Config{Store: store}) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + h.ServeHTTP(rec, req) + + if rec.Code != http.StatusNotFound { + t.Errorf("handleRoot() with disabled installer status = %d, want %d", rec.Code, http.StatusNotFound) + } +} + +func TestHandler_handleIndex_AlreadyRegistered(t *testing.T) { + store := &mockStore{ + statusFunc: func(ctx context.Context) (*configstore.InstallerStatus, error) { + return &configstore.InstallerStatus{ + Registered: true, + AppID: 12345, + AppSlug: "my-app", + }, nil + }, + } + + h, _ := New(Config{Store: store}) + + req := httptest.NewRequest(http.MethodGet, "/setup", nil) + rec := httptest.NewRecorder() + + h.ServeHTTP(rec, req) + + // Should show success page + if rec.Code != http.StatusOK { + t.Errorf("handleIndex() with registered app status = %d, want %d", rec.Code, http.StatusOK) + } +} + +func TestIsValidOAuthCode(t *testing.T) { + tests := []struct { + name string + code string + want bool + }{ + {"valid alphanumeric code", "abc123DEF456xyz789", true}, + {"valid 20 character code", "abcdefghij1234567890", true}, + {"code too short", "abc", false}, + {"code too long", string(make([]byte, 101)), false}, + {"contains hyphen", "abc-123", false}, + {"contains underscore", "abc_123", false}, + {"contains space", "abc 123", false}, + {"contains special chars", "abc!@#123", false}, + {"empty string", "", false}, + {"minimum valid length", "abcdefghij", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isValidOAuthCode(tt.code) + if got != tt.want { + t.Errorf("isValidOAuthCode(%q) = %v, want %v", tt.code, got, tt.want) + } + }) + } +} + +func TestHandler_handleCallback_InvalidCode(t *testing.T) { + store := &mockStore{ + statusFunc: func(ctx context.Context) (*configstore.InstallerStatus, error) { + return &configstore.InstallerStatus{Registered: false}, nil + }, + } + + h, _ := New(Config{Store: store}) + + tests := []struct { + name string + code string + wantStatus int + }{ + {"missing code", "", http.StatusBadRequest}, + {"invalid code with special chars", "abc!@#123", http.StatusBadRequest}, + {"code too short", "abc", http.StatusBadRequest}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + url := "/callback" + if tt.code != "" { + url += "?code=" + tt.code + } + req := httptest.NewRequest(http.MethodGet, url, nil) + rec := httptest.NewRecorder() + + h.ServeHTTP(rec, req) + + if rec.Code != tt.wantStatus { + t.Errorf("handleCallback() with code %q status = %d, want %d", tt.code, rec.Code, tt.wantStatus) + } + }) + } +} + +// mockStore implements configstore.Store for testing +type mockStore struct { + saveFunc func(ctx context.Context, creds *configstore.AppCredentials) error + statusFunc func(ctx context.Context) (*configstore.InstallerStatus, error) + disableInstallerFunc func(ctx context.Context) error +} + +func (m *mockStore) Save(ctx context.Context, creds *configstore.AppCredentials) error { + if m.saveFunc != nil { + return m.saveFunc(ctx, creds) + } + return nil +} + +func (m *mockStore) Status(ctx context.Context) (*configstore.InstallerStatus, error) { + if m.statusFunc != nil { + return m.statusFunc(ctx) + } + return &configstore.InstallerStatus{}, nil +} + +func (m *mockStore) DisableInstaller(ctx context.Context) error { + if m.disableInstallerFunc != nil { + return m.disableInstallerFunc(ctx) + } + return nil +} diff --git a/installer/manifest.go b/installer/manifest.go new file mode 100644 index 0000000..9f28713 --- /dev/null +++ b/installer/manifest.go @@ -0,0 +1,50 @@ +// Copyright 2025 CruxStack +// SPDX-License-Identifier: MIT + +package installer + +// Manifest represents a GitHub App manifest. +type Manifest struct { + Name string `json:"name,omitempty"` + URL string `json:"url"` + HookAttributes HookAttributes `json:"hook_attributes"` + RedirectURL string `json:"redirect_url"` + Public bool `json:"public"` + DefaultPerms map[string]string `json:"default_permissions"` + DefaultEvents []string `json:"default_events"` +} + +// HookAttributes configures the webhook for the GitHub App. +type HookAttributes struct { + URL string `json:"url"` + Active bool `json:"active"` +} + +// Clone returns a deep copy of the manifest. +func (m *Manifest) Clone() *Manifest { + if m == nil { + return nil + } + + clone := &Manifest{ + Name: m.Name, + URL: m.URL, + RedirectURL: m.RedirectURL, + Public: m.Public, + HookAttributes: m.HookAttributes, + } + + if m.DefaultPerms != nil { + clone.DefaultPerms = make(map[string]string, len(m.DefaultPerms)) + for k, v := range m.DefaultPerms { + clone.DefaultPerms[k] = v + } + } + + if m.DefaultEvents != nil { + clone.DefaultEvents = make([]string, len(m.DefaultEvents)) + copy(clone.DefaultEvents, m.DefaultEvents) + } + + return clone +} diff --git a/installer/templates/index.html b/installer/templates/index.html new file mode 100644 index 0000000..c79d0e7 --- /dev/null +++ b/installer/templates/index.html @@ -0,0 +1,178 @@ + + + + + + {{.AppDisplayName}} - GitHub App Installer + + + +
+

{{.AppDisplayName}} Installer

+

Create a GitHub App with all the required permissions for {{.AppDisplayName}}.

+ +
+ This will create a private GitHub App with pre-configured permissions. +
+ + {{if .GitHubOrg}} +
+ Creating app for organization: {{.GitHubOrg}} +
+ {{else}} +
+ Creating app for your personal account. Set GITHUB_ORG to create for an organization instead. +
+ {{end}} + +
+
+ + +

The name for your GitHub App (must be unique across GitHub)

+
+ + {{if .NeedsWebhook}} +
+ + +

The URL where GitHub will send webhook events

+
+ {{end}} + + + +
+
+ + + + diff --git a/installer/templates/success.html b/installer/templates/success.html new file mode 100644 index 0000000..4ca6b56 --- /dev/null +++ b/installer/templates/success.html @@ -0,0 +1,184 @@ + + + + + + {{.AppDisplayName}} - GitHub App Created + + + +
+

{{.AppDisplayName}} Created

+ +
+

Success!

+

Your GitHub App {{.AppSlug}} has been created and credentials have been saved.

+
+ + {{if .InstallURL}} +
+

Next Step: Install the App

+

+ The app has been created, but it still needs to be installed on your account or organization + to grant it access to repositories. +

+ Install App +
+ {{end}} + + {{if .DisableActionURL}} +
+

Disable Web Installer

+ {{if .InstallerDisabled}} +

Setup has already been disabled. Restart the service to stop exposing the installer.

+ {{else}} +

Once credentials are saved, disable this page so future visitors can't modify your configuration.

+
+ +
+ {{end}} +
+ {{end}} + + {{- $showAppID := and (ne .AppID 0) (not .InstallerDisabled) -}} + {{if or $showAppID .HTMLURL}} +
+
+ {{if $showAppID}} +
App ID
+
{{.AppID}}
+ {{end}} + {{if .HTMLURL}} +
App URL
+
{{.HTMLURL}}
+ {{end}} +
+
+ {{end}} + +

+ Credentials have been saved to the configured storage backend. + After installing the app, you can configure your services to use the saved credentials. +

+
+ + diff --git a/integration/README.md b/integration/README.md new file mode 100644 index 0000000..e26b598 --- /dev/null +++ b/integration/README.md @@ -0,0 +1,154 @@ +# Integration Tests + +End-to-end integration tests that validate the GitHub App installer flow using +local HTTPS mock servers. Tests run the actual installer code against mock +GitHub API servers with self-signed TLS certificates. + +## How It Works + +1. **Mock GitHub API** starts on localhost with a self-signed TLS certificate +2. **Installer handler** is configured to use the mock server URL +3. **Test scenarios** execute HTTP requests against the installer +4. **Requests to GitHub API** are captured and matched against expected calls +5. **Store state** is verified after each scenario completes +6. **Reload triggers** are tracked to verify the installer triggers reloads + +Key advantage: Tests run against production code paths with real HTTP +handling - no mocking of internal packages required. + +## Running Tests + +```bash +# Run integration tests only +make test-integration + +# Run integration tests with verbose output +make test-integration-v + +# Run all tests (unit + integration) +make test-all + +# Run a specific scenario by name +go test -tags=integration ./integration/... -run "successful_manifest" +``` + +## Test Scenarios + +Scenarios are defined in `testdata/scenarios.yaml`. Each scenario specifies: + +- **config**: Installer configuration overrides +- **mock_responses**: Canned GitHub API responses +- **preset_credentials**: Optional pre-seeded credentials +- **steps**: HTTP requests to execute +- **expected_store**: Expected store state after test +- **expected_calls**: Expected HTTP calls to mock GitHub +- **expect_reload**: Whether a reload should be triggered + +### Example Scenario + +```yaml +- name: "successful_manifest_exchange" + description: "Complete GitHub App manifest flow with valid code" + config: + app_display_name: "Test App" + mock_responses: + - method: POST + path: /api/v3/app-manifests/*/conversions + status: 201 + body: | + { + "id": 12345, + "slug": "test-app", + "client_id": "Iv1.abc123", + "pem": "-----BEGIN RSA PRIVATE KEY-----\n...\n-----END RSA PRIVATE KEY-----" + } + steps: + - action: request + method: GET + path: /callback?code=valid-code + expect_status: 200 + expect_body_contains: + - "test-app" + - "12345" + expected_store: + registered: true + app_id: 12345 + app_slug: "test-app" + expected_calls: + - method: POST + path: /api/v3/app-manifests/*/conversions + expect_reload: true +``` + +### Path Matching + +Mock responses and expected calls support wildcard matching: +- `*` matches any single path segment +- Example: `/api/v3/app-manifests/*/conversions` matches + `/api/v3/app-manifests/abc123/conversions` + +## Adding Tests + +1. Add a new scenario to `testdata/scenarios.yaml` +2. Define the mock responses needed +3. Specify the steps to execute +4. Define expected outcomes (store state, API calls, reload) +5. Run with `make test-integration` + +## Architecture + +``` ++-------------------+ +| Test Scenario | ++---------+---------+ + | + v ++---------+---------+ +| installer.Handler | +| .ServeHTTP() | ++---------+---------+ + | + v ++---------+---------+ +-------------------------+ +| exchangeCode() +------>| Mock GitHub HTTPS Server| ++---------+---------+ | (localhost) | + | +------------+------------+ + | | + | v + | +------------+------------+ + | | Record Request | + | +------------+------------+ + | | + | v + | +------------+------------+ + | | Return Mock Response | + | +-------------------------+ + | + v ++---------+-------------------+ +| configstore.LocalEnvFileStore| +| .Save() | ++---------+-------------------+ + | + v ++---------+---------+ +| configwait | +| .TriggerReload() | ++---------+---------+ + | + v ++---------+-------------------+ +| Verify: | +| - Store State | +| - Expected Calls | +| - Reload Counter | ++-----------------------------+ +``` + +## Notes + +- Each scenario runs with a fresh temp directory and store +- Self-signed TLS certificates are generated per test run +- The installer uses `/api/v3/app-manifests/*/conversions` for non-github.com + URLs +- Tests require the `integration` build tag: `-tags=integration` diff --git a/integration/integration_test.go b/integration/integration_test.go new file mode 100644 index 0000000..07ddf21 --- /dev/null +++ b/integration/integration_test.go @@ -0,0 +1,62 @@ +// Copyright 2025 CruxStack +// SPDX-License-Identifier: MIT + +//go:build integration + +package integration + +import ( + "os" + "path/filepath" + "testing" +) + +func TestIntegrationScenarios(t *testing.T) { + // Find scenarios file relative to this test file + scenariosPath := filepath.Join("testdata", "scenarios.yaml") + if _, err := os.Stat(scenariosPath); os.IsNotExist(err) { + t.Fatalf("scenarios file not found: %s", scenariosPath) + } + + scenarios, err := LoadScenarios(scenariosPath) + if err != nil { + t.Fatalf("load scenarios: %v", err) + } + + if len(scenarios) == 0 { + t.Fatal("no scenarios found in scenarios.yaml") + } + + verbose := os.Getenv("VERBOSE") == "1" || os.Getenv("VERBOSE") == "true" + runner := NewScenarioRunner(t, verbose) + + for _, scenario := range scenarios { + runner.Run(scenario) + } +} + +// TestMatchPath validates the path matching logic used by the mock server. +func TestMatchPath(t *testing.T) { + tests := []struct { + path string + pattern string + want bool + }{ + {"/app-manifests/abc123/conversions", "/app-manifests/*/conversions", true}, + {"/app-manifests/xyz/conversions", "/app-manifests/*/conversions", true}, + {"/repos/owner/repo/pulls/123", "/repos/*/*/pulls/*", true}, + {"/repos/owner/repo/pulls", "/repos/*/*/pulls/*", false}, // different segment count + {"/exact/match", "/exact/match", true}, + {"/exact/mismatch", "/exact/match", false}, + {"/api/v3/app-manifests/code/conversions", "/api/v3/app-manifests/*/conversions", true}, + } + + for _, tt := range tests { + t.Run(tt.path+"_"+tt.pattern, func(t *testing.T) { + got := matchPath(tt.path, tt.pattern) + if got != tt.want { + t.Errorf("matchPath(%q, %q) = %v, want %v", tt.path, tt.pattern, got, tt.want) + } + }) + } +} diff --git a/integration/mock_github.go b/integration/mock_github.go new file mode 100644 index 0000000..f7c231d --- /dev/null +++ b/integration/mock_github.go @@ -0,0 +1,154 @@ +// Copyright 2025 CruxStack +// SPDX-License-Identifier: MIT + +//go:build integration + +package integration + +import ( + "fmt" + "io" + "net/http" + "strings" + "sync" + "time" +) + +// RequestRecord captures details of an HTTP request made to the mock server. +type RequestRecord struct { + Timestamp time.Time + Method string + Path string + Query string + Headers http.Header + Body string +} + +// MockResponse defines a canned HTTP response for matching requests. +type MockResponse struct { + Method string `yaml:"method"` + Path string `yaml:"path"` + StatusCode int `yaml:"status"` + Headers map[string]string `yaml:"headers,omitempty"` + Body string `yaml:"body"` +} + +// MockGitHubServer simulates the GitHub API for integration testing. +type MockGitHubServer struct { + mu sync.Mutex + requests []RequestRecord + responses map[string]MockResponse + verbose bool +} + +// NewMockGitHubServer creates a new mock GitHub API server. +func NewMockGitHubServer(responses []MockResponse, verbose bool) *MockGitHubServer { + respMap := make(map[string]MockResponse) + for _, r := range responses { + key := fmt.Sprintf("%s:%s", r.Method, r.Path) + respMap[key] = r + } + return &MockGitHubServer{ + requests: make([]RequestRecord, 0), + responses: respMap, + verbose: verbose, + } +} + +// ServeHTTP implements http.Handler. +func (m *MockGitHubServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + r.Body.Close() + + rec := RequestRecord{ + Timestamp: time.Now(), + Method: r.Method, + Path: r.URL.Path, + Query: r.URL.RawQuery, + Headers: r.Header.Clone(), + Body: string(body), + } + + m.mu.Lock() + m.requests = append(m.requests, rec) + m.mu.Unlock() + + if m.verbose { + fmt.Printf(" [mock-github] %s %s\n", r.Method, r.URL.Path) + } + + // Try exact match first + key := fmt.Sprintf("%s:%s", r.Method, r.URL.Path) + if resp, ok := m.responses[key]; ok { + m.writeResponse(w, resp) + return + } + + // Try wildcard matching + for respKey, resp := range m.responses { + parts := strings.SplitN(respKey, ":", 2) + if len(parts) == 2 { + method, pattern := parts[0], parts[1] + if method == r.Method && matchPath(r.URL.Path, pattern) { + m.writeResponse(w, resp) + return + } + } + } + + // No match found + if m.verbose { + fmt.Printf(" [mock-github] no mock response for: %s %s\n", r.Method, r.URL.Path) + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusNotFound) + w.Write([]byte(`{"message":"Not Found"}`)) +} + +func (m *MockGitHubServer) writeResponse(w http.ResponseWriter, resp MockResponse) { + for k, v := range resp.Headers { + w.Header().Set(k, v) + } + if w.Header().Get("Content-Type") == "" { + w.Header().Set("Content-Type", "application/json") + } + w.WriteHeader(resp.StatusCode) + w.Write([]byte(resp.Body)) +} + +// GetRequests returns all captured requests. +func (m *MockGitHubServer) GetRequests() []RequestRecord { + m.mu.Lock() + defer m.mu.Unlock() + reqs := make([]RequestRecord, len(m.requests)) + copy(reqs, m.requests) + return reqs +} + +// Reset clears all recorded requests. +func (m *MockGitHubServer) Reset() { + m.mu.Lock() + defer m.mu.Unlock() + m.requests = make([]RequestRecord, 0) +} + +// matchPath checks if a path matches a pattern with wildcard support. +func matchPath(path, pattern string) bool { + pathParts := strings.Split(strings.Trim(path, "/"), "/") + patternParts := strings.Split(strings.Trim(pattern, "/"), "/") + + if len(pathParts) != len(patternParts) { + return false + } + + for i, patternPart := range patternParts { + if patternPart == "*" { + continue + } + if pathParts[i] != patternPart { + return false + } + } + + return true +} diff --git a/integration/scenario.go b/integration/scenario.go new file mode 100644 index 0000000..e43aeb4 --- /dev/null +++ b/integration/scenario.go @@ -0,0 +1,338 @@ +// Copyright 2025 CruxStack +// SPDX-License-Identifier: MIT + +//go:build integration + +package integration + +import ( + "context" + "crypto/tls" + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "gopkg.in/yaml.v3" + + "github.com/cruxstack/github-app-setup-go/configstore" + "github.com/cruxstack/github-app-setup-go/configwait" + "github.com/cruxstack/github-app-setup-go/installer" +) + +// Scenario defines a single integration test case. +type Scenario struct { + Name string `yaml:"name"` + Description string `yaml:"description,omitempty"` + + // Config overrides for the installer + Config ScenarioConfig `yaml:"config,omitempty"` + + // Mock responses from GitHub API + MockResponses []MockResponse `yaml:"mock_responses,omitempty"` + + // Preset credentials to seed the store before the test + PresetCredentials *PresetCredentials `yaml:"preset_credentials,omitempty"` + + // Test steps to execute + Steps []Step `yaml:"steps"` + + // Expected state after test completion + ExpectedStore *ExpectedStore `yaml:"expected_store,omitempty"` + + // Expected HTTP calls to mock GitHub + ExpectedCalls []ExpectedCall `yaml:"expected_calls,omitempty"` + + // Whether a reload should have been triggered + ExpectReload bool `yaml:"expect_reload,omitempty"` +} + +// ScenarioConfig holds installer configuration overrides. +type ScenarioConfig struct { + AppDisplayName string `yaml:"app_display_name,omitempty"` + GitHubOrg string `yaml:"github_org,omitempty"` + WebhookURL string `yaml:"webhook_url,omitempty"` +} + +// PresetCredentials allows seeding the store with existing credentials. +type PresetCredentials struct { + AppID int64 `yaml:"app_id"` + AppSlug string `yaml:"app_slug"` + ClientID string `yaml:"client_id,omitempty"` + ClientSecret string `yaml:"client_secret,omitempty"` + WebhookSecret string `yaml:"webhook_secret,omitempty"` + PrivateKey string `yaml:"private_key,omitempty"` + HTMLURL string `yaml:"html_url,omitempty"` +} + +// Step defines a single action in the test scenario. +type Step struct { + Action string `yaml:"action"` + Method string `yaml:"method,omitempty"` + Path string `yaml:"path,omitempty"` + ExpectStatus int `yaml:"expect_status,omitempty"` + ExpectBodyContains []string `yaml:"expect_body_contains,omitempty"` + ExpectRedirect string `yaml:"expect_redirect,omitempty"` +} + +// ExpectedStore defines the expected state of the store after the test. +type ExpectedStore struct { + Registered bool `yaml:"registered"` + InstallerDisabled bool `yaml:"installer_disabled,omitempty"` + AppID int64 `yaml:"app_id,omitempty"` + AppSlug string `yaml:"app_slug,omitempty"` +} + +// ExpectedCall defines an expected HTTP call to the mock server. +type ExpectedCall struct { + Method string `yaml:"method"` + Path string `yaml:"path"` +} + +// LoadScenarios reads scenarios from a YAML file. +func LoadScenarios(path string) ([]Scenario, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("read scenarios file: %w", err) + } + + var scenarios []Scenario + if err := yaml.Unmarshal(data, &scenarios); err != nil { + return nil, fmt.Errorf("parse scenarios: %w", err) + } + + return scenarios, nil +} + +// ScenarioRunner executes integration test scenarios. +type ScenarioRunner struct { + t *testing.T + verbose bool +} + +// NewScenarioRunner creates a new scenario runner. +func NewScenarioRunner(t *testing.T, verbose bool) *ScenarioRunner { + return &ScenarioRunner{t: t, verbose: verbose} +} + +// Run executes a single scenario. +func (r *ScenarioRunner) Run(scenario Scenario) { + r.t.Run(scenario.Name, func(t *testing.T) { + if r.verbose { + t.Logf("Running scenario: %s", scenario.Name) + if scenario.Description != "" { + t.Logf(" Description: %s", scenario.Description) + } + } + + // Get shared TLS certificate (generated once at test startup) + tlsCert, certPool, err := getSharedTLSCert() + if err != nil { + t.Fatalf("get TLS cert: %v", err) + } + + // Create mock GitHub server + mockGitHub := NewMockGitHubServer(scenario.MockResponses, r.verbose) + githubServer := httptest.NewUnstartedServer(mockGitHub) + githubServer.TLS = &tls.Config{ + Certificates: []tls.Certificate{tlsCert}, + } + githubServer.StartTLS() + defer githubServer.Close() + + // Create temp directory for .env file + tempDir := t.TempDir() + envFilePath := filepath.Join(tempDir, ".env") + + // Create store + store := configstore.NewLocalEnvFileStore(envFilePath) + + // Preset credentials if specified + if scenario.PresetCredentials != nil { + creds := &configstore.AppCredentials{ + AppID: scenario.PresetCredentials.AppID, + AppSlug: scenario.PresetCredentials.AppSlug, + ClientID: scenario.PresetCredentials.ClientID, + ClientSecret: scenario.PresetCredentials.ClientSecret, + WebhookSecret: scenario.PresetCredentials.WebhookSecret, + PrivateKey: scenario.PresetCredentials.PrivateKey, + HTMLURL: scenario.PresetCredentials.HTMLURL, + } + // Fill in defaults for required fields + if creds.PrivateKey == "" { + rsaKey, err := getSharedRSAKeyPEM() + if err != nil { + t.Fatalf("get RSA key: %v", err) + } + creds.PrivateKey = rsaKey + } + if err := store.Save(context.Background(), creds); err != nil { + t.Fatalf("preset credentials: %v", err) + } + } + + // Reset reload counter + configwait.ResetReloadCounter() + + // Create installer handler + cfg := installer.Config{ + Store: store, + GitHubURL: githubServer.URL, + AppDisplayName: "GitHub App", + } + if scenario.Config.AppDisplayName != "" { + cfg.AppDisplayName = scenario.Config.AppDisplayName + } + if scenario.Config.GitHubOrg != "" { + cfg.GitHubOrg = scenario.Config.GitHubOrg + } + if scenario.Config.WebhookURL != "" { + cfg.WebhookURL = scenario.Config.WebhookURL + } + + handler, err := installer.New(cfg) + if err != nil { + t.Fatalf("create installer: %v", err) + } + + // Create test server for installer (also HTTPS) + installerServer := httptest.NewUnstartedServer(handler) + installerServer.TLS = &tls.Config{ + Certificates: []tls.Certificate{tlsCert}, + } + installerServer.StartTLS() + defer installerServer.Close() + + // Create HTTP client that trusts our self-signed cert + httpClient := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: certPool, + }, + }, + // Don't follow redirects automatically - we want to inspect them + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + Timeout: 10 * time.Second, + } + + originalTransport := http.DefaultTransport + http.DefaultTransport = &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: certPool, + }, + } + defer func() { http.DefaultTransport = originalTransport }() + + // Execute test steps + for i, step := range scenario.Steps { + if r.verbose { + t.Logf(" Step %d: %s %s", i+1, step.Method, step.Path) + } + + switch step.Action { + case "request": + r.executeRequestStep(t, httpClient, installerServer.URL, step) + default: + t.Fatalf("unknown action: %s", step.Action) + } + } + + time.Sleep(50 * time.Millisecond) + + // Verify expected store state + if scenario.ExpectedStore != nil { + status, err := store.Status(context.Background()) + if err != nil { + t.Fatalf("get store status: %v", err) + } + + if status.Registered != scenario.ExpectedStore.Registered { + t.Errorf("store.Registered = %v, want %v", status.Registered, scenario.ExpectedStore.Registered) + } + if status.InstallerDisabled != scenario.ExpectedStore.InstallerDisabled { + t.Errorf("store.InstallerDisabled = %v, want %v", status.InstallerDisabled, scenario.ExpectedStore.InstallerDisabled) + } + if scenario.ExpectedStore.AppID != 0 && status.AppID != scenario.ExpectedStore.AppID { + t.Errorf("store.AppID = %d, want %d", status.AppID, scenario.ExpectedStore.AppID) + } + if scenario.ExpectedStore.AppSlug != "" && status.AppSlug != scenario.ExpectedStore.AppSlug { + t.Errorf("store.AppSlug = %q, want %q", status.AppSlug, scenario.ExpectedStore.AppSlug) + } + } + + // Verify expected HTTP calls to mock GitHub + if len(scenario.ExpectedCalls) > 0 { + requests := mockGitHub.GetRequests() + for _, expected := range scenario.ExpectedCalls { + found := false + for _, req := range requests { + if req.Method == expected.Method && matchPath(req.Path, expected.Path) { + found = true + break + } + } + if !found { + t.Errorf("expected call not found: %s %s", expected.Method, expected.Path) + t.Logf("actual calls:") + for _, req := range requests { + t.Logf(" %s %s", req.Method, req.Path) + } + } + } + } + + // Verify reload was triggered if expected + if scenario.ExpectReload { + count := configwait.GetReloadCount() + if count == 0 { + t.Errorf("expected reload to be triggered, but it was not") + } + } + }) +} + +func (r *ScenarioRunner) executeRequestStep(t *testing.T, client *http.Client, baseURL string, step Step) { + url := baseURL + step.Path + req, err := http.NewRequest(step.Method, url, nil) + if err != nil { + t.Fatalf("create request: %v", err) + } + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("execute request: %v", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read response body: %v", err) + } + + // Check status code + if step.ExpectStatus != 0 && resp.StatusCode != step.ExpectStatus { + t.Errorf("%s %s: status = %d, want %d\nBody: %s", step.Method, step.Path, resp.StatusCode, step.ExpectStatus, string(body)) + } + + // Check body contains expected strings + for _, expected := range step.ExpectBodyContains { + if !strings.Contains(string(body), expected) { + t.Errorf("%s %s: body does not contain %q\nBody: %s", step.Method, step.Path, expected, string(body)) + } + } + + // Check redirect location + if step.ExpectRedirect != "" { + location := resp.Header.Get("Location") + if location != step.ExpectRedirect { + t.Errorf("%s %s: redirect = %q, want %q", step.Method, step.Path, location, step.ExpectRedirect) + } + } +} diff --git a/integration/testdata/scenarios.yaml b/integration/testdata/scenarios.yaml new file mode 100644 index 0000000..f34f51e --- /dev/null +++ b/integration/testdata/scenarios.yaml @@ -0,0 +1,236 @@ +# Integration Test Scenarios for github-app-setup-go +# +# Each scenario tests a specific flow through the installer with mock GitHub API responses. +# Tests are executed in order, each with a fresh environment (temp directory, new store). + +# ============================================================================= +# Setup Page Tests +# ============================================================================= + +- name: "setup_page_renders" + description: "GET /setup returns the installation form when not registered" + config: + app_display_name: "My Test App" + steps: + - action: request + method: GET + path: /setup + expect_status: 200 + expect_body_contains: + - "My Test App" + - "Create GitHub App" + expected_store: + registered: false + +- name: "root_redirects_to_setup" + description: "GET / redirects to /setup when not registered" + steps: + - action: request + method: GET + path: / + expect_status: 302 + expect_redirect: /setup + +# ============================================================================= +# Manifest Exchange Tests +# ============================================================================= + +- name: "successful_manifest_exchange" + description: "Complete GitHub App manifest flow with valid code" + config: + app_display_name: "Test App" + mock_responses: + # Note: For non-github.com URLs, the API path includes /api/v3/ + - method: POST + path: /api/v3/app-manifests/validcode1234567890/conversions + status: 201 + body: | + { + "id": 12345, + "slug": "test-app", + "client_id": "Iv1.abc123def456", + "client_secret": "secret_abc123", + "webhook_secret": "whsec_xyz789", + "pem": "-----BEGIN RSA PRIVATE KEY-----\nMIIEpAIBAAKCAQEA0Z3VS5JJcds3xfn/ygWyF8PbnGy0AHB7MeXvzlsOgQpEwMKP\nrNYSEwpTMkpGNL8z0gVbPfJAwnqPbPLVJJWz0hpFNLfErL3wLqXW6FmmYKX7sGBO\nkgJQWoKzaMgSMDonhDmC9nLGnanV/WwnJ7hi7xwbLPJYR4NLnkGOKFVBFbg0BAZF\nvdGadVRZz0UuZkpZVMMPHdZPl1dCYrOnOGtCcGq0OyMmMklNMQBcB64CkZmQpGNR\nxaLwQPP7PjKXz8IxhqY1m1kNqR4wLBMFbPzXY6jUzK+p1TKMBMmklPxPiVUBpddV\nSfKnOK8pCIWyiQIDAQABAoIBAC9rsiMjXhwSk8lVPXBnMPNSRD2K9GFoNlMiYFIh\nrxrBD0bDHqLwUjYrFpZH8S3vkgMCn+4ZbPohUfFwxlOdpGQo0MAlkL3k84bvljDa\nn2hDYvPa1jHCktXvPxfhP9toE9cts01WnCGE4V+gQXNPNWN4VCV4O0NlKxsQ8aOo\nZMaKjzQPBc1SKeEALs9m5k3JgBJ9TpYlovHQ8HLTlSBNuMrjaQYGAIDAi+i7bE8X\nnLBEbghVUMgsVnJBaQIgJqTlEkvLnmLjviMYczsakPaoXfnPfLmPi7GjRrVPiAMp\nHYlMz8L2S9CRLqHv9qKMSQqVpNmVbYzPjYsSv2R9y4ECgYEA7o1qfQ9IR+kLkhM+\nXAM6Q+sTLjPVq+ONmy9vJwMGey0Dt4BjP3nPaWTg8YfLmSO8pW3KOuLV4ZPBsnoE\nkXSyPDF8pZmHOE32he3BnAM7LfPa1w9mnVYsLzRvg2zXl0dyuz4FlmJEMbvLOdkr\nbruFVM/S3FQOII9sFgkbXtKH2hkCgYEA4QrBuNCXvaFSyx6qaQL0Cj3T5OkNMb98\nlXkrPNvPS4Y3KsHrLoYr3ZN1p93MjFgf4Vf7Kv5W3skCKmMt8wQZQjkMCiiQ5zhi\njQS0aR5p7C0B4IujCJ+tqWVbXCBbpAtjrvQc5LPKymph6SXFWJ0Qk9prXXB+vpZc\n2L3o7F8TsEkCgYEAxk8DRnpBdR2NfNPJj5IvL3bDJQI3T7w2khFUn3JgO0L8tr2X\n6Yq4z9G6a8w5vK3kPlO5N1R8HqM3skRUhPLFnrKPmR5xPyK3S5hVf5i8tLs1oHif\nN2h4pM3sOoXmsjcH9H8i7rJNEfkJa5UA8mrnuBchpBdSdMv+B0fBnPCk64kCgYBP\nmTTi0xqyMTdCJvLN5bF5jsDqhp6ADn7DJ35dIwMJUDXHipNLkMXhMsJfH9wIjwDX\nGWfp8jCsgI0NKnLbPLU5C4xWP2xQp5rN9GWC9YlhVGQPHQEoDHTh5B7uPZKYo5S9\nA5v7d2vFMBXKPfDO2R8/LsZ8MGAO9Y2oEi3EFpUb4QKBgQDjXD0qCiCH89uMJS4h\nEUUSd19JqWtSJqSIuGcNYqHHHxplWlOPINOHD3t2WqLdLkLY7O2i8vCkOm+g41cI\n4a1MdCb8sYDm8HEi2LfPLfQ/SqKzx0nT7iDYGt8pZxJxYiEu7AqMb/HJwma+CQAU\nH5bQd0RoEhYWIKBXwT/k/wfmCw==\n-----END RSA PRIVATE KEY-----", + "html_url": "https://github.com/apps/test-app" + } + steps: + - action: request + method: GET + path: /callback?code=validcode1234567890 + expect_status: 200 + expect_body_contains: + - "test-app" + - "12345" + expected_store: + registered: true + app_id: 12345 + app_slug: "test-app" + expected_calls: + - method: POST + path: /api/v3/app-manifests/validcode1234567890/conversions + expect_reload: true + +- name: "invalid_code_returns_error" + description: "GitHub API returns 404 for invalid code" + mock_responses: + - method: POST + path: /api/v3/app-manifests/invalidcode1234567/conversions + status: 404 + body: '{"message": "Not Found"}' + steps: + - action: request + method: GET + path: /callback?code=invalidcode1234567 + expect_status: 500 + expect_body_contains: + - "Failed to exchange code" + expected_store: + registered: false + expected_calls: + - method: POST + path: /api/v3/app-manifests/invalidcode1234567/conversions + +- name: "missing_code_parameter" + description: "Callback without code parameter returns 400" + steps: + - action: request + method: GET + path: /callback + expect_status: 400 + expect_body_contains: + - "Missing code" + expected_store: + registered: false + +# ============================================================================= +# Already Registered Tests +# ============================================================================= + +- name: "already_registered_shows_success" + description: "Visiting /setup when already registered shows success page" + preset_credentials: + app_id: 99999 + app_slug: "existing-app" + client_id: "Iv1.existing123" + client_secret: "existing_secret" + webhook_secret: "existing_webhook" + html_url: "https://github.com/apps/existing-app" + steps: + - action: request + method: GET + path: /setup + expect_status: 200 + expect_body_contains: + - "existing-app" + - "99999" + expected_store: + registered: true + app_id: 99999 + app_slug: "existing-app" + expected_calls: [] + +# ============================================================================= +# Disable Installer Tests +# ============================================================================= + +- name: "disable_installer_after_registration" + description: "POST /setup/disable marks installer as disabled" + preset_credentials: + app_id: 12345 + app_slug: "my-app" + client_id: "Iv1.test" + client_secret: "test_secret" + webhook_secret: "test_webhook" + html_url: "https://github.com/apps/my-app" + steps: + - action: request + method: POST + path: /setup/disable + expect_status: 303 + expected_store: + registered: true + installer_disabled: true + +- name: "disabled_installer_returns_404" + description: "GET / returns 404 when installer is disabled" + preset_credentials: + app_id: 12345 + app_slug: "my-app" + client_id: "Iv1.test" + client_secret: "test_secret" + webhook_secret: "test_webhook" + html_url: "https://github.com/apps/my-app" + steps: + # First disable the installer + - action: request + method: POST + path: /setup/disable + expect_status: 303 + # Then verify root returns 404 + - action: request + method: GET + path: / + expect_status: 404 + expected_store: + registered: true + installer_disabled: true + +- name: "disable_without_registration_fails" + description: "POST /setup/disable fails when not registered" + steps: + - action: request + method: POST + path: /setup/disable + expect_status: 400 + expect_body_contains: + - "Cannot disable" + expected_store: + registered: false + +# ============================================================================= +# Error Handling Tests +# ============================================================================= + +- name: "github_api_server_error" + description: "GitHub API returns 500 server error" + mock_responses: + - method: POST + path: /api/v3/app-manifests/errorcode12345678/conversions + status: 500 + body: '{"message": "Internal Server Error"}' + steps: + - action: request + method: GET + path: /callback?code=errorcode12345678 + expect_status: 500 + expect_body_contains: + - "Failed to exchange code" + expected_store: + registered: false + +- name: "unknown_path_returns_404" + description: "Unknown paths return 404" + steps: + - action: request + method: GET + path: /unknown/path + expect_status: 404 + +# ============================================================================= +# Organization Scoped Tests +# ============================================================================= + +- name: "org_scoped_installer" + description: "Installer configured for organization shows correct form action" + config: + app_display_name: "Org App" + github_org: "my-organization" + steps: + - action: request + method: GET + path: /setup + expect_status: 200 + expect_body_contains: + - "my-organization" + - "Org App" diff --git a/integration/tls.go b/integration/tls.go new file mode 100644 index 0000000..0bedcc5 --- /dev/null +++ b/integration/tls.go @@ -0,0 +1,129 @@ +// Copyright 2025 CruxStack +// SPDX-License-Identifier: MIT + +//go:build integration + +package integration + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "net" + "sync" + "time" +) + +// sharedCrypto holds cryptographic materials generated once at test startup. +var sharedCrypto struct { + once sync.Once + tlsCert tls.Certificate + certPool *x509.CertPool + rsaKeyPEM string + err error +} + +// initSharedCrypto generates all cryptographic materials once at test startup. +func initSharedCrypto() { + sharedCrypto.once.Do(func() { + // Generate TLS certificate + sharedCrypto.tlsCert, sharedCrypto.certPool, sharedCrypto.err = generateSelfSignedCert() + if sharedCrypto.err != nil { + return + } + + // Generate RSA key for GitHub App simulation + sharedCrypto.rsaKeyPEM, sharedCrypto.err = generateRSAKeyPEM() + }) +} + +// getSharedTLSCert returns the shared TLS certificate. +func getSharedTLSCert() (tls.Certificate, *x509.CertPool, error) { + initSharedCrypto() + return sharedCrypto.tlsCert, sharedCrypto.certPool, sharedCrypto.err +} + +// getSharedRSAKeyPEM returns the shared RSA private key PEM. +func getSharedRSAKeyPEM() (string, error) { + initSharedCrypto() + return sharedCrypto.rsaKeyPEM, sharedCrypto.err +} + +// generateSelfSignedCert creates a self-signed TLS certificate for localhost. +func generateSelfSignedCert() (tls.Certificate, *x509.CertPool, error) { + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return tls.Certificate{}, nil, fmt.Errorf("generate private key: %w", err) + } + + serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 64)) + if err != nil { + return tls.Certificate{}, nil, fmt.Errorf("generate serial number: %w", err) + } + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"Integration Tests"}, + CommonName: "localhost", + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(1 * time.Hour), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("::1")}, + DNSNames: []string{"localhost"}, + } + + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) + if err != nil { + return tls.Certificate{}, nil, fmt.Errorf("create certificate: %w", err) + } + + certPEM := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: certDER, + }) + + keyDER, err := x509.MarshalECPrivateKey(privateKey) + if err != nil { + return tls.Certificate{}, nil, fmt.Errorf("marshal private key: %w", err) + } + keyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "EC PRIVATE KEY", + Bytes: keyDER, + }) + + tlsCert, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + return tls.Certificate{}, nil, fmt.Errorf("create key pair: %w", err) + } + + certPool := x509.NewCertPool() + certPool.AppendCertsFromPEM(certPEM) + + return tlsCert, certPool, nil +} + +// generateRSAKeyPEM generates a fresh RSA private key in PEM format. +func generateRSAKeyPEM() (string, error) { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return "", fmt.Errorf("generate RSA key: %w", err) + } + + keyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(privateKey), + }) + + return string(keyPEM), nil +} diff --git a/ssmresolver/resolver.go b/ssmresolver/resolver.go new file mode 100644 index 0000000..461a592 --- /dev/null +++ b/ssmresolver/resolver.go @@ -0,0 +1,194 @@ +// Copyright 2025 CruxStack +// SPDX-License-Identifier: MIT + +// Package ssmresolver provides utilities for resolving AWS SSM Parameter Store +// ARNs in environment variables. +package ssmresolver + +import ( + "context" + "fmt" + "os" + "regexp" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/ssm" + "github.com/chainguard-dev/clog" +) + +const ( + EnvMaxRetries = "CONFIG_WAIT_MAX_RETRIES" + EnvRetryInterval = "CONFIG_WAIT_RETRY_INTERVAL" +) + +const ( + DefaultMaxRetries = 5 + DefaultRetryInterval = 1 * time.Second +) + +var ssmARNPattern = regexp.MustCompile(`^arn:aws:ssm:[^:]+:[^:]+:parameter/(.+)$`) + +// Client defines the interface for SSM operations. +type Client interface { + GetParameter(ctx context.Context, params *ssm.GetParameterInput, optFns ...func(*ssm.Options)) (*ssm.GetParameterOutput, error) +} + +// Resolver handles SSM parameter resolution. +type Resolver struct { + client Client +} + +// New creates a Resolver with the default AWS configuration. +func New(ctx context.Context) (*Resolver, error) { + cfg, err := config.LoadDefaultConfig(ctx) + if err != nil { + return nil, fmt.Errorf("failed to load AWS config: %w", err) + } + return &Resolver{ + client: ssm.NewFromConfig(cfg), + }, nil +} + +// NewWithClient creates a Resolver with a custom SSM client. +func NewWithClient(client Client) *Resolver { + return &Resolver{client: client} +} + +// IsSSMARN checks if the given value is an SSM Parameter Store ARN. +func IsSSMARN(value string) bool { + return ssmARNPattern.MatchString(value) +} + +// ExtractParameterName extracts the parameter name from an SSM ARN. +func ExtractParameterName(arn string) (string, bool) { + matches := ssmARNPattern.FindStringSubmatch(arn) + if len(matches) != 2 { + return "", false + } + paramName := matches[1] + if !strings.HasPrefix(paramName, "/") { + paramName = "/" + paramName + } + return paramName, true +} + +// ResolveValue resolves an SSM ARN to its value, or returns it unchanged. +func (r *Resolver) ResolveValue(ctx context.Context, value string) (string, error) { + if !IsSSMARN(value) { + return value, nil + } + + paramName, ok := ExtractParameterName(value) + if !ok { + return "", fmt.Errorf("invalid SSM ARN format: %s", value) + } + + resp, err := r.client.GetParameter(ctx, &ssm.GetParameterInput{ + Name: ¶mName, + WithDecryption: ptr(true), + }) + if err != nil { + return "", fmt.Errorf("failed to get SSM parameter %s: %w", paramName, err) + } + + if resp.Parameter == nil || resp.Parameter.Value == nil { + return "", fmt.Errorf("SSM parameter %s has no value", paramName) + } + + return *resp.Parameter.Value, nil +} + +// ResolveEnvironment resolves any SSM ARN values in environment variables. +func (r *Resolver) ResolveEnvironment(ctx context.Context) error { + for _, env := range os.Environ() { + parts := strings.SplitN(env, "=", 2) + if len(parts) != 2 { + continue + } + key, value := parts[0], parts[1] + + if IsSSMARN(value) { + resolved, err := r.ResolveValue(ctx, value) + if err != nil { + return fmt.Errorf("failed to resolve %s: %w", key, err) + } + if err := os.Setenv(key, resolved); err != nil { + return fmt.Errorf("failed to set %s: %w", key, err) + } + } + } + return nil +} + +// ResolveEnvironmentWithDefaults creates a resolver and resolves all env vars. +func ResolveEnvironmentWithDefaults(ctx context.Context) error { + resolver, err := New(ctx) + if err != nil { + return err + } + return resolver.ResolveEnvironment(ctx) +} + +// RetryConfig configures retry behavior for SSM resolution. +type RetryConfig struct { + MaxRetries int + RetryInterval time.Duration +} + +// NewRetryConfigFromEnv creates a RetryConfig from environment variables. +func NewRetryConfigFromEnv() RetryConfig { + cfg := RetryConfig{ + MaxRetries: DefaultMaxRetries, + RetryInterval: DefaultRetryInterval, + } + + if v := os.Getenv(EnvMaxRetries); v != "" { + var n int + if _, err := fmt.Sscanf(v, "%d", &n); err == nil && n > 0 { + cfg.MaxRetries = n + } + } + + if v := os.Getenv(EnvRetryInterval); v != "" { + if d, err := time.ParseDuration(v); err == nil && d > 0 { + cfg.RetryInterval = d + } + } + + return cfg +} + +// ResolveEnvironmentWithRetry resolves all environment variables with retry logic. +func ResolveEnvironmentWithRetry(ctx context.Context, cfg RetryConfig) error { + log := clog.FromContext(ctx) + var lastErr error + + for attempt := 1; attempt <= cfg.MaxRetries; attempt++ { + err := ResolveEnvironmentWithDefaults(ctx) + if err == nil { + if attempt > 1 { + log.Infof("[ssmresolver] SSM parameters resolved successfully after %d attempts", attempt) + } + return nil + } + + lastErr = err + log.Warnf("[ssmresolver] attempt %d/%d failed: %v", attempt, cfg.MaxRetries, err) + + if attempt < cfg.MaxRetries { + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(cfg.RetryInterval): + } + } + } + + return lastErr +} + +func ptr[T any](v T) *T { + return &v +} diff --git a/ssmresolver/resolver_test.go b/ssmresolver/resolver_test.go new file mode 100644 index 0000000..112ce34 --- /dev/null +++ b/ssmresolver/resolver_test.go @@ -0,0 +1,300 @@ +// Copyright 2025 CruxStack +// SPDX-License-Identifier: MIT + +package ssmresolver + +import ( + "context" + "errors" + "testing" + + "github.com/aws/aws-sdk-go-v2/service/ssm" + "github.com/aws/aws-sdk-go-v2/service/ssm/types" +) + +func TestIsSSMARN(t *testing.T) { + tests := []struct { + name string + value string + want bool + }{ + // Valid ARNs + { + name: "valid ARN with simple path", + value: "arn:aws:ssm:us-east-1:123456789012:parameter/my-param", + want: true, + }, + { + name: "valid ARN with nested path", + value: "arn:aws:ssm:us-west-2:111122223333:parameter/octo-sts/prod/GITHUB_APP_ID", + want: true, + }, + { + name: "valid ARN with leading slash in path", + value: "arn:aws:ssm:eu-west-1:999888777666:parameter//app/secret", + want: true, + }, + + // Invalid ARNs + { + name: "empty string", + value: "", + want: false, + }, + { + name: "plain value", + value: "my-secret-value", + want: false, + }, + { + name: "wrong service", + value: "arn:aws:s3:us-east-1:123456789012:bucket/my-bucket", + want: false, + }, + { + name: "missing parameter prefix", + value: "arn:aws:ssm:us-east-1:123456789012:secret/my-secret", + want: false, + }, + { + name: "incomplete ARN", + value: "arn:aws:ssm:us-east-1:123456789012", + want: false, + }, + { + name: "ARN-like but malformed", + value: "arn:aws:ssm:parameter/test", + want: false, + }, + { + name: "URL that looks like ARN", + value: "https://arn:aws:ssm:us-east-1:123456789012:parameter/test", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := IsSSMARN(tt.value) + if got != tt.want { + t.Errorf("IsSSMARN(%q) = %v, want %v", tt.value, got, tt.want) + } + }) + } +} + +func TestExtractParameterName(t *testing.T) { + tests := []struct { + name string + arn string + wantName string + wantFound bool + }{ + { + name: "simple parameter name", + arn: "arn:aws:ssm:us-east-1:123456789012:parameter/my-param", + wantName: "/my-param", + wantFound: true, + }, + { + name: "nested path without leading slash", + arn: "arn:aws:ssm:us-west-2:111122223333:parameter/octo-sts/prod/GITHUB_APP_ID", + wantName: "/octo-sts/prod/GITHUB_APP_ID", + wantFound: true, + }, + { + name: "path already has leading slash is normalized", + arn: "arn:aws:ssm:us-east-1:123456789012:parameter//app/secret", + wantName: "/app/secret", // leading slash in path becomes single slash after normalization + wantFound: true, + }, + { + name: "invalid ARN returns empty", + arn: "not-an-arn", + wantName: "", + wantFound: false, + }, + { + name: "empty string", + arn: "", + wantName: "", + wantFound: false, + }, + { + name: "wrong service ARN", + arn: "arn:aws:s3:us-east-1:123456789012:bucket/my-bucket", + wantName: "", + wantFound: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotName, gotFound := ExtractParameterName(tt.arn) + if gotName != tt.wantName { + t.Errorf("ExtractParameterName(%q) name = %q, want %q", tt.arn, gotName, tt.wantName) + } + if gotFound != tt.wantFound { + t.Errorf("ExtractParameterName(%q) found = %v, want %v", tt.arn, gotFound, tt.wantFound) + } + }) + } +} + +// mockSSMClient implements the Client interface for testing +type mockSSMClient struct { + getParameterFunc func(ctx context.Context, params *ssm.GetParameterInput, optFns ...func(*ssm.Options)) (*ssm.GetParameterOutput, error) +} + +func (m *mockSSMClient) GetParameter(ctx context.Context, params *ssm.GetParameterInput, optFns ...func(*ssm.Options)) (*ssm.GetParameterOutput, error) { + return m.getParameterFunc(ctx, params, optFns...) +} + +func TestResolveValue_NonARN(t *testing.T) { + resolver := NewWithClient(&mockSSMClient{ + getParameterFunc: func(ctx context.Context, params *ssm.GetParameterInput, optFns ...func(*ssm.Options)) (*ssm.GetParameterOutput, error) { + t.Fatal("GetParameter should not be called for non-ARN values") + return nil, nil + }, + }) + + tests := []struct { + name string + value string + }{ + {"plain string", "my-secret-value"}, + {"empty string", ""}, + {"number", "12345"}, + {"URL", "https://example.com"}, + {"JSON", `{"key": "value"}`}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := resolver.ResolveValue(context.Background(), tt.value) + if err != nil { + t.Errorf("ResolveValue(%q) error = %v, want nil", tt.value, err) + } + if got != tt.value { + t.Errorf("ResolveValue(%q) = %q, want unchanged value", tt.value, got) + } + }) + } +} + +func TestResolveValue_ValidARN(t *testing.T) { + expectedValue := "resolved-secret-value" + var capturedParamName string + + resolver := NewWithClient(&mockSSMClient{ + getParameterFunc: func(ctx context.Context, params *ssm.GetParameterInput, optFns ...func(*ssm.Options)) (*ssm.GetParameterOutput, error) { + capturedParamName = *params.Name + return &ssm.GetParameterOutput{ + Parameter: &types.Parameter{ + Value: &expectedValue, + }, + }, nil + }, + }) + + arn := "arn:aws:ssm:us-east-1:123456789012:parameter/my-app/secret" + got, err := resolver.ResolveValue(context.Background(), arn) + + if err != nil { + t.Errorf("ResolveValue() error = %v, want nil", err) + } + if got != expectedValue { + t.Errorf("ResolveValue() = %q, want %q", got, expectedValue) + } + if capturedParamName != "/my-app/secret" { + t.Errorf("GetParameter called with name = %q, want %q", capturedParamName, "/my-app/secret") + } +} + +func TestResolveValue_SSMError(t *testing.T) { + expectedErr := errors.New("SSM access denied") + + resolver := NewWithClient(&mockSSMClient{ + getParameterFunc: func(ctx context.Context, params *ssm.GetParameterInput, optFns ...func(*ssm.Options)) (*ssm.GetParameterOutput, error) { + return nil, expectedErr + }, + }) + + arn := "arn:aws:ssm:us-east-1:123456789012:parameter/my-secret" + _, err := resolver.ResolveValue(context.Background(), arn) + + if err == nil { + t.Error("ResolveValue() expected error, got nil") + } + if !errors.Is(err, expectedErr) { + t.Errorf("ResolveValue() error should wrap %v", expectedErr) + } +} + +func TestResolveValue_NilParameter(t *testing.T) { + resolver := NewWithClient(&mockSSMClient{ + getParameterFunc: func(ctx context.Context, params *ssm.GetParameterInput, optFns ...func(*ssm.Options)) (*ssm.GetParameterOutput, error) { + return &ssm.GetParameterOutput{ + Parameter: nil, + }, nil + }, + }) + + arn := "arn:aws:ssm:us-east-1:123456789012:parameter/my-secret" + _, err := resolver.ResolveValue(context.Background(), arn) + + if err == nil { + t.Error("ResolveValue() expected error for nil parameter, got nil") + } +} + +func TestResolveValue_NilValue(t *testing.T) { + resolver := NewWithClient(&mockSSMClient{ + getParameterFunc: func(ctx context.Context, params *ssm.GetParameterInput, optFns ...func(*ssm.Options)) (*ssm.GetParameterOutput, error) { + return &ssm.GetParameterOutput{ + Parameter: &types.Parameter{ + Value: nil, + }, + }, nil + }, + }) + + arn := "arn:aws:ssm:us-east-1:123456789012:parameter/my-secret" + _, err := resolver.ResolveValue(context.Background(), arn) + + if err == nil { + t.Error("ResolveValue() expected error for nil value, got nil") + } +} + +func TestResolveValue_DecryptionEnabled(t *testing.T) { + var capturedWithDecryption *bool + + resolver := NewWithClient(&mockSSMClient{ + getParameterFunc: func(ctx context.Context, params *ssm.GetParameterInput, optFns ...func(*ssm.Options)) (*ssm.GetParameterOutput, error) { + capturedWithDecryption = params.WithDecryption + value := "decrypted" + return &ssm.GetParameterOutput{ + Parameter: &types.Parameter{Value: &value}, + }, nil + }, + }) + + arn := "arn:aws:ssm:us-east-1:123456789012:parameter/encrypted-secret" + _, _ = resolver.ResolveValue(context.Background(), arn) + + if capturedWithDecryption == nil || !*capturedWithDecryption { + t.Error("ResolveValue() should request decryption for SSM parameters") + } +} + +func TestNewRetryConfigFromEnv_Defaults(t *testing.T) { + cfg := NewRetryConfigFromEnv() + + if cfg.MaxRetries != DefaultMaxRetries { + t.Errorf("MaxRetries = %d, want %d", cfg.MaxRetries, DefaultMaxRetries) + } + if cfg.RetryInterval != DefaultRetryInterval { + t.Errorf("RetryInterval = %v, want %v", cfg.RetryInterval, DefaultRetryInterval) + } +}