diff --git a/.github/dependabot.yaml b/.github/dependabot.yaml new file mode 100644 index 0000000..8789fde --- /dev/null +++ b/.github/dependabot.yaml @@ -0,0 +1,11 @@ +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "daily" + + - package-ecosystem: "gomod" + directory: "/" + schedule: + interval: "daily" diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml new file mode 100644 index 0000000..f137095 --- /dev/null +++ b/.github/workflows/ci.yaml @@ -0,0 +1,78 @@ +name: ci + +on: + push: + branches: + - main + pull_request: + branches: + - main + +permissions: + contents: read + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - name: Checkout Code + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + + - name: Setup Go + uses: actions/setup-go@d35c59abb061a4a6fb18e82ac0862c26744d6ab5 # v5 + with: + go-version: '1.25' + + - name: Run go vet + run: go vet ./... + + - name: Check formatting + run: | + unformatted=$(gofmt -l .) + if [ -n "$unformatted" ]; then + echo "The following files are not formatted:" + echo "$unformatted" + exit 1 + fi + + test-unit: + runs-on: ubuntu-latest + steps: + - name: Checkout Code + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + + - name: Setup Go + uses: actions/setup-go@d35c59abb061a4a6fb18e82ac0862c26744d6ab5 # v5 + with: + go-version: '1.25' + + - name: Run unit tests + run: make test-unit + + test-integration: + runs-on: ubuntu-latest + steps: + - name: Checkout Code + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + + - name: Setup Go + uses: actions/setup-go@d35c59abb061a4a6fb18e82ac0862c26744d6ab5 # v5 + with: + go-version: '1.25' + + - name: Run integration tests + run: make test-integration + + build: + runs-on: ubuntu-latest + steps: + - name: Checkout Code + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + + - name: Setup Go + uses: actions/setup-go@d35c59abb061a4a6fb18e82ac0862c26744d6ab5 # v5 + with: + go-version: '1.25' + + - name: Build + run: make build diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml new file mode 100644 index 0000000..c174417 --- /dev/null +++ b/.github/workflows/release.yaml @@ -0,0 +1,29 @@ +name: release + +on: + push: + branches: + - main + +permissions: + contents: write + +jobs: + release: + runs-on: ubuntu-latest + steps: + - name: Checkout Code + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + - name: Bump Version + id: tag_version + uses: mathieudutour/github-tag-action@d28fa2ccfbd16e871a4bdf35e11b3ad1bd56c0c1 # v6.2 + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + default_bump: minor + custom_release_rules: bug:patch:Fixes,chore:patch:Chores,docs:patch:Documentation,feat:minor:Features,refactor:minor:Refactors,test:patch:Tests,ci:patch:Development,dev:patch:Development + - name: Create Release + uses: ncipollo/release-action@b7eabc95ff50cbeeedec83973935c8f306dfcd0b # v1 + with: + tag: ${{ steps.tag_version.outputs.new_tag }} + name: ${{ steps.tag_version.outputs.new_tag }} + body: ${{ steps.tag_version.outputs.changelog }} diff --git a/.github/workflows/semantic-check.yaml b/.github/workflows/semantic-check.yaml new file mode 100644 index 0000000..6f97f9c --- /dev/null +++ b/.github/workflows/semantic-check.yaml @@ -0,0 +1,26 @@ +name: semantic-check +on: + pull_request_target: + types: + - opened + - edited + - synchronize + +permissions: + contents: read + pull-requests: read + +jobs: + main: + name: Semantic Commit Message Check + runs-on: ubuntu-latest + steps: + - name: Checkout Code + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + - uses: amannn/action-semantic-pull-request@e32d7e603df1aa1ba07e981f2a23455dee596825 # v5 + name: Check PR for Semantic Commit Message + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + requireScope: false + validateSingleCommit: true diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..8622f50 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 CruxStack + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..515931a --- /dev/null +++ b/Makefile @@ -0,0 +1,59 @@ +.PHONY: all build test test-unit test-integration lint fmt vet clean tidy help + +GO ?= go +GOFLAGS ?= +PACKAGES := $(shell $(GO) list ./... | grep -v '/examples/' | grep -v '/docs/' | grep -v '/integration') +INTEGRATION_PKG := ./integration/... + +all: fmt vet lint test-unit ## Run all checks and unit tests + +build: ## Build all packages + $(GO) build $(GOFLAGS) ./... + +test: test-unit ## Alias for test-unit + +test-unit: ## Run unit tests only + $(GO) test $(GOFLAGS) $(PACKAGES) + +test-unit-v: ## Run unit tests with verbose output + $(GO) test $(GOFLAGS) -v $(PACKAGES) + +test-integration: ## Run integration tests + $(GO) test $(GOFLAGS) -tags=integration $(INTEGRATION_PKG) + +test-integration-v: ## Run integration tests with verbose output + VERBOSE=1 $(GO) test $(GOFLAGS) -v -tags=integration $(INTEGRATION_PKG) + +test-all: test-unit test-integration ## Run all tests (unit + integration) + +test-v: ## Run unit tests with verbose output (alias for test-unit-v) + $(GO) test $(GOFLAGS) -v $(PACKAGES) + +test-race: ## Run tests with race detector + $(GO) test $(GOFLAGS) -race $(PACKAGES) + +test-cover: ## Run tests with coverage + $(GO) test $(GOFLAGS) -cover $(PACKAGES) + +test-cover-html: ## Run tests and generate HTML coverage report + $(GO) test $(GOFLAGS) -coverprofile=coverage.out $(PACKAGES) + $(GO) tool cover -html=coverage.out -o coverage.html + +lint: ## Run linter (requires golangci-lint) + @which golangci-lint > /dev/null || (echo "golangci-lint not installed, skipping" && exit 0) + golangci-lint run ./... + +fmt: ## Format code + $(GO) fmt ./... + +vet: ## Run go vet + $(GO) vet ./... + +tidy: ## Tidy go.mod + $(GO) mod tidy + +clean: ## Clean build artifacts + rm -f coverage.out coverage.html + +help: ## Show this help + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-15s\033[0m %s\n", $$1, $$2}' diff --git a/README.md b/README.md index f8a2443..a7d5290 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,210 @@ # github-app-setup-go -A Go library for creating and managing GitHub Apps using the [GitHub App Manifest flow](https://docs.github.com/en/apps/sharing-github-apps/registering-a-github-app-from-a-manifest). Provides a web-based installer, multiple credential storage backends, and utilities for configuration management in containerized environments. +A Go library for creating and managing GitHub Apps using the +[GitHub App Manifest flow](https://docs.github.com/en/apps/sharing-github-apps/registering-a-github-app-from-a-manifest). +Provides a web-based installer, multiple credential storage backends, and +utilities for configuration management in containerized environments. +## Features + +- **Web-based installer** - User-friendly UI for creating GitHub Apps with + pre-configured permissions +- **Multiple storage backends** - AWS SSM Parameter Store, `.env` files, or + individual files +- **Hot reload support** - Reload configuration via SIGHUP or programmatic + triggers +- **SSM ARN resolution** - Resolve AWS SSM Parameter Store ARNs in environment + variables (useful for Lambda) +- **Ready gate** - HTTP middleware that returns 503 until configuration is + loaded + +## Installation + +```bash +go get github.com/cruxstack/github-app-setup-go +``` + +## Packages + +| Package | Description | +|---------------|-----------------------------------------------------------| +| `installer` | HTTP handler implementing the GitHub App Manifest flow | +| `configstore` | Storage backends for GitHub App credentials | +| `configwait` | Startup wait logic, ready gate middleware, and reload | +| `ssmresolver` | Resolves SSM Parameter Store ARNs in environment vars | + +## Quick Start + +```go +package main + +import ( + "context" + "log" + "net/http" + + "github.com/cruxstack/github-app-setup-go/configstore" + "github.com/cruxstack/github-app-setup-go/configwait" + "github.com/cruxstack/github-app-setup-go/installer" +) + +func main() { + ctx := context.Background() + + // Create a storage backend (uses STORAGE_MODE env var, defaults to .env file) + store, err := configstore.NewFromEnv() + if err != nil { + log.Fatal(err) + } + + // Define the GitHub App manifest with required permissions + manifest := installer.Manifest{ + URL: "https://example.com", + Public: false, + DefaultPerms: map[string]string{ + "contents": "read", + "pull_requests": "write", + }, + DefaultEvents: []string{"pull_request", "push"}, + } + + // Create the installer handler + installerHandler, err := installer.New(installer.Config{ + Store: store, + Manifest: manifest, + AppDisplayName: "My GitHub App", + }) + if err != nil { + log.Fatal(err) + } + + // Set up routes + mux := http.NewServeMux() + mux.Handle("/setup", installerHandler) + mux.Handle("/callback", installerHandler) + + // Create a ready gate that allows /setup through before app is configured + gate := configwait.NewReadyGate(mux, []string{"/setup", "/callback", "/healthz"}) + + // Start the server + log.Println("Starting server on :8080") + log.Fatal(http.ListenAndServe(":8080", gate)) +} +``` + +## Configuration + +### Environment Variables + +#### Installer + +| Variable | Description | Default | +|--------------------------------|---------------------------------------------|----------------------| +| `GITHUB_URL` | GitHub base URL (for GHE Server) | `https://github.com` | +| `GITHUB_ORG` | Organization (empty = personal account) | - | +| `GITHUB_APP_INSTALLER_ENABLED` | Enable the installer UI (`true`, `1`, `yes`)| - | + +#### Storage + +| Variable | Description | Default | +|---------------------------|----------------------------------------------|-------------| +| `STORAGE_MODE` | Backend: `envfile`, `files`, or `aws-ssm` | `envfile` | +| `STORAGE_DIR` | Directory/path for local storage backends | `./.env` | +| `AWS_SSM_PARAMETER_PREFIX`| SSM parameter path prefix (for `aws-ssm`) | - | +| `AWS_SSM_KMS_KEY_ID` | Custom KMS key for SSM encryption | AWS managed | +| `AWS_SSM_TAGS` | JSON object of tags for SSM parameters | - | + +#### Config Wait + +| Variable | Description | Default | +|-----------------------------|--------------------------------------|---------| +| `CONFIG_WAIT_MAX_RETRIES` | Maximum retry attempts | `30` | +| `CONFIG_WAIT_RETRY_INTERVAL`| Duration between retries (e.g., `2s`)| `2s` | + +## Storage Backends + +### AWS SSM Parameter Store + +Stores credentials as encrypted SecureString parameters: + +```go +store, err := configstore.NewAWSSSMStore("/my-app/prod/", + configstore.WithKMSKey("alias/my-key"), + configstore.WithTags(map[string]string{ + "Environment": "production", + }), +) +``` + +Parameters are stored at paths like `/my-app/prod/GITHUB_APP_ID`, +`/my-app/prod/GITHUB_APP_PRIVATE_KEY`, etc. + +### Local .env File + +Saves credentials to a `.env` file, preserving existing content: + +```go +store := configstore.NewLocalEnvFileStore("./.env") +``` + +### Local Files + +Saves each credential as a separate file: + +```go +store := configstore.NewLocalFileStore("./secrets/") +// Creates: ./secrets/app-id, ./secrets/private-key.pem, etc. +``` + +## Hot Reload + +The library supports hot-reloading configuration via SIGHUP signals or +programmatic triggers: + +```go +// Create a reloader that calls your reload function +reloader := configwait.NewReloader(ctx, gate, func(ctx context.Context) error { + // Reload your configuration here + newHandler := buildHandler() + gate.SetHandler(newHandler) + gate.SetReady() + return nil +}) + +// Set as global reloader (allows installer to trigger reload after saving) +configwait.SetGlobalReloader(reloader) + +// Start listening for SIGHUP +reloader.Start() +``` + +## SSM ARN Resolution + +For Lambda deployments where secrets are passed as SSM ARNs: + +```go +// Resolve all environment variables that contain SSM ARNs +err := ssmresolver.ResolveEnvironmentWithRetry(ctx, ssmresolver.NewRetryConfigFromEnv()) + +// Or resolve a single value +resolver, _ := ssmresolver.New(ctx) +value, err := resolver.ResolveValue(ctx, os.Getenv("MY_SECRET")) +``` + +## Stored Credentials + +After a GitHub App is created, the following credentials are stored: + +| Key | Description | +|---------------------------|---------------------------------------| +| `GITHUB_APP_ID` | The numeric App ID | +| `GITHUB_APP_SLUG` | The app's URL slug | +| `GITHUB_APP_HTML_URL` | URL to the app's GitHub settings page | +| `GITHUB_WEBHOOK_SECRET` | Webhook signature secret | +| `GITHUB_CLIENT_ID` | OAuth client ID | +| `GITHUB_CLIENT_SECRET` | OAuth client secret | +| `GITHUB_APP_PRIVATE_KEY` | Private key (PEM format) | + +## License + +MIT License - Copyright 2025 CruxStack diff --git a/configstore/aws_ssm_store.go b/configstore/aws_ssm_store.go new file mode 100644 index 0000000..127e7f3 --- /dev/null +++ b/configstore/aws_ssm_store.go @@ -0,0 +1,225 @@ +// Copyright 2025 CruxStack +// SPDX-License-Identifier: MIT + +package configstore + +import ( + "context" + "errors" + "fmt" + "strconv" + "strings" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/ssm" + "github.com/aws/aws-sdk-go-v2/service/ssm/types" +) + +// SSMClient defines the interface for AWS SSM operations. +type SSMClient interface { + PutParameter(ctx context.Context, params *ssm.PutParameterInput, + optFns ...func(*ssm.Options)) (*ssm.PutParameterOutput, error) + GetParameter(ctx context.Context, params *ssm.GetParameterInput, + optFns ...func(*ssm.Options)) (*ssm.GetParameterOutput, error) +} + +// AWSSSMStore saves credentials to AWS SSM Parameter Store with encryption. +type AWSSSMStore struct { + ParameterPrefix string + KMSKeyID string + Tags map[string]string + ssmClient SSMClient +} + +// SSMStoreOption is a functional option for configuring AWSSSMStore. +type SSMStoreOption func(*AWSSSMStore) + +// WithKMSKey sets a custom KMS key ID for parameter encryption. +func WithKMSKey(keyID string) SSMStoreOption { + return func(s *AWSSSMStore) { + s.KMSKeyID = keyID + } +} + +// WithTags adds AWS tags to all created parameters. +func WithTags(tags map[string]string) SSMStoreOption { + return func(s *AWSSSMStore) { + s.Tags = tags + } +} + +// WithSSMClient sets a custom SSM client. +func WithSSMClient(client SSMClient) SSMStoreOption { + return func(s *AWSSSMStore) { + s.ssmClient = client + } +} + +// NewAWSSSMStore creates a new AWS SSM Parameter Store backend. +// The prefix is normalized to always end with a slash. +func NewAWSSSMStore(prefix string, opts ...SSMStoreOption) (*AWSSSMStore, error) { + if prefix == "" { + return nil, fmt.Errorf("parameter prefix cannot be empty") + } + + if !strings.HasSuffix(prefix, "/") { + prefix = prefix + "/" + } + + store := &AWSSSMStore{ + ParameterPrefix: prefix, + } + + for _, opt := range opts { + opt(store) + } + + if store.ssmClient == nil { + cfg, err := config.LoadDefaultConfig(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to load AWS config: %w", err) + } + store.ssmClient = ssm.NewFromConfig(cfg) + } + + return store, nil +} + +// Save writes credentials to AWS SSM as encrypted SecureString parameters. +func (s *AWSSSMStore) Save(ctx context.Context, creds *AppCredentials) error { + parameters := map[string]string{ + EnvGitHubAppID: fmt.Sprintf("%d", creds.AppID), + EnvGitHubWebhookSecret: creds.WebhookSecret, + EnvGitHubClientID: creds.ClientID, + EnvGitHubClientSecret: creds.ClientSecret, + EnvGitHubAppPrivateKey: creds.PrivateKey, + } + + if creds.AppSlug != "" { + parameters[EnvGitHubAppSlug] = creds.AppSlug + } + if creds.HTMLURL != "" { + parameters[EnvGitHubAppHTMLURL] = creds.HTMLURL + } + + for key, value := range creds.CustomFields { + if value != "" { + parameters[key] = value + } + } + + for name, value := range parameters { + if err := s.putParameter(ctx, name, value); err != nil { + return fmt.Errorf("failed to save parameter %s: %w", name, err) + } + } + + return nil +} + +// putParameter creates or updates a single SSM parameter. +func (s *AWSSSMStore) putParameter(ctx context.Context, name, value string) error { + input := &ssm.PutParameterInput{ + Name: aws.String(s.ParameterPrefix + name), + Value: aws.String(value), + Type: types.ParameterTypeSecureString, + Overwrite: aws.Bool(true), + DataType: aws.String("text"), + } + + if s.KMSKeyID != "" { + input.KeyId = aws.String(s.KMSKeyID) + } + + if len(s.Tags) > 0 { + var tags []types.Tag + for key, value := range s.Tags { + tags = append(tags, types.Tag{ + Key: aws.String(key), + Value: aws.String(value), + }) + } + input.Tags = tags + } + + _, err := s.ssmClient.PutParameter(ctx, input) + if err != nil { + return err + } + + return nil +} + +// Status returns the current registration state by checking required SSM parameters. +func (s *AWSSSMStore) Status(ctx context.Context) (*InstallerStatus, error) { + status := &InstallerStatus{} + required := []string{ + EnvGitHubAppID, + EnvGitHubWebhookSecret, + EnvGitHubClientID, + EnvGitHubClientSecret, + EnvGitHubAppPrivateKey, + } + + values := make(map[string]string) + for _, key := range required { + value, err := s.getParameterValue(ctx, key) + if err != nil { + if isParameterNotFound(err) { + return status, nil + } + return nil, err + } + values[key] = value + } + + status.Registered = true + if id, err := strconv.ParseInt(strings.TrimSpace(values[EnvGitHubAppID]), 10, 64); err == nil { + status.AppID = id + } + + if slug, err := s.getParameterValue(ctx, EnvGitHubAppSlug); err == nil { + status.AppSlug = slug + } else if !isParameterNotFound(err) { + return nil, err + } + + if html, err := s.getParameterValue(ctx, EnvGitHubAppHTMLURL); err == nil { + status.HTMLURL = html + } else if !isParameterNotFound(err) { + return nil, err + } + + if flag, err := s.getParameterValue(ctx, EnvGitHubAppInstallerEnabled); err == nil { + status.InstallerDisabled = isFalseString(flag) + } else if !isParameterNotFound(err) { + return nil, err + } + + return status, nil +} + +// DisableInstaller sets a parameter to disable the installer. +func (s *AWSSSMStore) DisableInstaller(ctx context.Context) error { + return s.putParameter(ctx, EnvGitHubAppInstallerEnabled, "false") +} + +func (s *AWSSSMStore) getParameterValue(ctx context.Context, name string) (string, error) { + output, err := s.ssmClient.GetParameter(ctx, &ssm.GetParameterInput{ + Name: aws.String(s.ParameterPrefix + name), + WithDecryption: aws.Bool(true), + }) + if err != nil { + return "", err + } + if output.Parameter == nil || output.Parameter.Value == nil { + return "", fmt.Errorf("parameter %s missing value", name) + } + return aws.ToString(output.Parameter.Value), nil +} + +func isParameterNotFound(err error) bool { + var notFound *types.ParameterNotFound + return errors.As(err, ¬Found) +} diff --git a/configstore/aws_ssm_store_test.go b/configstore/aws_ssm_store_test.go new file mode 100644 index 0000000..4211ed9 --- /dev/null +++ b/configstore/aws_ssm_store_test.go @@ -0,0 +1,451 @@ +// Copyright 2025 CruxStack +// SPDX-License-Identifier: MIT + +package configstore + +import ( + "context" + "fmt" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ssm" + "github.com/aws/aws-sdk-go-v2/service/ssm/types" +) + +// mockSSMClient implements SSMClient for testing +type mockSSMClient struct { + parameters map[string]string + putCalls []ssm.PutParameterInput + getCalls []ssm.GetParameterInput + putErr error + getErr error +} + +func newMockSSMClient() *mockSSMClient { + return &mockSSMClient{ + parameters: make(map[string]string), + } +} + +func (m *mockSSMClient) PutParameter(ctx context.Context, params *ssm.PutParameterInput, optFns ...func(*ssm.Options)) (*ssm.PutParameterOutput, error) { + m.putCalls = append(m.putCalls, *params) + if m.putErr != nil { + return nil, m.putErr + } + m.parameters[*params.Name] = *params.Value + return &ssm.PutParameterOutput{}, nil +} + +func (m *mockSSMClient) GetParameter(ctx context.Context, params *ssm.GetParameterInput, optFns ...func(*ssm.Options)) (*ssm.GetParameterOutput, error) { + m.getCalls = append(m.getCalls, *params) + if m.getErr != nil { + return nil, m.getErr + } + value, ok := m.parameters[*params.Name] + if !ok { + return nil, &types.ParameterNotFound{} + } + return &ssm.GetParameterOutput{ + Parameter: &types.Parameter{ + Name: params.Name, + Value: aws.String(value), + }, + }, nil +} + +func TestNewAWSSSMStore(t *testing.T) { + t.Run("empty prefix returns error", func(t *testing.T) { + _, err := NewAWSSSMStore("") + if err == nil { + t.Error("NewAWSSSMStore(\"\") should return error") + } + }) + + t.Run("prefix without trailing slash is normalized", func(t *testing.T) { + mock := newMockSSMClient() + store, err := NewAWSSSMStore("/my/prefix", WithSSMClient(mock)) + if err != nil { + t.Fatalf("NewAWSSSMStore() error = %v", err) + } + if store.ParameterPrefix != "/my/prefix/" { + t.Errorf("ParameterPrefix = %q, want %q", store.ParameterPrefix, "/my/prefix/") + } + }) + + t.Run("prefix with trailing slash is preserved", func(t *testing.T) { + mock := newMockSSMClient() + store, err := NewAWSSSMStore("/my/prefix/", WithSSMClient(mock)) + if err != nil { + t.Fatalf("NewAWSSSMStore() error = %v", err) + } + if store.ParameterPrefix != "/my/prefix/" { + t.Errorf("ParameterPrefix = %q, want %q", store.ParameterPrefix, "/my/prefix/") + } + }) +} + +func TestAWSSSMStore_WithOptions(t *testing.T) { + t.Run("WithKMSKey sets KMS key ID", func(t *testing.T) { + mock := newMockSSMClient() + store, err := NewAWSSSMStore("/prefix", WithSSMClient(mock), WithKMSKey("alias/my-key")) + if err != nil { + t.Fatalf("NewAWSSSMStore() error = %v", err) + } + if store.KMSKeyID != "alias/my-key" { + t.Errorf("KMSKeyID = %q, want %q", store.KMSKeyID, "alias/my-key") + } + }) + + t.Run("WithTags sets tags", func(t *testing.T) { + mock := newMockSSMClient() + tags := map[string]string{"env": "prod", "team": "platform"} + store, err := NewAWSSSMStore("/prefix", WithSSMClient(mock), WithTags(tags)) + if err != nil { + t.Fatalf("NewAWSSSMStore() error = %v", err) + } + if len(store.Tags) != 2 { + t.Errorf("Tags count = %d, want 2", len(store.Tags)) + } + if store.Tags["env"] != "prod" { + t.Errorf("Tags[\"env\"] = %q, want %q", store.Tags["env"], "prod") + } + }) +} + +func TestAWSSSMStore_Save(t *testing.T) { + mock := newMockSSMClient() + store, err := NewAWSSSMStore("/app/github/", WithSSMClient(mock)) + if err != nil { + t.Fatalf("NewAWSSSMStore() error = %v", err) + } + + creds := &AppCredentials{ + AppID: 12345, + AppSlug: "my-app", + ClientID: "Iv1.abc123", + ClientSecret: "secret123", + WebhookSecret: "whsec_123", + PrivateKey: "-----BEGIN RSA PRIVATE KEY-----\nkey\n-----END RSA PRIVATE KEY-----", + HTMLURL: "https://github.com/apps/my-app", + CustomFields: map[string]string{ + "STS_DOMAIN": "sts.example.com", + }, + } + + err = store.Save(context.Background(), creds) + if err != nil { + t.Fatalf("Save() error = %v", err) + } + + // Verify all expected parameters were saved + expectedParams := map[string]string{ + "/app/github/GITHUB_APP_ID": "12345", + "/app/github/GITHUB_APP_SLUG": "my-app", + "/app/github/GITHUB_CLIENT_ID": "Iv1.abc123", + "/app/github/GITHUB_CLIENT_SECRET": "secret123", + "/app/github/GITHUB_WEBHOOK_SECRET": "whsec_123", + "/app/github/GITHUB_APP_PRIVATE_KEY": "-----BEGIN RSA PRIVATE KEY-----\nkey\n-----END RSA PRIVATE KEY-----", + "/app/github/GITHUB_APP_HTML_URL": "https://github.com/apps/my-app", + "/app/github/STS_DOMAIN": "sts.example.com", + } + + for name, wantValue := range expectedParams { + gotValue, ok := mock.parameters[name] + if !ok { + t.Errorf("Parameter %q was not saved", name) + continue + } + if gotValue != wantValue { + t.Errorf("Parameter %q = %q, want %q", name, gotValue, wantValue) + } + } + + // Verify all parameters were saved as SecureString + for _, call := range mock.putCalls { + if call.Type != types.ParameterTypeSecureString { + t.Errorf("Parameter %q type = %v, want SecureString", *call.Name, call.Type) + } + if call.Overwrite == nil || !*call.Overwrite { + t.Errorf("Parameter %q Overwrite should be true", *call.Name) + } + } +} + +func TestAWSSSMStore_Save_WithKMSKey(t *testing.T) { + mock := newMockSSMClient() + store, err := NewAWSSSMStore("/prefix/", WithSSMClient(mock), WithKMSKey("alias/custom-key")) + if err != nil { + t.Fatalf("NewAWSSSMStore() error = %v", err) + } + + creds := &AppCredentials{ + AppID: 1, + ClientID: "client", + ClientSecret: "secret", + WebhookSecret: "webhook", + PrivateKey: "key", + } + + err = store.Save(context.Background(), creds) + if err != nil { + t.Fatalf("Save() error = %v", err) + } + + // Verify KMS key was used + for _, call := range mock.putCalls { + if call.KeyId == nil || *call.KeyId != "alias/custom-key" { + t.Errorf("Parameter %q KeyId = %v, want alias/custom-key", *call.Name, call.KeyId) + } + } +} + +func TestAWSSSMStore_Save_WithTags(t *testing.T) { + mock := newMockSSMClient() + tags := map[string]string{"env": "test", "app": "myapp"} + store, err := NewAWSSSMStore("/prefix/", WithSSMClient(mock), WithTags(tags)) + if err != nil { + t.Fatalf("NewAWSSSMStore() error = %v", err) + } + + creds := &AppCredentials{ + AppID: 1, + ClientID: "client", + ClientSecret: "secret", + WebhookSecret: "webhook", + PrivateKey: "key", + } + + err = store.Save(context.Background(), creds) + if err != nil { + t.Fatalf("Save() error = %v", err) + } + + // Verify tags were applied to all parameters + for _, call := range mock.putCalls { + if len(call.Tags) != 2 { + t.Errorf("Parameter %q has %d tags, want 2", *call.Name, len(call.Tags)) + } + } +} + +func TestAWSSSMStore_Save_OmitsEmptyOptionalFields(t *testing.T) { + mock := newMockSSMClient() + store, err := NewAWSSSMStore("/prefix/", WithSSMClient(mock)) + if err != nil { + t.Fatalf("NewAWSSSMStore() error = %v", err) + } + + creds := &AppCredentials{ + AppID: 1, + ClientID: "client", + ClientSecret: "secret", + WebhookSecret: "webhook", + PrivateKey: "key", + // AppSlug and HTMLURL are empty + } + + err = store.Save(context.Background(), creds) + if err != nil { + t.Fatalf("Save() error = %v", err) + } + + // Optional parameters should not be saved + for _, name := range []string{"/prefix/GITHUB_APP_SLUG", "/prefix/GITHUB_APP_HTML_URL"} { + if _, ok := mock.parameters[name]; ok { + t.Errorf("Empty optional parameter %q should not be saved", name) + } + } +} + +func TestAWSSSMStore_Save_Error(t *testing.T) { + mock := newMockSSMClient() + mock.putErr = fmt.Errorf("access denied") + + store, err := NewAWSSSMStore("/prefix/", WithSSMClient(mock)) + if err != nil { + t.Fatalf("NewAWSSSMStore() error = %v", err) + } + + creds := &AppCredentials{ + AppID: 1, + ClientID: "client", + ClientSecret: "secret", + WebhookSecret: "webhook", + PrivateKey: "key", + } + + err = store.Save(context.Background(), creds) + if err == nil { + t.Error("Save() should return error when PutParameter fails") + } +} + +func TestAWSSSMStore_Status_Registered(t *testing.T) { + mock := newMockSSMClient() + mock.parameters = map[string]string{ + "/prefix/GITHUB_APP_ID": "12345", + "/prefix/GITHUB_APP_SLUG": "test-app", + "/prefix/GITHUB_APP_HTML_URL": "https://github.com/apps/test-app", + "/prefix/GITHUB_CLIENT_ID": "client123", + "/prefix/GITHUB_CLIENT_SECRET": "secret123", + "/prefix/GITHUB_WEBHOOK_SECRET": "webhook123", + "/prefix/GITHUB_APP_PRIVATE_KEY": "-----BEGIN RSA-----\nkey\n-----END RSA-----", + } + + store, err := NewAWSSSMStore("/prefix/", WithSSMClient(mock)) + if err != nil { + t.Fatalf("NewAWSSSMStore() error = %v", err) + } + + status, err := store.Status(context.Background()) + if err != nil { + t.Fatalf("Status() error = %v", err) + } + + if !status.Registered { + t.Error("Status.Registered = false, want true") + } + if status.AppID != 12345 { + t.Errorf("Status.AppID = %d, want 12345", status.AppID) + } + if status.AppSlug != "test-app" { + t.Errorf("Status.AppSlug = %q, want %q", status.AppSlug, "test-app") + } + if status.HTMLURL != "https://github.com/apps/test-app" { + t.Errorf("Status.HTMLURL = %q, want %q", status.HTMLURL, "https://github.com/apps/test-app") + } + if status.InstallerDisabled { + t.Error("Status.InstallerDisabled = true, want false") + } +} + +func TestAWSSSMStore_Status_NotRegistered(t *testing.T) { + mock := newMockSSMClient() + // No parameters exist + + store, err := NewAWSSSMStore("/prefix/", WithSSMClient(mock)) + if err != nil { + t.Fatalf("NewAWSSSMStore() error = %v", err) + } + + status, err := store.Status(context.Background()) + if err != nil { + t.Fatalf("Status() error = %v", err) + } + + if status.Registered { + t.Error("Status.Registered = true, want false (no parameters)") + } +} + +func TestAWSSSMStore_Status_InstallerDisabled(t *testing.T) { + mock := newMockSSMClient() + mock.parameters = map[string]string{ + "/prefix/GITHUB_APP_ID": "12345", + "/prefix/GITHUB_CLIENT_ID": "client", + "/prefix/GITHUB_CLIENT_SECRET": "secret", + "/prefix/GITHUB_WEBHOOK_SECRET": "webhook", + "/prefix/GITHUB_APP_PRIVATE_KEY": "key", + "/prefix/GITHUB_APP_INSTALLER_ENABLED": "false", + } + + store, err := NewAWSSSMStore("/prefix/", WithSSMClient(mock)) + if err != nil { + t.Fatalf("NewAWSSSMStore() error = %v", err) + } + + status, err := store.Status(context.Background()) + if err != nil { + t.Fatalf("Status() error = %v", err) + } + + if !status.InstallerDisabled { + t.Error("Status.InstallerDisabled = false, want true") + } +} + +func TestAWSSSMStore_Status_Error(t *testing.T) { + mock := newMockSSMClient() + mock.getErr = fmt.Errorf("access denied") + + store, err := NewAWSSSMStore("/prefix/", WithSSMClient(mock)) + if err != nil { + t.Fatalf("NewAWSSSMStore() error = %v", err) + } + + _, err = store.Status(context.Background()) + if err == nil { + t.Error("Status() should return error when GetParameter fails") + } +} + +func TestAWSSSMStore_DisableInstaller(t *testing.T) { + mock := newMockSSMClient() + store, err := NewAWSSSMStore("/prefix/", WithSSMClient(mock)) + if err != nil { + t.Fatalf("NewAWSSSMStore() error = %v", err) + } + + err = store.DisableInstaller(context.Background()) + if err != nil { + t.Fatalf("DisableInstaller() error = %v", err) + } + + // Verify the parameter was set + value, ok := mock.parameters["/prefix/GITHUB_APP_INSTALLER_ENABLED"] + if !ok { + t.Error("DisableInstaller() did not create parameter") + } + if value != "false" { + t.Errorf("Parameter value = %q, want %q", value, "false") + } +} + +func TestAWSSSMStore_DisableInstaller_Error(t *testing.T) { + mock := newMockSSMClient() + mock.putErr = fmt.Errorf("access denied") + + store, err := NewAWSSSMStore("/prefix/", WithSSMClient(mock)) + if err != nil { + t.Fatalf("NewAWSSSMStore() error = %v", err) + } + + err = store.DisableInstaller(context.Background()) + if err == nil { + t.Error("DisableInstaller() should return error when PutParameter fails") + } +} + +func TestIsParameterNotFound(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + { + name: "ParameterNotFound error", + err: &types.ParameterNotFound{}, + want: true, + }, + { + name: "other error", + err: fmt.Errorf("some other error"), + want: false, + }, + { + name: "nil error", + err: nil, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isParameterNotFound(tt.err) + if got != tt.want { + t.Errorf("isParameterNotFound() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/configstore/local_env_store.go b/configstore/local_env_store.go new file mode 100644 index 0000000..b94af76 --- /dev/null +++ b/configstore/local_env_store.go @@ -0,0 +1,245 @@ +// Copyright 2025 CruxStack +// SPDX-License-Identifier: MIT + +package configstore + +import ( + "bufio" + "context" + "fmt" + "os" + "path/filepath" + "strconv" + "strings" +) + +// LocalEnvFileStore saves credentials to a .env file. +type LocalEnvFileStore struct { + FilePath string +} + +// NewLocalEnvFileStore creates a store that saves credentials to the given path. +func NewLocalEnvFileStore(filepath string) *LocalEnvFileStore { + return &LocalEnvFileStore{FilePath: filepath} +} + +// Save writes credentials to .env format, preserving existing content. +// It also sets the environment variables in the current process so they +// are immediately available to the application. +func (s *LocalEnvFileStore) Save(ctx context.Context, creds *AppCredentials) error { + dir := filepath.Dir(s.FilePath) + if err := os.MkdirAll(dir, 0700); err != nil { + return fmt.Errorf("failed to create directory %s: %w", dir, err) + } + + existingValues, originalLines, err := parseEnvFile(s.FilePath) + if err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to read existing .env file: %w", err) + } + if existingValues == nil { + existingValues = make(map[string]string) + } + + for key, value := range creds.CustomFields { + if value != "" { + existingValues[key] = value + } + } + + singleLinePEM := strings.ReplaceAll(creds.PrivateKey, "\n", "\\n") + + existingValues[EnvGitHubAppID] = fmt.Sprintf("%d", creds.AppID) + existingValues[EnvGitHubWebhookSecret] = creds.WebhookSecret + existingValues[EnvGitHubClientID] = creds.ClientID + existingValues[EnvGitHubClientSecret] = creds.ClientSecret + existingValues[EnvGitHubAppPrivateKey] = singleLinePEM + if creds.AppSlug != "" { + existingValues[EnvGitHubAppSlug] = creds.AppSlug + } + if creds.HTMLURL != "" { + existingValues[EnvGitHubAppHTMLURL] = creds.HTMLURL + } + + if err := writeEnvFile(s.FilePath, existingValues, originalLines); err != nil { + return fmt.Errorf("failed to write .env file: %w", err) + } + + // Set environment variables in the current process so they are + // immediately available for configuration reload. + os.Setenv(EnvGitHubAppID, fmt.Sprintf("%d", creds.AppID)) + os.Setenv(EnvGitHubWebhookSecret, creds.WebhookSecret) + os.Setenv(EnvGitHubClientID, creds.ClientID) + os.Setenv(EnvGitHubClientSecret, creds.ClientSecret) + os.Setenv(EnvGitHubAppPrivateKey, singleLinePEM) + if creds.AppSlug != "" { + os.Setenv(EnvGitHubAppSlug, creds.AppSlug) + } + if creds.HTMLURL != "" { + os.Setenv(EnvGitHubAppHTMLURL, creds.HTMLURL) + } + for key, value := range creds.CustomFields { + if value != "" { + os.Setenv(key, value) + } + } + + return nil +} + +func parseEnvFile(path string) (map[string]string, []string, error) { + file, err := os.Open(path) + if err != nil { + return nil, nil, err + } + defer file.Close() + + values := make(map[string]string) + var lines []string + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := scanner.Text() + lines = append(lines, line) + + trimmed := strings.TrimSpace(line) + if trimmed == "" || strings.HasPrefix(trimmed, "#") { + continue + } + + idx := strings.Index(line, "=") + if idx == -1 { + continue + } + + key := strings.TrimSpace(line[:idx]) + value := strings.TrimSpace(line[idx+1:]) + + if len(value) >= 2 { + if (strings.HasPrefix(value, "\"") && strings.HasSuffix(value, "\"")) || + (strings.HasPrefix(value, "'") && strings.HasSuffix(value, "'")) { + value = value[1 : len(value)-1] + } + } + + values[key] = value + } + + if err := scanner.Err(); err != nil { + return nil, nil, err + } + + return values, lines, nil +} + +func writeEnvFile(path string, values map[string]string, originalLines []string) error { + var outputLines []string + writtenKeys := make(map[string]bool) + + for _, line := range originalLines { + trimmed := strings.TrimSpace(line) + + if trimmed == "" || strings.HasPrefix(trimmed, "#") { + outputLines = append(outputLines, line) + continue + } + + idx := strings.Index(line, "=") + if idx == -1 { + outputLines = append(outputLines, line) + continue + } + + key := strings.TrimSpace(line[:idx]) + + if newValue, ok := values[key]; ok { + outputLines = append(outputLines, formatEnvLine(key, newValue)) + writtenKeys[key] = true + } else { + outputLines = append(outputLines, line) + } + } + + for key, value := range values { + if !writtenKeys[key] { + outputLines = append(outputLines, formatEnvLine(key, value)) + } + } + + content := strings.Join(outputLines, "\n") + if len(outputLines) > 0 { + content += "\n" + } + + return os.WriteFile(path, []byte(content), 0600) +} + +func formatEnvLine(key, value string) string { + needsQuotes := strings.ContainsAny(value, " \t\n\r\"'\\#") || strings.Contains(value, "\\n") + + if needsQuotes { + escaped := strings.ReplaceAll(value, "\"", "\\\"") + return fmt.Sprintf("%s=\"%s\"", key, escaped) + } + + return fmt.Sprintf("%s=%s", key, value) +} + +// Status returns the current registration state by checking the .env file. +func (s *LocalEnvFileStore) Status(ctx context.Context) (*InstallerStatus, error) { + values, _, err := parseEnvFile(s.FilePath) + if err != nil { + if os.IsNotExist(err) { + return &InstallerStatus{}, nil + } + return nil, err + } + + status := &InstallerStatus{ + AppSlug: values[EnvGitHubAppSlug], + HTMLURL: values[EnvGitHubAppHTMLURL], + } + + if idStr := strings.TrimSpace(values[EnvGitHubAppID]); idStr != "" { + if id, err := strconv.ParseInt(idStr, 10, 64); err == nil { + status.AppID = id + } + } + + status.Registered = hasAllValues(values, + EnvGitHubAppID, + EnvGitHubWebhookSecret, + EnvGitHubClientID, + EnvGitHubClientSecret, + EnvGitHubAppPrivateKey, + ) + + status.InstallerDisabled = isFalseString(values[EnvGitHubAppInstallerEnabled]) + + return status, nil +} + +// DisableInstaller sets GITHUB_APP_INSTALLER_ENABLED=false in the .env file. +func (s *LocalEnvFileStore) DisableInstaller(ctx context.Context) error { + dir := filepath.Dir(s.FilePath) + if err := os.MkdirAll(dir, 0700); err != nil { + return fmt.Errorf("failed to create directory %s: %w", dir, err) + } + + values, originalLines, err := parseEnvFile(s.FilePath) + if err != nil { + if !os.IsNotExist(err) { + return err + } + } + if values == nil { + values = make(map[string]string) + } + + values[EnvGitHubAppInstallerEnabled] = "false" + + if err := writeEnvFile(s.FilePath, values, originalLines); err != nil { + return fmt.Errorf("failed to persist installer flag: %w", err) + } + + return nil +} diff --git a/configstore/local_env_store_test.go b/configstore/local_env_store_test.go new file mode 100644 index 0000000..fb652cd --- /dev/null +++ b/configstore/local_env_store_test.go @@ -0,0 +1,590 @@ +// Copyright 2025 CruxStack +// SPDX-License-Identifier: MIT + +package configstore + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestParseEnvFile(t *testing.T) { + tests := []struct { + name string + content string + wantValues map[string]string + wantLines int + }{ + { + name: "simple key-value pairs", + content: "FOO=bar\nBAZ=qux", + wantValues: map[string]string{ + "FOO": "bar", + "BAZ": "qux", + }, + wantLines: 2, + }, + { + name: "double-quoted values", + content: `KEY="value with spaces"`, + wantValues: map[string]string{ + "KEY": "value with spaces", + }, + wantLines: 1, + }, + { + name: "single-quoted values", + content: `KEY='value with spaces'`, + wantValues: map[string]string{ + "KEY": "value with spaces", + }, + wantLines: 1, + }, + { + name: "unquoted value with equals sign", + content: "KEY=value=with=equals", + wantValues: map[string]string{ + "KEY": "value=with=equals", + }, + wantLines: 1, + }, + { + name: "comments are ignored", + content: "# This is a comment\nKEY=value\n# Another comment", + wantValues: map[string]string{ + "KEY": "value", + }, + wantLines: 3, + }, + { + name: "empty lines preserved", + content: "FOO=bar\n\nBAZ=qux", + wantValues: map[string]string{ + "FOO": "bar", + "BAZ": "qux", + }, + wantLines: 3, + }, + { + name: "whitespace around equals", + content: "KEY = value", + wantValues: map[string]string{ + "KEY": "value", + }, + wantLines: 1, + }, + { + name: "PEM key with escaped newlines", + content: `PRIVATE_KEY="-----BEGIN RSA PRIVATE KEY-----\nMIIE...\n-----END RSA PRIVATE KEY-----"`, + wantValues: map[string]string{ + "PRIVATE_KEY": `-----BEGIN RSA PRIVATE KEY-----\nMIIE...\n-----END RSA PRIVATE KEY-----`, + }, + wantLines: 1, + }, + { + name: "empty file", + content: "", + wantValues: map[string]string{}, + wantLines: 0, + }, + { + name: "only comments", + content: "# comment 1\n# comment 2", + wantValues: map[string]string{}, + wantLines: 2, + }, + { + name: "line without equals ignored", + content: "INVALID_LINE\nVALID=true", + wantValues: map[string]string{ + "VALID": "true", + }, + wantLines: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := t.TempDir() + envPath := filepath.Join(tmpDir, ".env") + + if err := os.WriteFile(envPath, []byte(tt.content), 0600); err != nil { + t.Fatalf("Failed to write test file: %v", err) + } + + values, lines, err := parseEnvFile(envPath) + if err != nil { + t.Fatalf("parseEnvFile() error = %v", err) + } + + if len(values) != len(tt.wantValues) { + t.Errorf("parseEnvFile() got %d values, want %d", len(values), len(tt.wantValues)) + } + + for k, want := range tt.wantValues { + got, ok := values[k] + if !ok { + t.Errorf("parseEnvFile() missing key %q", k) + continue + } + if got != want { + t.Errorf("parseEnvFile()[%q] = %q, want %q", k, got, want) + } + } + + if len(lines) != tt.wantLines { + t.Errorf("parseEnvFile() got %d lines, want %d", len(lines), tt.wantLines) + } + }) + } +} + +func TestParseEnvFile_NotExists(t *testing.T) { + values, lines, err := parseEnvFile("/nonexistent/path/.env") + + if !os.IsNotExist(err) { + t.Errorf("parseEnvFile() error = %v, want os.IsNotExist", err) + } + if values != nil { + t.Errorf("parseEnvFile() values = %v, want nil", values) + } + if lines != nil { + t.Errorf("parseEnvFile() lines = %v, want nil", lines) + } +} + +func TestFormatEnvLine(t *testing.T) { + tests := []struct { + name string + key string + value string + want string + }{ + { + name: "simple value", + key: "KEY", + value: "value", + want: "KEY=value", + }, + { + name: "value with spaces needs quotes", + key: "KEY", + value: "value with spaces", + want: `KEY="value with spaces"`, + }, + { + name: "value with escaped newlines needs quotes", + key: "PEM", + value: `-----BEGIN RSA-----\nMIIE\n-----END RSA-----`, + want: `PEM="-----BEGIN RSA-----\nMIIE\n-----END RSA-----"`, + }, + { + name: "value with hash needs quotes", + key: "KEY", + value: "value#comment", + want: `KEY="value#comment"`, + }, + { + name: "value with double quote is escaped", + key: "KEY", + value: `value"quoted`, + want: `KEY="value\"quoted"`, + }, + { + name: "value with single quote needs quotes", + key: "KEY", + value: "it's", + want: `KEY="it's"`, + }, + { + name: "value with tab needs quotes", + key: "KEY", + value: "value\twith\ttabs", + want: `KEY="value with tabs"`, + }, + { + name: "empty value", + key: "KEY", + value: "", + want: "KEY=", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := formatEnvLine(tt.key, tt.value) + if got != tt.want { + t.Errorf("formatEnvLine(%q, %q) = %q, want %q", tt.key, tt.value, got, tt.want) + } + }) + } +} + +func TestWriteEnvFile_PreservesComments(t *testing.T) { + tmpDir := t.TempDir() + envPath := filepath.Join(tmpDir, ".env") + + originalContent := `# Database config +DB_HOST=localhost + +# App config +APP_NAME=test +` + if err := os.WriteFile(envPath, []byte(originalContent), 0600); err != nil { + t.Fatalf("Failed to write initial file: %v", err) + } + + values, lines, err := parseEnvFile(envPath) + if err != nil { + t.Fatalf("parseEnvFile() error = %v", err) + } + + // Update one value and add a new one + values["APP_NAME"] = "updated" + values["NEW_KEY"] = "new_value" + + if err := writeEnvFile(envPath, values, lines); err != nil { + t.Fatalf("writeEnvFile() error = %v", err) + } + + content, err := os.ReadFile(envPath) + if err != nil { + t.Fatalf("Failed to read file: %v", err) + } + + result := string(content) + + // Comments should be preserved + if !strings.Contains(result, "# Database config") { + t.Error("writeEnvFile() lost comment '# Database config'") + } + if !strings.Contains(result, "# App config") { + t.Error("writeEnvFile() lost comment '# App config'") + } + + // Updated value should be present + if !strings.Contains(result, "APP_NAME=updated") { + t.Error("writeEnvFile() did not update APP_NAME") + } + + // New key should be appended + if !strings.Contains(result, "NEW_KEY=new_value") { + t.Error("writeEnvFile() did not add NEW_KEY") + } + + // Original structure should be maintained + if !strings.Contains(result, "DB_HOST=localhost") { + t.Error("writeEnvFile() lost DB_HOST") + } +} + +func TestLocalEnvFileStore_Save(t *testing.T) { + tmpDir := t.TempDir() + envPath := filepath.Join(tmpDir, ".env") + + store := NewLocalEnvFileStore(envPath) + creds := &AppCredentials{ + AppID: 12345, + AppSlug: "my-app", + ClientID: "Iv1.abc123", + ClientSecret: "secret123", + WebhookSecret: "whsec_123", + PrivateKey: "-----BEGIN RSA PRIVATE KEY-----\nMIIE...\n-----END RSA PRIVATE KEY-----\n", + HTMLURL: "https://github.com/apps/my-app", + CustomFields: map[string]string{ + "STS_DOMAIN": "sts.example.com", + }, + } + + err := store.Save(context.Background(), creds) + if err != nil { + t.Fatalf("Save() error = %v", err) + } + + // Verify file was created with correct permissions + info, err := os.Stat(envPath) + if err != nil { + t.Fatalf("File not created: %v", err) + } + if info.Mode().Perm() != 0600 { + t.Errorf("File permissions = %o, want 0600", info.Mode().Perm()) + } + + // Parse and verify contents + values, _, err := parseEnvFile(envPath) + if err != nil { + t.Fatalf("parseEnvFile() error = %v", err) + } + + checks := map[string]string{ + EnvGitHubAppID: "12345", + EnvGitHubAppSlug: "my-app", + EnvGitHubClientID: "Iv1.abc123", + EnvGitHubClientSecret: "secret123", + EnvGitHubWebhookSecret: "whsec_123", + EnvGitHubAppHTMLURL: "https://github.com/apps/my-app", + "STS_DOMAIN": "sts.example.com", + } + + for key, want := range checks { + got := values[key] + if got != want { + t.Errorf("values[%q] = %q, want %q", key, got, want) + } + } + + // PEM key should have newlines escaped + pemValue := values[EnvGitHubAppPrivateKey] + if strings.Contains(pemValue, "\n") && !strings.Contains(pemValue, "\\n") { + t.Error("Private key should have literal \\n not actual newlines") + } +} + +func TestLocalEnvFileStore_Save_CreatesDirectory(t *testing.T) { + tmpDir := t.TempDir() + envPath := filepath.Join(tmpDir, "nested", "dir", ".env") + + store := NewLocalEnvFileStore(envPath) + creds := &AppCredentials{ + AppID: 1, + ClientID: "client", + ClientSecret: "secret", + WebhookSecret: "webhook", + PrivateKey: "key", + } + + err := store.Save(context.Background(), creds) + if err != nil { + t.Fatalf("Save() error = %v", err) + } + + if _, err := os.Stat(envPath); os.IsNotExist(err) { + t.Error("Save() did not create file in nested directory") + } +} + +func TestLocalEnvFileStore_Save_PreservesExistingValues(t *testing.T) { + tmpDir := t.TempDir() + envPath := filepath.Join(tmpDir, ".env") + + // Write initial content + initialContent := `# My app config +EXISTING_KEY=existing_value +PORT=8080 +` + if err := os.WriteFile(envPath, []byte(initialContent), 0600); err != nil { + t.Fatalf("Failed to write initial file: %v", err) + } + + store := NewLocalEnvFileStore(envPath) + creds := &AppCredentials{ + AppID: 1, + ClientID: "client", + ClientSecret: "secret", + WebhookSecret: "webhook", + PrivateKey: "key", + } + + err := store.Save(context.Background(), creds) + if err != nil { + t.Fatalf("Save() error = %v", err) + } + + values, _, err := parseEnvFile(envPath) + if err != nil { + t.Fatalf("parseEnvFile() error = %v", err) + } + + // Existing values should be preserved + if values["EXISTING_KEY"] != "existing_value" { + t.Error("Save() overwrote EXISTING_KEY") + } + if values["PORT"] != "8080" { + t.Error("Save() overwrote PORT") + } + + // New values should be present + if values[EnvGitHubAppID] != "1" { + t.Error("Save() did not write GITHUB_APP_ID") + } +} + +func TestLocalEnvFileStore_Status(t *testing.T) { + tmpDir := t.TempDir() + envPath := filepath.Join(tmpDir, ".env") + + content := `GITHUB_APP_ID=12345 +GITHUB_APP_SLUG=my-app +GITHUB_APP_HTML_URL=https://github.com/apps/my-app +GITHUB_APP_PRIVATE_KEY="-----BEGIN RSA-----\nMIIE\n-----END RSA-----" +GITHUB_WEBHOOK_SECRET=whsec_123 +GITHUB_CLIENT_ID=Iv1.abc +GITHUB_CLIENT_SECRET=secret +` + if err := os.WriteFile(envPath, []byte(content), 0600); err != nil { + t.Fatalf("Failed to write file: %v", err) + } + + store := NewLocalEnvFileStore(envPath) + status, err := store.Status(context.Background()) + if err != nil { + t.Fatalf("Status() error = %v", err) + } + + if !status.Registered { + t.Error("Status.Registered = false, want true") + } + if status.AppID != 12345 { + t.Errorf("Status.AppID = %d, want 12345", status.AppID) + } + if status.AppSlug != "my-app" { + t.Errorf("Status.AppSlug = %q, want %q", status.AppSlug, "my-app") + } + if status.HTMLURL != "https://github.com/apps/my-app" { + t.Errorf("Status.HTMLURL = %q, want %q", status.HTMLURL, "https://github.com/apps/my-app") + } + if status.InstallerDisabled { + t.Error("Status.InstallerDisabled = true, want false") + } +} + +func TestLocalEnvFileStore_Status_NotRegistered(t *testing.T) { + tmpDir := t.TempDir() + envPath := filepath.Join(tmpDir, ".env") + + // Missing required fields + content := `GITHUB_APP_ID=12345 +# Missing other required fields +` + if err := os.WriteFile(envPath, []byte(content), 0600); err != nil { + t.Fatalf("Failed to write file: %v", err) + } + + store := NewLocalEnvFileStore(envPath) + status, err := store.Status(context.Background()) + if err != nil { + t.Fatalf("Status() error = %v", err) + } + + if status.Registered { + t.Error("Status.Registered = true, want false (missing required fields)") + } +} + +func TestLocalEnvFileStore_Status_FileNotExists(t *testing.T) { + store := NewLocalEnvFileStore("/nonexistent/.env") + status, err := store.Status(context.Background()) + if err != nil { + t.Fatalf("Status() error = %v, want nil for nonexistent file", err) + } + + if status.Registered { + t.Error("Status.Registered = true, want false for nonexistent file") + } +} + +func TestLocalEnvFileStore_Status_InstallerDisabled(t *testing.T) { + tmpDir := t.TempDir() + envPath := filepath.Join(tmpDir, ".env") + + content := `GITHUB_APP_INSTALLER_ENABLED=false +` + if err := os.WriteFile(envPath, []byte(content), 0600); err != nil { + t.Fatalf("Failed to write file: %v", err) + } + + store := NewLocalEnvFileStore(envPath) + status, err := store.Status(context.Background()) + if err != nil { + t.Fatalf("Status() error = %v", err) + } + + if !status.InstallerDisabled { + t.Error("Status.InstallerDisabled = false, want true") + } +} + +func TestLocalEnvFileStore_DisableInstaller(t *testing.T) { + tmpDir := t.TempDir() + envPath := filepath.Join(tmpDir, ".env") + + store := NewLocalEnvFileStore(envPath) + err := store.DisableInstaller(context.Background()) + if err != nil { + t.Fatalf("DisableInstaller() error = %v", err) + } + + values, _, err := parseEnvFile(envPath) + if err != nil { + t.Fatalf("parseEnvFile() error = %v", err) + } + + if values[EnvGitHubAppInstallerEnabled] != "false" { + t.Errorf("DisableInstaller() did not set %s=false", EnvGitHubAppInstallerEnabled) + } +} + +func TestLocalEnvFileStore_RoundTrip(t *testing.T) { + tmpDir := t.TempDir() + envPath := filepath.Join(tmpDir, ".env") + + store := NewLocalEnvFileStore(envPath) + + // Original credentials with complex PEM key + original := &AppCredentials{ + AppID: 99999, + AppSlug: "test-app", + ClientID: "Iv1.complex123", + ClientSecret: "super-secret-value", + WebhookSecret: "whsec_complex", + PrivateKey: `-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEA0Z3VS5JJcds3xfn/ygWyF8PbnGy0AHB7MhgHW1FZ ++multiline+content+here +-----END RSA PRIVATE KEY----- +`, + HTMLURL: "https://github.com/apps/test-app", + CustomFields: map[string]string{ + "CUSTOM_DOMAIN": "custom.example.com", + "ANOTHER_FIELD": "another-value", + }, + } + + // Save + if err := store.Save(context.Background(), original); err != nil { + t.Fatalf("Save() error = %v", err) + } + + // Verify via Status + status, err := store.Status(context.Background()) + if err != nil { + t.Fatalf("Status() error = %v", err) + } + + if !status.Registered { + t.Error("Status.Registered = false after Save()") + } + if status.AppID != original.AppID { + t.Errorf("Status.AppID = %d, want %d", status.AppID, original.AppID) + } + if status.AppSlug != original.AppSlug { + t.Errorf("Status.AppSlug = %q, want %q", status.AppSlug, original.AppSlug) + } + if status.HTMLURL != original.HTMLURL { + t.Errorf("Status.HTMLURL = %q, want %q", status.HTMLURL, original.HTMLURL) + } + + // Also verify custom fields were saved + values, _, err := parseEnvFile(envPath) + if err != nil { + t.Fatalf("parseEnvFile() error = %v", err) + } + if values["CUSTOM_DOMAIN"] != "custom.example.com" { + t.Error("Custom field CUSTOM_DOMAIN was not saved") + } + if values["ANOTHER_FIELD"] != "another-value" { + t.Error("Custom field ANOTHER_FIELD was not saved") + } +} diff --git a/configstore/local_file_store.go b/configstore/local_file_store.go new file mode 100644 index 0000000..ca5cc8a --- /dev/null +++ b/configstore/local_file_store.go @@ -0,0 +1,143 @@ +// Copyright 2025 CruxStack +// SPDX-License-Identifier: MIT + +package configstore + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strconv" + "strings" +) + +// LocalFileStore saves credentials as individual files in a directory. +type LocalFileStore struct { + Dir string +} + +// NewLocalFileStore creates a store that saves credentials as files in dir. +func NewLocalFileStore(dir string) *LocalFileStore { + return &LocalFileStore{Dir: dir} +} + +// Save writes credentials to individual files in the store directory. +func (s *LocalFileStore) Save(ctx context.Context, creds *AppCredentials) error { + if err := os.MkdirAll(s.Dir, 0700); err != nil { + return fmt.Errorf("failed to create directory %s: %w", s.Dir, err) + } + + files := map[string]struct { + content string + mode os.FileMode + }{ + "app-id": {content: fmt.Sprintf("%d", creds.AppID), mode: 0644}, + "private-key.pem": {content: creds.PrivateKey, mode: 0600}, + "webhook-secret": {content: creds.WebhookSecret, mode: 0600}, + "client-id": {content: creds.ClientID, mode: 0644}, + "client-secret": {content: creds.ClientSecret, mode: 0600}, + } + + if creds.AppSlug != "" { + files["app-slug"] = struct { + content string + mode os.FileMode + }{content: creds.AppSlug, mode: 0644} + } + if creds.HTMLURL != "" { + files["app-html-url"] = struct { + content string + mode os.FileMode + }{content: creds.HTMLURL, mode: 0644} + } + + for key, value := range creds.CustomFields { + if value != "" { + filename := strings.ToLower(strings.ReplaceAll(key, "_", "-")) + files[filename] = struct { + content string + mode os.FileMode + }{content: value, mode: 0644} + } + } + + for name, file := range files { + path := filepath.Join(s.Dir, name) + if err := os.WriteFile(path, []byte(file.content), file.mode); err != nil { + return fmt.Errorf("failed to write %s: %w", path, err) + } + } + + return nil +} + +// Status returns the current registration state by checking required files. +func (s *LocalFileStore) Status(ctx context.Context) (*InstallerStatus, error) { + status := &InstallerStatus{} + + appIDBytes, err := os.ReadFile(filepath.Join(s.Dir, "app-id")) + if err != nil { + if os.IsNotExist(err) { + return status, nil + } + return nil, err + } + + if id, err := strconv.ParseInt(strings.TrimSpace(string(appIDBytes)), 10, 64); err == nil { + status.AppID = id + } + + required := []string{"client-id", "client-secret", "webhook-secret", "private-key.pem"} + for _, name := range required { + if _, err := os.Stat(filepath.Join(s.Dir, name)); err != nil { + if os.IsNotExist(err) { + return status, nil + } + return nil, err + } + } + status.Registered = true + + if slug, err := readTrimmedFile(filepath.Join(s.Dir, "app-slug")); err == nil { + status.AppSlug = slug + } else if !os.IsNotExist(err) { + return nil, err + } + + if html, err := readTrimmedFile(filepath.Join(s.Dir, "app-html-url")); err == nil { + status.HTMLURL = html + } else if !os.IsNotExist(err) { + return nil, err + } + + if _, err := os.Stat(filepath.Join(s.Dir, "installer-disabled")); err == nil { + status.InstallerDisabled = true + } else if !os.IsNotExist(err) { + return nil, err + } + + return status, nil +} + +// DisableInstaller creates a marker file to disable the installer. +func (s *LocalFileStore) DisableInstaller(ctx context.Context) error { + if err := os.MkdirAll(s.Dir, 0700); err != nil { + return fmt.Errorf("failed to create directory %s: %w", s.Dir, err) + } + + path := filepath.Join(s.Dir, "installer-disabled") + if err := os.WriteFile(path, []byte("disabled"), 0600); err != nil { + return fmt.Errorf("failed to write %s: %w", path, err) + } + + return nil +} + +func readTrimmedFile(path string) (string, error) { + data, err := os.ReadFile(path) + if err != nil { + return "", err + } + return strings.TrimSpace(string(data)), nil +} diff --git a/configstore/local_file_store_test.go b/configstore/local_file_store_test.go new file mode 100644 index 0000000..a56840a --- /dev/null +++ b/configstore/local_file_store_test.go @@ -0,0 +1,445 @@ +// Copyright 2025 CruxStack +// SPDX-License-Identifier: MIT + +package configstore + +import ( + "context" + "os" + "path/filepath" + "testing" +) + +func TestLocalFileStore_Save(t *testing.T) { + tmpDir := t.TempDir() + store := NewLocalFileStore(tmpDir) + + creds := &AppCredentials{ + AppID: 12345, + AppSlug: "my-app", + ClientID: "Iv1.abc123", + ClientSecret: "secret123", + WebhookSecret: "whsec_123", + PrivateKey: "-----BEGIN RSA PRIVATE KEY-----\nMIIE...\n-----END RSA PRIVATE KEY-----\n", + HTMLURL: "https://github.com/apps/my-app", + CustomFields: map[string]string{ + "STS_DOMAIN": "sts.example.com", + "CUSTOM_VALUE": "custom123", + }, + } + + err := store.Save(context.Background(), creds) + if err != nil { + t.Fatalf("Save() error = %v", err) + } + + // Verify core files exist with correct content + checks := map[string]struct { + content string + mode os.FileMode + }{ + "app-id": {content: "12345", mode: 0644}, + "app-slug": {content: "my-app", mode: 0644}, + "client-id": {content: "Iv1.abc123", mode: 0644}, + "client-secret": {content: "secret123", mode: 0600}, + "webhook-secret": {content: "whsec_123", mode: 0600}, + "private-key.pem": {content: "-----BEGIN RSA PRIVATE KEY-----\nMIIE...\n-----END RSA PRIVATE KEY-----\n", mode: 0600}, + "app-html-url": {content: "https://github.com/apps/my-app", mode: 0644}, + "sts-domain": {content: "sts.example.com", mode: 0644}, + "custom-value": {content: "custom123", mode: 0644}, + } + + for name, want := range checks { + path := filepath.Join(tmpDir, name) + + // Check file exists + info, err := os.Stat(path) + if err != nil { + t.Errorf("File %q not created: %v", name, err) + continue + } + + // Check permissions + if info.Mode().Perm() != want.mode { + t.Errorf("File %q permissions = %o, want %o", name, info.Mode().Perm(), want.mode) + } + + // Check content + content, err := os.ReadFile(path) + if err != nil { + t.Errorf("Failed to read %q: %v", name, err) + continue + } + if string(content) != want.content { + t.Errorf("File %q content = %q, want %q", name, string(content), want.content) + } + } +} + +func TestLocalFileStore_Save_CreatesDirectory(t *testing.T) { + tmpDir := t.TempDir() + nestedDir := filepath.Join(tmpDir, "nested", "config", "dir") + + store := NewLocalFileStore(nestedDir) + creds := &AppCredentials{ + AppID: 1, + ClientID: "client", + ClientSecret: "secret", + WebhookSecret: "webhook", + PrivateKey: "key", + } + + err := store.Save(context.Background(), creds) + if err != nil { + t.Fatalf("Save() error = %v", err) + } + + // Directory should be created with 0700 permissions + info, err := os.Stat(nestedDir) + if err != nil { + t.Fatalf("Directory not created: %v", err) + } + if !info.IsDir() { + t.Error("Expected directory, got file") + } + if info.Mode().Perm() != 0700 { + t.Errorf("Directory permissions = %o, want 0700", info.Mode().Perm()) + } +} + +func TestLocalFileStore_Save_OmitsEmptyOptionalFields(t *testing.T) { + tmpDir := t.TempDir() + store := NewLocalFileStore(tmpDir) + + creds := &AppCredentials{ + AppID: 1, + ClientID: "client", + ClientSecret: "secret", + WebhookSecret: "webhook", + PrivateKey: "key", + // AppSlug and HTMLURL are empty + } + + err := store.Save(context.Background(), creds) + if err != nil { + t.Fatalf("Save() error = %v", err) + } + + // Optional files should not exist + for _, name := range []string{"app-slug", "app-html-url"} { + path := filepath.Join(tmpDir, name) + if _, err := os.Stat(path); !os.IsNotExist(err) { + t.Errorf("File %q should not exist for empty value", name) + } + } +} + +func TestLocalFileStore_Save_SkipsEmptyCustomFields(t *testing.T) { + tmpDir := t.TempDir() + store := NewLocalFileStore(tmpDir) + + creds := &AppCredentials{ + AppID: 1, + ClientID: "client", + ClientSecret: "secret", + WebhookSecret: "webhook", + PrivateKey: "key", + CustomFields: map[string]string{ + "EMPTY_FIELD": "", + "VALID_FIELD": "value", + }, + } + + err := store.Save(context.Background(), creds) + if err != nil { + t.Fatalf("Save() error = %v", err) + } + + // Empty custom field should not be created + if _, err := os.Stat(filepath.Join(tmpDir, "empty-field")); !os.IsNotExist(err) { + t.Error("Empty custom field should not create a file") + } + + // Valid custom field should exist + if _, err := os.Stat(filepath.Join(tmpDir, "valid-field")); err != nil { + t.Errorf("Valid custom field not created: %v", err) + } +} + +func TestLocalFileStore_Status_Registered(t *testing.T) { + tmpDir := t.TempDir() + + // Create all required files + files := map[string]string{ + "app-id": "12345", + "app-slug": "test-app", + "app-html-url": "https://github.com/apps/test-app", + "client-id": "client123", + "client-secret": "secret123", + "webhook-secret": "webhook123", + "private-key.pem": "-----BEGIN RSA-----\n...\n-----END RSA-----", + } + for name, content := range files { + if err := os.WriteFile(filepath.Join(tmpDir, name), []byte(content), 0644); err != nil { + t.Fatalf("Failed to create %s: %v", name, err) + } + } + + store := NewLocalFileStore(tmpDir) + status, err := store.Status(context.Background()) + if err != nil { + t.Fatalf("Status() error = %v", err) + } + + if !status.Registered { + t.Error("Status.Registered = false, want true") + } + if status.AppID != 12345 { + t.Errorf("Status.AppID = %d, want 12345", status.AppID) + } + if status.AppSlug != "test-app" { + t.Errorf("Status.AppSlug = %q, want %q", status.AppSlug, "test-app") + } + if status.HTMLURL != "https://github.com/apps/test-app" { + t.Errorf("Status.HTMLURL = %q, want %q", status.HTMLURL, "https://github.com/apps/test-app") + } + if status.InstallerDisabled { + t.Error("Status.InstallerDisabled = true, want false") + } +} + +func TestLocalFileStore_Status_NotRegistered_MissingAppID(t *testing.T) { + tmpDir := t.TempDir() + + // Directory exists but no app-id file + store := NewLocalFileStore(tmpDir) + status, err := store.Status(context.Background()) + if err != nil { + t.Fatalf("Status() error = %v", err) + } + + if status.Registered { + t.Error("Status.Registered = true, want false (no app-id)") + } +} + +func TestLocalFileStore_Status_NotRegistered_MissingRequiredFiles(t *testing.T) { + tmpDir := t.TempDir() + + // Create app-id but missing required files + if err := os.WriteFile(filepath.Join(tmpDir, "app-id"), []byte("12345"), 0644); err != nil { + t.Fatalf("Failed to create app-id: %v", err) + } + + store := NewLocalFileStore(tmpDir) + status, err := store.Status(context.Background()) + if err != nil { + t.Fatalf("Status() error = %v", err) + } + + if status.Registered { + t.Error("Status.Registered = true, want false (missing required files)") + } +} + +func TestLocalFileStore_Status_InstallerDisabled(t *testing.T) { + tmpDir := t.TempDir() + + // Create all required files plus installer-disabled marker + files := map[string]string{ + "app-id": "12345", + "client-id": "client", + "client-secret": "secret", + "webhook-secret": "webhook", + "private-key.pem": "key", + "installer-disabled": "disabled", + } + for name, content := range files { + if err := os.WriteFile(filepath.Join(tmpDir, name), []byte(content), 0644); err != nil { + t.Fatalf("Failed to create %s: %v", name, err) + } + } + + store := NewLocalFileStore(tmpDir) + status, err := store.Status(context.Background()) + if err != nil { + t.Fatalf("Status() error = %v", err) + } + + if !status.InstallerDisabled { + t.Error("Status.InstallerDisabled = false, want true") + } +} + +func TestLocalFileStore_Status_DirectoryNotExists(t *testing.T) { + store := NewLocalFileStore("/nonexistent/directory") + status, err := store.Status(context.Background()) + if err != nil { + t.Fatalf("Status() error = %v, want nil for nonexistent directory", err) + } + + if status.Registered { + t.Error("Status.Registered = true, want false for nonexistent directory") + } +} + +func TestLocalFileStore_Status_WhitespaceInAppID(t *testing.T) { + tmpDir := t.TempDir() + + // Create app-id with whitespace and newline + files := map[string]string{ + "app-id": " 12345\n", + "client-id": "client", + "client-secret": "secret", + "webhook-secret": "webhook", + "private-key.pem": "key", + } + for name, content := range files { + if err := os.WriteFile(filepath.Join(tmpDir, name), []byte(content), 0644); err != nil { + t.Fatalf("Failed to create %s: %v", name, err) + } + } + + store := NewLocalFileStore(tmpDir) + status, err := store.Status(context.Background()) + if err != nil { + t.Fatalf("Status() error = %v", err) + } + + if status.AppID != 12345 { + t.Errorf("Status.AppID = %d, want 12345 (whitespace should be trimmed)", status.AppID) + } +} + +func TestLocalFileStore_DisableInstaller(t *testing.T) { + tmpDir := t.TempDir() + store := NewLocalFileStore(tmpDir) + + err := store.DisableInstaller(context.Background()) + if err != nil { + t.Fatalf("DisableInstaller() error = %v", err) + } + + // Check marker file exists + markerPath := filepath.Join(tmpDir, "installer-disabled") + info, err := os.Stat(markerPath) + if err != nil { + t.Fatalf("Marker file not created: %v", err) + } + + // Check permissions (should be 0600 for security) + if info.Mode().Perm() != 0600 { + t.Errorf("Marker file permissions = %o, want 0600", info.Mode().Perm()) + } +} + +func TestLocalFileStore_DisableInstaller_CreatesDirectory(t *testing.T) { + tmpDir := t.TempDir() + nestedDir := filepath.Join(tmpDir, "nested", "dir") + store := NewLocalFileStore(nestedDir) + + err := store.DisableInstaller(context.Background()) + if err != nil { + t.Fatalf("DisableInstaller() error = %v", err) + } + + // Directory should be created + if _, err := os.Stat(nestedDir); err != nil { + t.Errorf("Directory not created: %v", err) + } + + // Marker file should exist + if _, err := os.Stat(filepath.Join(nestedDir, "installer-disabled")); err != nil { + t.Errorf("Marker file not created: %v", err) + } +} + +func TestLocalFileStore_RoundTrip(t *testing.T) { + tmpDir := t.TempDir() + store := NewLocalFileStore(tmpDir) + + original := &AppCredentials{ + AppID: 99999, + AppSlug: "roundtrip-app", + ClientID: "Iv1.roundtrip", + ClientSecret: "roundtrip-secret", + WebhookSecret: "whsec_roundtrip", + PrivateKey: "-----BEGIN RSA PRIVATE KEY-----\nroundtrip-key-content\n-----END RSA PRIVATE KEY-----\n", + HTMLURL: "https://github.com/apps/roundtrip-app", + CustomFields: map[string]string{ + "CUSTOM_FIELD": "custom-value", + }, + } + + // Save + if err := store.Save(context.Background(), original); err != nil { + t.Fatalf("Save() error = %v", err) + } + + // Verify via Status + status, err := store.Status(context.Background()) + if err != nil { + t.Fatalf("Status() error = %v", err) + } + + if !status.Registered { + t.Error("Status.Registered = false after Save()") + } + if status.AppID != original.AppID { + t.Errorf("Status.AppID = %d, want %d", status.AppID, original.AppID) + } + if status.AppSlug != original.AppSlug { + t.Errorf("Status.AppSlug = %q, want %q", status.AppSlug, original.AppSlug) + } + if status.HTMLURL != original.HTMLURL { + t.Errorf("Status.HTMLURL = %q, want %q", status.HTMLURL, original.HTMLURL) + } + + // Verify custom field was saved + content, err := os.ReadFile(filepath.Join(tmpDir, "custom-field")) + if err != nil { + t.Fatalf("Custom field file not found: %v", err) + } + if string(content) != "custom-value" { + t.Errorf("Custom field content = %q, want %q", string(content), "custom-value") + } +} + +func TestReadTrimmedFile(t *testing.T) { + tmpDir := t.TempDir() + + tests := []struct { + name string + content string + want string + }{ + {"no whitespace", "value", "value"}, + {"trailing newline", "value\n", "value"}, + {"leading and trailing spaces", " value ", "value"}, + {"mixed whitespace", "\t value \n", "value"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + path := filepath.Join(tmpDir, "test-"+tt.name) + if err := os.WriteFile(path, []byte(tt.content), 0644); err != nil { + t.Fatalf("Failed to write file: %v", err) + } + + got, err := readTrimmedFile(path) + if err != nil { + t.Fatalf("readTrimmedFile() error = %v", err) + } + if got != tt.want { + t.Errorf("readTrimmedFile() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestReadTrimmedFile_NotExists(t *testing.T) { + _, err := readTrimmedFile("/nonexistent/file") + if !os.IsNotExist(err) { + t.Errorf("readTrimmedFile() error = %v, want os.IsNotExist", err) + } +} diff --git a/configstore/store.go b/configstore/store.go new file mode 100644 index 0000000..53b1427 --- /dev/null +++ b/configstore/store.go @@ -0,0 +1,162 @@ +// Copyright 2025 CruxStack +// SPDX-License-Identifier: MIT + +// Package configstore provides storage backends for GitHub App credentials. +// It supports multiple storage backends including AWS SSM Parameter Store, +// local .env files, and individual files. +package configstore + +import ( + "context" + "encoding/json" + "fmt" + "os" + "strings" +) + +const ( + EnvGitHubAppID = "GITHUB_APP_ID" + EnvGitHubAppSlug = "GITHUB_APP_SLUG" + EnvGitHubAppHTMLURL = "GITHUB_APP_HTML_URL" + EnvGitHubAppPrivateKey = "GITHUB_APP_PRIVATE_KEY" + EnvGitHubWebhookSecret = "GITHUB_WEBHOOK_SECRET" + EnvGitHubClientID = "GITHUB_CLIENT_ID" + EnvGitHubClientSecret = "GITHUB_CLIENT_SECRET" +) + +const ( + EnvGitHubAppInstallerEnabled = "GITHUB_APP_INSTALLER_ENABLED" + EnvStorageMode = "STORAGE_MODE" + EnvStorageDir = "STORAGE_DIR" + EnvAWSSSMParameterPfx = "AWS_SSM_PARAMETER_PREFIX" + EnvAWSSSMKMSKeyID = "AWS_SSM_KMS_KEY_ID" + EnvAWSSSMTags = "AWS_SSM_TAGS" +) + +// Storage mode constants for STORAGE_MODE environment variable. +const ( + // StorageModeEnvFile saves credentials to a .env file (default mode). + StorageModeEnvFile = "envfile" + // StorageModeFiles saves credentials as individual files in a directory. + StorageModeFiles = "files" + // StorageModeAWSSSM saves credentials to AWS SSM Parameter Store. + StorageModeAWSSSM = "aws-ssm" +) + +// HookConfig contains webhook configuration returned from GitHub. +type HookConfig struct { + URL string `json:"url"` +} + +// AppCredentials holds credentials returned from GitHub App manifest creation. +type AppCredentials struct { + AppID int64 `json:"id"` + AppSlug string `json:"slug"` + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + WebhookSecret string `json:"webhook_secret"` + PrivateKey string `json:"pem"` + HTMLURL string `json:"html_url"` + HookConfig HookConfig `json:"hook_config"` + + // CustomFields stores additional app-specific values alongside credentials. + CustomFields map[string]string `json:"-"` +} + +// InstallerStatus describes the current GitHub App registration state. +type InstallerStatus struct { + Registered bool + InstallerDisabled bool + AppID int64 + AppSlug string + HTMLURL string +} + +// Store saves app credentials to various backends (local disk, AWS SSM, etc). +type Store interface { + Save(ctx context.Context, creds *AppCredentials) error + Status(ctx context.Context) (*InstallerStatus, error) + DisableInstaller(ctx context.Context) error +} + +// NewFromEnv creates a Store based on environment variable configuration. +// It reads STORAGE_MODE to determine the backend type: +// - "envfile" (default): saves to a .env file at STORAGE_DIR (default: ./.env) +// - "files": saves to individual files in STORAGE_DIR directory +// - "aws-ssm": saves to AWS SSM Parameter Store with AWS_SSM_PARAMETER_PREFIX +// +// Returns an error if configuration is invalid or store creation fails. +func NewFromEnv() (Store, error) { + mode := GetEnvDefault(EnvStorageMode, StorageModeEnvFile) + + switch mode { + case StorageModeFiles: + dir := GetEnvDefault(EnvStorageDir, "./.env") + return NewLocalFileStore(dir), nil + + case StorageModeEnvFile: + path := GetEnvDefault(EnvStorageDir, "./.env") + return NewLocalEnvFileStore(path), nil + + case StorageModeAWSSSM: + prefix := os.Getenv(EnvAWSSSMParameterPfx) + if prefix == "" { + return nil, fmt.Errorf("%s is required when using %s storage mode", EnvAWSSSMParameterPfx, StorageModeAWSSSM) + } + + var opts []SSMStoreOption + + if kmsKeyID := os.Getenv(EnvAWSSSMKMSKeyID); kmsKeyID != "" { + opts = append(opts, WithKMSKey(kmsKeyID)) + } + + if tagsJSON := os.Getenv(EnvAWSSSMTags); tagsJSON != "" { + var tags map[string]string + if err := json.Unmarshal([]byte(tagsJSON), &tags); err != nil { + return nil, fmt.Errorf("failed to parse %s as JSON: %w", EnvAWSSSMTags, err) + } + opts = append(opts, WithTags(tags)) + } + + return NewAWSSSMStore(prefix, opts...) + + default: + return nil, fmt.Errorf("unknown %s: %s (expected '%s', '%s', or '%s')", + EnvStorageMode, mode, StorageModeEnvFile, StorageModeFiles, StorageModeAWSSSM) + } +} + +// InstallerEnabled returns true if the installer is enabled via environment variable. +func InstallerEnabled() bool { + v := strings.ToLower(os.Getenv(EnvGitHubAppInstallerEnabled)) + return v == "true" || v == "1" || v == "yes" +} + +// GetEnvDefault returns an env var value, or defaultValue if not set or empty. +func GetEnvDefault(key, defaultValue string) string { + if v := os.Getenv(key); v != "" { + return v + } + return defaultValue +} + +func hasAllValues(values map[string]string, keys ...string) bool { + if len(values) == 0 { + return false + } + for _, key := range keys { + if strings.TrimSpace(values[key]) == "" { + return false + } + } + return true +} + +func isFalseString(v string) bool { + switch strings.ToLower(strings.TrimSpace(v)) { + case "false", "0", "no", "off": + return true + default: + return false + } +} diff --git a/configstore/store_test.go b/configstore/store_test.go new file mode 100644 index 0000000..e5c4373 --- /dev/null +++ b/configstore/store_test.go @@ -0,0 +1,312 @@ +// Copyright 2025 CruxStack +// SPDX-License-Identifier: MIT + +package configstore + +import ( + "os" + "testing" +) + +func TestInstallerEnabled(t *testing.T) { + tests := []struct { + name string + value string + want bool + }{ + {"true lowercase", "true", true}, + {"TRUE uppercase", "TRUE", true}, + {"True mixed", "True", true}, + {"1", "1", true}, + {"yes", "yes", true}, + {"YES uppercase", "YES", true}, + {"false", "false", false}, + {"FALSE", "FALSE", false}, + {"0", "0", false}, + {"no", "no", false}, + {"empty string", "", false}, + {"random string", "enabled", false}, + {"on (not supported)", "on", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + os.Setenv(EnvGitHubAppInstallerEnabled, tt.value) + defer os.Unsetenv(EnvGitHubAppInstallerEnabled) + + got := InstallerEnabled() + if got != tt.want { + t.Errorf("InstallerEnabled() with %q = %v, want %v", tt.value, got, tt.want) + } + }) + } +} + +func TestHasAllValues(t *testing.T) { + tests := []struct { + name string + values map[string]string + keys []string + want bool + }{ + { + name: "all keys present with values", + values: map[string]string{"a": "1", "b": "2", "c": "3"}, + keys: []string{"a", "b"}, + want: true, + }, + { + name: "missing key", + values: map[string]string{"a": "1"}, + keys: []string{"a", "b"}, + want: false, + }, + { + name: "empty value fails", + values: map[string]string{"a": "1", "b": ""}, + keys: []string{"a", "b"}, + want: false, + }, + { + name: "whitespace-only value fails", + values: map[string]string{"a": "1", "b": " "}, + keys: []string{"a", "b"}, + want: false, + }, + { + name: "nil map", + values: nil, + keys: []string{"a"}, + want: false, + }, + { + name: "empty map", + values: map[string]string{}, + keys: []string{"a"}, + want: false, + }, + { + name: "no keys required", + values: map[string]string{"a": "1"}, + keys: []string{}, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := hasAllValues(tt.values, tt.keys...) + if got != tt.want { + t.Errorf("hasAllValues() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestIsFalseString(t *testing.T) { + tests := []struct { + value string + want bool + }{ + {"false", true}, + {"FALSE", true}, + {"False", true}, + {"0", true}, + {"no", true}, + {"NO", true}, + {"off", true}, + {"OFF", true}, + {" false ", true}, // with whitespace + {"true", false}, + {"1", false}, + {"yes", false}, + {"on", false}, + {"", false}, + {"random", false}, + } + + for _, tt := range tests { + t.Run(tt.value, func(t *testing.T) { + got := isFalseString(tt.value) + if got != tt.want { + t.Errorf("isFalseString(%q) = %v, want %v", tt.value, got, tt.want) + } + }) + } +} + +func TestGetEnvDefault(t *testing.T) { + tests := []struct { + name string + envKey string + envValue string + setEnv bool + defaultValue string + want string + }{ + { + name: "env set returns env value", + envKey: "TEST_VAR", + envValue: "custom_value", + setEnv: true, + defaultValue: "default", + want: "custom_value", + }, + { + name: "env not set returns default", + envKey: "TEST_VAR_UNSET", + envValue: "", + setEnv: false, + defaultValue: "default_value", + want: "default_value", + }, + { + name: "empty env returns default", + envKey: "TEST_VAR_EMPTY", + envValue: "", + setEnv: true, + defaultValue: "default", + want: "default", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + os.Unsetenv(tt.envKey) + if tt.setEnv { + os.Setenv(tt.envKey, tt.envValue) + defer os.Unsetenv(tt.envKey) + } + + got := GetEnvDefault(tt.envKey, tt.defaultValue) + if got != tt.want { + t.Errorf("GetEnvDefault(%q, %q) = %q, want %q", tt.envKey, tt.defaultValue, got, tt.want) + } + }) + } +} + +func TestNewFromEnv_StorageModes(t *testing.T) { + t.Run("default mode creates LocalEnvFileStore", func(t *testing.T) { + os.Unsetenv(EnvStorageMode) + defer os.Unsetenv(EnvStorageMode) + + store, err := NewFromEnv() + if err != nil { + t.Fatalf("NewFromEnv() error = %v", err) + } + + if _, ok := store.(*LocalEnvFileStore); !ok { + t.Errorf("NewFromEnv() returned %T, want *LocalEnvFileStore", store) + } + }) + + t.Run("envfile mode creates LocalEnvFileStore", func(t *testing.T) { + os.Setenv(EnvStorageMode, StorageModeEnvFile) + defer os.Unsetenv(EnvStorageMode) + + store, err := NewFromEnv() + if err != nil { + t.Fatalf("NewFromEnv() error = %v", err) + } + + if _, ok := store.(*LocalEnvFileStore); !ok { + t.Errorf("NewFromEnv() returned %T, want *LocalEnvFileStore", store) + } + }) + + t.Run("files mode creates LocalFileStore", func(t *testing.T) { + os.Setenv(EnvStorageMode, StorageModeFiles) + defer os.Unsetenv(EnvStorageMode) + + store, err := NewFromEnv() + if err != nil { + t.Fatalf("NewFromEnv() error = %v", err) + } + + if _, ok := store.(*LocalFileStore); !ok { + t.Errorf("NewFromEnv() returned %T, want *LocalFileStore", store) + } + }) + + t.Run("aws-ssm mode requires prefix", func(t *testing.T) { + os.Setenv(EnvStorageMode, StorageModeAWSSSM) + os.Unsetenv(EnvAWSSSMParameterPfx) + defer os.Unsetenv(EnvStorageMode) + + _, err := NewFromEnv() + if err == nil { + t.Error("NewFromEnv() with aws-ssm and no prefix should return error") + } + }) + + t.Run("unknown mode returns error", func(t *testing.T) { + os.Setenv(EnvStorageMode, "invalid-mode") + defer os.Unsetenv(EnvStorageMode) + + _, err := NewFromEnv() + if err == nil { + t.Error("NewFromEnv() with unknown mode should return error") + } + }) +} + +func TestNewFromEnv_CustomStorageDir(t *testing.T) { + t.Run("envfile mode uses STORAGE_DIR", func(t *testing.T) { + os.Setenv(EnvStorageMode, StorageModeEnvFile) + os.Setenv(EnvStorageDir, "/custom/path/.env") + defer os.Unsetenv(EnvStorageMode) + defer os.Unsetenv(EnvStorageDir) + + store, err := NewFromEnv() + if err != nil { + t.Fatalf("NewFromEnv() error = %v", err) + } + + envStore, ok := store.(*LocalEnvFileStore) + if !ok { + t.Fatalf("NewFromEnv() returned %T, want *LocalEnvFileStore", store) + } + + if envStore.FilePath != "/custom/path/.env" { + t.Errorf("FilePath = %q, want %q", envStore.FilePath, "/custom/path/.env") + } + }) + + t.Run("files mode uses STORAGE_DIR", func(t *testing.T) { + os.Setenv(EnvStorageMode, StorageModeFiles) + os.Setenv(EnvStorageDir, "/custom/dir") + defer os.Unsetenv(EnvStorageMode) + defer os.Unsetenv(EnvStorageDir) + + store, err := NewFromEnv() + if err != nil { + t.Fatalf("NewFromEnv() error = %v", err) + } + + fileStore, ok := store.(*LocalFileStore) + if !ok { + t.Fatalf("NewFromEnv() returned %T, want *LocalFileStore", store) + } + + if fileStore.Dir != "/custom/dir" { + t.Errorf("Dir = %q, want %q", fileStore.Dir, "/custom/dir") + } + }) +} + +func TestNewFromEnv_AWSSSMTags(t *testing.T) { + t.Run("invalid JSON tags returns error", func(t *testing.T) { + os.Setenv(EnvStorageMode, StorageModeAWSSSM) + os.Setenv(EnvAWSSSMParameterPfx, "/test/prefix") + os.Setenv(EnvAWSSSMTags, "not valid json") + defer os.Unsetenv(EnvStorageMode) + defer os.Unsetenv(EnvAWSSSMParameterPfx) + defer os.Unsetenv(EnvAWSSSMTags) + + _, err := NewFromEnv() + if err == nil { + t.Error("NewFromEnv() with invalid JSON tags should return error") + } + }) +} diff --git a/configwait/configwait.go b/configwait/configwait.go new file mode 100644 index 0000000..cb5e6ab --- /dev/null +++ b/configwait/configwait.go @@ -0,0 +1,203 @@ +// Copyright 2025 CruxStack +// SPDX-License-Identifier: MIT + +// Package configwait provides utilities for waiting on configuration availability +// during startup. +package configwait + +import ( + "context" + "encoding/json" + "net/http" + "os" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/chainguard-dev/clog" +) + +const ( + EnvMaxRetries = "CONFIG_WAIT_MAX_RETRIES" + EnvRetryInterval = "CONFIG_WAIT_RETRY_INTERVAL" +) + +const ( + DefaultMaxRetries = 30 + DefaultRetryInterval = 2 * time.Second +) + +// Config configures the wait behavior. +type Config struct { + MaxRetries int + RetryInterval time.Duration +} + +// NewConfigFromEnv creates a Config from environment variables. +func NewConfigFromEnv() Config { + cfg := Config{ + MaxRetries: DefaultMaxRetries, + RetryInterval: DefaultRetryInterval, + } + + if v := os.Getenv(EnvMaxRetries); v != "" { + if n, err := strconv.Atoi(v); err == nil && n > 0 { + cfg.MaxRetries = n + } + } + + if v := os.Getenv(EnvRetryInterval); v != "" { + if d, err := time.ParseDuration(v); err == nil && d > 0 { + cfg.RetryInterval = d + } + } + + return cfg +} + +// LoadFunc attempts to load configuration; returns nil on success. +type LoadFunc func(ctx context.Context) error + +// Wait blocks until load succeeds or max retries is reached. +func Wait(ctx context.Context, cfg Config, load LoadFunc) error { + log := clog.FromContext(ctx) + var lastErr error + + for attempt := 1; attempt <= cfg.MaxRetries; attempt++ { + if err := load(ctx); err != nil { + lastErr = err + log.Warnf("[configwait] attempt %d/%d failed: %v", attempt, cfg.MaxRetries, err) + + if attempt < cfg.MaxRetries { + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(cfg.RetryInterval): + + } + } + } else { + if attempt > 1 { + log.Infof("[configwait] configuration loaded successfully after %d attempts", attempt) + } + return nil + } + } + + return lastErr +} + +// ReadyGate gates HTTP requests until the service is ready. +type ReadyGate struct { + inner http.Handler + allowedPaths []string + ready atomic.Bool + handler atomic.Value // stores http.Handler once ready + + mu sync.Mutex + handlerReady chan struct{} +} + +// NewReadyGate creates a ReadyGate wrapping the given handler. +// The allowedPaths are path prefixes always allowed through (e.g., "/setup"). +// The inner handler can be nil initially; call SetHandler() once ready. +func NewReadyGate(inner http.Handler, allowedPaths []string) *ReadyGate { + rg := &ReadyGate{ + inner: inner, + allowedPaths: allowedPaths, + handlerReady: make(chan struct{}), + } + if inner != nil { + rg.handler.Store(inner) + } + return rg +} + +// SetReady marks the service as ready to handle all requests. +func (rg *ReadyGate) SetReady() { + rg.ready.Store(true) +} + +// SetHandler sets the main handler to use once ready. +func (rg *ReadyGate) SetHandler(h http.Handler) { + rg.handler.Store(h) + rg.mu.Lock() + defer rg.mu.Unlock() + select { + case <-rg.handlerReady: + default: + close(rg.handlerReady) + } +} + +// IsReady returns true if the service is ready. +func (rg *ReadyGate) IsReady() bool { + return rg.ready.Load() +} + +// ServeHTTP implements http.Handler. +func (rg *ReadyGate) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if rg.isAllowedPath(r.URL.Path) { + h := rg.getHandler() + if h != nil { + h.ServeHTTP(w, r) + return + } + rg.serveUnavailable(w, r, "service starting up") + return + } + + if !rg.ready.Load() { + rg.serveUnavailable(w, r, "service not ready, configuration loading") + return + } + + h := rg.getHandler() + if h == nil { + rg.serveUnavailable(w, r, "service starting up") + return + } + h.ServeHTTP(w, r) +} + +// isAllowedPath checks if the path matches any allowed path prefix. +func (rg *ReadyGate) isAllowedPath(path string) bool { + for _, allowed := range rg.allowedPaths { + if allowed == "/" { + if path == "/" { + return true + } + continue + } + if strings.HasPrefix(path, allowed) { + return true + } + } + return false +} + +// getHandler returns the current handler. +func (rg *ReadyGate) getHandler() http.Handler { + h := rg.handler.Load() + if h == nil { + return nil + } + return h.(http.Handler) +} + +// serveUnavailable writes a 503 response. +func (rg *ReadyGate) serveUnavailable(w http.ResponseWriter, r *http.Request, message string) { + log := clog.FromContext(r.Context()) + + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Retry-After", "5") + w.WriteHeader(http.StatusServiceUnavailable) + if err := json.NewEncoder(w).Encode(map[string]string{ + "error": "service_unavailable", + "message": message, + }); err != nil { + log.Errorf("[configwait] failed to write unavailable response: %v", err) + } +} diff --git a/configwait/configwait_test.go b/configwait/configwait_test.go new file mode 100644 index 0000000..852c948 --- /dev/null +++ b/configwait/configwait_test.go @@ -0,0 +1,398 @@ +// Copyright 2025 CruxStack +// SPDX-License-Identifier: MIT + +package configwait + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" +) + +func TestWait_ImmediateSuccess(t *testing.T) { + ctx := context.Background() + cfg := Config{ + MaxRetries: 3, + RetryInterval: 10 * time.Millisecond, + } + + callCount := 0 + err := Wait(ctx, cfg, func(ctx context.Context) error { + callCount++ + return nil + }) + + if err != nil { + t.Errorf("Wait() error = %v, want nil", err) + } + if callCount != 1 { + t.Errorf("Load function called %d times, want 1", callCount) + } +} + +func TestWait_RetryThenSuccess(t *testing.T) { + ctx := context.Background() + cfg := Config{ + MaxRetries: 5, + RetryInterval: 10 * time.Millisecond, + } + + callCount := 0 + err := Wait(ctx, cfg, func(ctx context.Context) error { + callCount++ + if callCount < 3 { + return errors.New("not ready") + } + return nil + }) + + if err != nil { + t.Errorf("Wait() error = %v, want nil", err) + } + if callCount != 3 { + t.Errorf("Load function called %d times, want 3", callCount) + } +} + +func TestWait_MaxRetriesExceeded(t *testing.T) { + ctx := context.Background() + cfg := Config{ + MaxRetries: 3, + RetryInterval: 10 * time.Millisecond, + } + + callCount := 0 + expectedErr := errors.New("always fail") + err := Wait(ctx, cfg, func(ctx context.Context) error { + callCount++ + return expectedErr + }) + + if err != expectedErr { + t.Errorf("Wait() error = %v, want %v", err, expectedErr) + } + if callCount != 3 { + t.Errorf("Load function called %d times, want 3", callCount) + } +} + +func TestWait_ContextCancellation(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cfg := Config{ + MaxRetries: 100, + RetryInterval: 100 * time.Millisecond, + } + + callCount := 0 + go func() { + time.Sleep(50 * time.Millisecond) + cancel() + }() + + err := Wait(ctx, cfg, func(ctx context.Context) error { + callCount++ + return errors.New("not ready") + }) + + if err != context.Canceled { + t.Errorf("Wait() error = %v, want %v", err, context.Canceled) + } + // Should have been cancelled after 1-2 attempts + if callCount > 2 { + t.Errorf("Load function called %d times, expected <= 2", callCount) + } +} + +func TestReadyGate_NotReadyReturns503(t *testing.T) { + inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("ok")) + }) + + gate := NewReadyGate(inner, nil) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + rec := httptest.NewRecorder() + + gate.ServeHTTP(rec, req) + + if rec.Code != http.StatusServiceUnavailable { + t.Errorf("Status = %d, want %d", rec.Code, http.StatusServiceUnavailable) + } +} + +func TestReadyGate_ReadyPassesThrough(t *testing.T) { + inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("ok")) + }) + + gate := NewReadyGate(inner, nil) + gate.SetReady() + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + rec := httptest.NewRecorder() + + gate.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("Status = %d, want %d", rec.Code, http.StatusOK) + } + if rec.Body.String() != "ok" { + t.Errorf("Body = %q, want %q", rec.Body.String(), "ok") + } +} + +func TestReadyGate_AllowedPathsPassThrough(t *testing.T) { + var called atomic.Bool + inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called.Store(true) + w.WriteHeader(http.StatusOK) + w.Write([]byte("setup page")) + }) + + gate := NewReadyGate(inner, []string{"/setup", "/healthz"}) + // Not ready, but /setup should pass through + + tests := []struct { + path string + wantStatus int + wantBody string + }{ + {"/setup", http.StatusOK, "setup page"}, + {"/setup/callback", http.StatusOK, "setup page"}, + {"/healthz", http.StatusOK, "setup page"}, + {"/other", http.StatusServiceUnavailable, ""}, + {"/webhook", http.StatusServiceUnavailable, ""}, + } + + for _, tt := range tests { + t.Run(tt.path, func(t *testing.T) { + called.Store(false) + req := httptest.NewRequest(http.MethodGet, tt.path, nil) + rec := httptest.NewRecorder() + + gate.ServeHTTP(rec, req) + + if rec.Code != tt.wantStatus { + t.Errorf("Status = %d, want %d", rec.Code, tt.wantStatus) + } + + if tt.wantStatus == http.StatusOK && !called.Load() { + t.Error("Inner handler was not called for allowed path") + } + }) + } +} + +func TestReadyGate_IsReady(t *testing.T) { + gate := NewReadyGate(nil, nil) + + if gate.IsReady() { + t.Error("IsReady() = true, want false initially") + } + + gate.SetReady() + + if !gate.IsReady() { + t.Error("IsReady() = false, want true after SetReady()") + } +} + +func TestReadyGate_SetHandler(t *testing.T) { + gate := NewReadyGate(nil, nil) + + // Initially no handler + req := httptest.NewRequest(http.MethodGet, "/test", nil) + rec := httptest.NewRecorder() + gate.ServeHTTP(rec, req) + + if rec.Code != http.StatusServiceUnavailable { + t.Errorf("Status = %d, want %d when not ready", rec.Code, http.StatusServiceUnavailable) + } + + // Set handler and mark ready + gate.SetHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("dynamic handler")) + })) + gate.SetReady() + + req = httptest.NewRequest(http.MethodGet, "/test", nil) + rec = httptest.NewRecorder() + gate.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("Status = %d, want %d after SetReady()", rec.Code, http.StatusOK) + } + if rec.Body.String() != "dynamic handler" { + t.Errorf("Body = %q, want %q", rec.Body.String(), "dynamic handler") + } +} + +func TestNewConfigFromEnv_Defaults(t *testing.T) { + cfg := NewConfigFromEnv() + + if cfg.MaxRetries != DefaultMaxRetries { + t.Errorf("MaxRetries = %d, want %d", cfg.MaxRetries, DefaultMaxRetries) + } + if cfg.RetryInterval != DefaultRetryInterval { + t.Errorf("RetryInterval = %v, want %v", cfg.RetryInterval, DefaultRetryInterval) + } +} + +func TestReloader_Trigger(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + gate := NewReadyGate(nil, nil) + + var reloadCount atomic.Int32 + reloadFunc := func(ctx context.Context) error { + reloadCount.Add(1) + return nil + } + + reloader := NewReloader(ctx, gate, reloadFunc) + reloader.Start() + + // Trigger a reload + reloader.Trigger() + + // Wait for reload to complete + time.Sleep(50 * time.Millisecond) + + if got := reloadCount.Load(); got != 1 { + t.Errorf("Reload count = %d, want 1", got) + } +} + +func TestReloader_MultipleTriggers(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + gate := NewReadyGate(nil, nil) + + var reloadCount atomic.Int32 + reloadFunc := func(ctx context.Context) error { + reloadCount.Add(1) + // Simulate some work + time.Sleep(20 * time.Millisecond) + return nil + } + + reloader := NewReloader(ctx, gate, reloadFunc) + reloader.Start() + + // Trigger multiple reloads rapidly + reloader.Trigger() + reloader.Trigger() + reloader.Trigger() + + // Wait for reloads to complete + time.Sleep(100 * time.Millisecond) + + // Only one or two reloads should have occurred due to deduplication + got := reloadCount.Load() + if got > 2 { + t.Errorf("Reload count = %d, want <= 2 (due to deduplication)", got) + } + if got < 1 { + t.Errorf("Reload count = %d, want >= 1", got) + } +} + +func TestReloader_ReloadError(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + gate := NewReadyGate(nil, nil) + + var reloadCount atomic.Int32 + reloadFunc := func(ctx context.Context) error { + reloadCount.Add(1) + return errors.New("reload failed") + } + + reloader := NewReloader(ctx, gate, reloadFunc) + reloader.Start() + + // Trigger a reload + reloader.Trigger() + + // Wait for reload to complete + time.Sleep(50 * time.Millisecond) + + // Reload should have been attempted + if got := reloadCount.Load(); got != 1 { + t.Errorf("Reload count = %d, want 1", got) + } +} + +func TestReloader_ContextCancellation(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + gate := NewReadyGate(nil, nil) + + var reloadCount atomic.Int32 + reloadFunc := func(ctx context.Context) error { + reloadCount.Add(1) + return nil + } + + reloader := NewReloader(ctx, gate, reloadFunc) + done := reloader.Start() + + // Cancel context + cancel() + + // Wait for reloader to stop + select { + case <-done: + // Good, reloader stopped + case <-time.After(100 * time.Millisecond): + t.Error("Reloader did not stop after context cancellation") + } +} + +func TestGlobalReloader(t *testing.T) { + // Clear any existing global reloader + SetGlobalReloader(nil) + + // TriggerReload should be a no-op when no global reloader is set + TriggerReload() // Should not panic + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + gate := NewReadyGate(nil, nil) + + var reloadCount atomic.Int32 + reloadFunc := func(ctx context.Context) error { + reloadCount.Add(1) + return nil + } + + reloader := NewReloader(ctx, gate, reloadFunc) + reloader.Start() + + // Set global reloader + SetGlobalReloader(reloader) + + // Now TriggerReload should work + TriggerReload() + + // Wait for reload to complete + time.Sleep(50 * time.Millisecond) + + if got := reloadCount.Load(); got != 1 { + t.Errorf("Reload count = %d, want 1", got) + } + + // Clean up + SetGlobalReloader(nil) +} diff --git a/configwait/reloader.go b/configwait/reloader.go new file mode 100644 index 0000000..a1541fd --- /dev/null +++ b/configwait/reloader.go @@ -0,0 +1,152 @@ +// Copyright 2025 CruxStack +// SPDX-License-Identifier: MIT + +package configwait + +import ( + "context" + "os" + "os/signal" + "sync" + "syscall" + + "github.com/chainguard-dev/clog" +) + +// ReloadFunc is called when a reload is triggered. +type ReloadFunc func(ctx context.Context) error + +// Reloader manages configuration reloading via SIGHUP or programmatic triggers. +type Reloader struct { + gate *ReadyGate + reloadFunc ReloadFunc + ctx context.Context + + mu sync.Mutex + reloading bool + reloadCh chan struct{} +} + +// NewReloader creates a Reloader that calls reloadFunc when triggered. +func NewReloader(ctx context.Context, gate *ReadyGate, reloadFunc ReloadFunc) *Reloader { + return &Reloader{ + gate: gate, + reloadFunc: reloadFunc, + ctx: ctx, + reloadCh: make(chan struct{}, 1), + } +} + +// Start begins listening for SIGHUP signals and programmatic triggers. +// Returns a channel that closes when the reloader stops. +func (r *Reloader) Start() <-chan struct{} { + done := make(chan struct{}) + log := clog.FromContext(r.ctx) + + sighupCh := make(chan os.Signal, 1) + signal.Notify(sighupCh, syscall.SIGHUP) + + go func() { + defer close(done) + defer signal.Stop(sighupCh) + + for { + select { + case <-r.ctx.Done(): + return + case <-sighupCh: + log.Infof("[reloader] received SIGHUP, triggering reload") + r.doReload() + case <-r.reloadCh: + log.Infof("[reloader] programmatic reload triggered") + r.doReload() + } + } + }() + + return done +} + +// Trigger requests a configuration reload. Safe to call from any goroutine. +func (r *Reloader) Trigger() { + log := clog.FromContext(r.ctx) + + select { + case r.reloadCh <- struct{}{}: + default: + log.Infof("[reloader] reload already pending, ignoring trigger") + } +} + +// doReload performs the reload operation. +func (r *Reloader) doReload() { + log := clog.FromContext(r.ctx) + + r.mu.Lock() + if r.reloading { + r.mu.Unlock() + log.Infof("[reloader] reload already in progress, skipping") + return + } + r.reloading = true + r.mu.Unlock() + + defer func() { + r.mu.Lock() + r.reloading = false + r.mu.Unlock() + }() + + log.Infof("[reloader] starting configuration reload...") + + if err := r.reloadFunc(r.ctx); err != nil { + log.Errorf("[reloader] reload failed: %v", err) + return + } + + log.Infof("[reloader] configuration reloaded successfully") +} + +var ( + globalReloaderMu sync.RWMutex + globalReloader *Reloader + + reloadCounter int64 + reloadCounterMu sync.Mutex +) + +// SetGlobalReloader sets the global reloader instance. +func SetGlobalReloader(r *Reloader) { + globalReloaderMu.Lock() + defer globalReloaderMu.Unlock() + globalReloader = r +} + +// TriggerReload triggers a reload using the global reloader (no-op if unset). +func TriggerReload() { + reloadCounterMu.Lock() + reloadCounter++ + reloadCounterMu.Unlock() + + globalReloaderMu.RLock() + r := globalReloader + globalReloaderMu.RUnlock() + + if r != nil { + r.Trigger() + } +} + +// GetReloadCount returns the number of times TriggerReload has been called. +func GetReloadCount() int64 { + reloadCounterMu.Lock() + defer reloadCounterMu.Unlock() + return reloadCounter +} + +// ResetReloadCounter resets the reload counter to zero. +func ResetReloadCounter() { + reloadCounterMu.Lock() + defer reloadCounterMu.Unlock() + reloadCounter = 0 +} diff --git a/examples/simple/.env.example b/examples/simple/.env.example new file mode 100644 index 0000000..880339c --- /dev/null +++ b/examples/simple/.env.example @@ -0,0 +1,52 @@ +# Simple GitHub App Example - Configuration +# Copy this file to .env and customize as needed + +# ------------------------------------------------------------------------------ +# Logging +# ------------------------------------------------------------------------------ + +# Log output format: "text" (default) or "json" +LOG_FORMAT=text + +# ------------------------------------------------------------------------------ +# GitHub App Installer +# ------------------------------------------------------------------------------ + +# Enable the web-based installer UI at /setup +# Set to "false" after creating your GitHub App +GITHUB_APP_INSTALLER_ENABLED=true + +# GitHub base URL (change for GitHub Enterprise Server) +GITHUB_URL=https://github.com + +# Organization to create the app under (leave empty for personal account) +GITHUB_ORG= + +# ------------------------------------------------------------------------------ +# Server +# ------------------------------------------------------------------------------ + +# HTTP port to listen on +PORT=8080 + +# ------------------------------------------------------------------------------ +# Config Wait (startup behavior) +# ------------------------------------------------------------------------------ + +# Maximum retry attempts when waiting for configuration +CONFIG_WAIT_MAX_RETRIES=30 + +# Duration between retry attempts +CONFIG_WAIT_RETRY_INTERVAL=2s + +# ------------------------------------------------------------------------------ +# GitHub App Credentials (auto-populated by installer) +# These are filled in automatically when you create the app via /setup +# ------------------------------------------------------------------------------ + +# GITHUB_APP_ID= +# GITHUB_APP_SLUG= +# GITHUB_WEBHOOK_SECRET= +# GITHUB_CLIENT_ID= +# GITHUB_CLIENT_SECRET= +# GITHUB_APP_PRIVATE_KEY= diff --git a/examples/simple/Dockerfile b/examples/simple/Dockerfile new file mode 100644 index 0000000..0898fc9 --- /dev/null +++ b/examples/simple/Dockerfile @@ -0,0 +1,44 @@ +# syntax=docker/dockerfile:1 + +# ------------------------------------------------------------------ builder --- +FROM golang:1.25-alpine AS builder + +RUN apk add --no-cache git ca-certificates + +WORKDIR /build + +# Copy go.mod files first for better caching +COPY go.mod go.sum ./ +RUN go mod download + +# Copy source code +COPY . . + +# Build the example binary +WORKDIR /build/examples/simple +RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w" -o /out/app . + +# ------------------------------------------------------------------ runtime --- +FROM alpine:3.21 + +RUN apk add --no-cache ca-certificates + +# Create non-root user +RUN addgroup -g 1000 app && \ + adduser -u 1000 -G app -s /bin/sh -D app + +# Create data directory for .env file storage +RUN mkdir -p /data && chown app:app /data + +COPY --from=builder /out/app /usr/local/bin/app + +USER app +WORKDIR /data + +EXPOSE 8080 + +ENV PORT=8080 +ENV STORAGE_MODE=envfile +ENV STORAGE_DIR=/data/.env + +CMD ["app"] diff --git a/examples/simple/README.md b/examples/simple/README.md new file mode 100644 index 0000000..0cf8995 --- /dev/null +++ b/examples/simple/README.md @@ -0,0 +1,163 @@ +# Simple GitHub App Example + +A minimal example demonstrating a GitHub App with webhook handling using Docker +and Docker Compose. + +## Features + +- Web-based GitHub App installer at `/setup` +- Webhook endpoint at `/webhook` that logs received events +- Configurable logging format (text or JSON via slog) +- Credentials stored in `.env` file (envfile storage backend) +- Health check endpoint at `/healthz` + +## Prerequisites + +- Docker Engine 24.0+ and Docker Compose v2.20+ +- [ngrok](https://ngrok.com/) or similar tunnel for webhook delivery +- GitHub account with permission to create GitHub Apps + +## Quick Start + +### 1. Start the Application + +```bash +cd examples/simple +docker compose up --build +``` + +The app starts with the installer enabled by default. + +### 2. Expose the Application + +GitHub needs to reach your webhook endpoint. Use ngrok or a similar tunnel: + +```bash +ngrok http 8080 +``` + +Copy the HTTPS forwarding URL (e.g., `https://abc123.ngrok-free.app`). + +### 3. Create the GitHub App + +1. Open your ngrok URL with `/setup` path in a browser: + ``` + https://abc123.ngrok-free.app/setup + ``` + +2. Enter your desired app name + +3. Update the webhook URL to your ngrok URL + `/webhook`: + ``` + https://abc123.ngrok-free.app/webhook + ``` + +4. Click "Create GitHub App" and authorize the app creation + +5. The installer saves credentials automatically and the app reloads + +### 4. Install the App + +After creation, click "Install App" to grant the app access to your +repositories. Select which repositories should send webhook events. + +### 5. Test Webhooks + +Push a commit or create a pull request in an installed repository. You should +see log output like: + +``` +level=INFO msg="received webhook" event=push action="" delivery_id=abc123 repository=owner/repo sender=username payload_size=1234 +``` + +## Configuration + +Copy `.env.example` to `.env` to customize settings: + +```bash +cp .env.example .env +``` + +### Environment Variables + +| Variable | Description | Default | +|--------------------------------|------------------------------|---------------------| +| `LOG_FORMAT` | Log format: `text` or `json` | `text` | +| `PORT` | HTTP port | `8080` | +| `GITHUB_APP_INSTALLER_ENABLED` | Enable installer UI | `true` | +| `GITHUB_URL` | GitHub base URL | `https://github.com`| +| `GITHUB_ORG` | Organization for app | - | + +### Disabling the Installer + +After setup, disable the installer for security: + +1. Click "Disable Setup & Continue" in the success page, or +2. Set `GITHUB_APP_INSTALLER_ENABLED=false` in `.env` and restart + +## Endpoints + +| Path | Description | +|------------|--------------------------------| +| `/setup` | GitHub App installer (enabled) | +| `/webhook` | Webhook receiver | +| `/healthz` | Health check | + +## Logs + +The application uses Go's `slog` package with configurable output format. + +**Text format** (default): +``` +level=INFO msg="received webhook" event=push repository=owner/repo +``` + +**JSON format** (`LOG_FORMAT=json`): +```json +{"level":"INFO","msg":"received webhook","event":"push","repository":"owner/repo"} +``` + +## Architecture + +``` + +-------------------+ + | ngrok | + | (public HTTPS) | + +---------+---------+ + | + v ++-------------------------------+-------------------------------+ +| App Container | +| | +| /setup --> Installer (creates GitHub App) | +| /webhook --> Webhook Handler (logs events) | +| /healthz --> Health Check | +| | +| Storage: /data/.env (persisted via Docker volume) | ++---------------------------------------------------------------+ +``` + +## Troubleshooting + +### App not receiving webhooks + +1. Verify ngrok is running and the URL hasn't changed +2. Check webhook URL in GitHub App settings matches your ngrok URL + `/webhook` +3. View webhook deliveries in GitHub App settings for error details + +### Configuration not loading + +Check logs for retry messages: +```bash +docker compose logs -f +``` + +The app retries configuration loading for 60 seconds by default. + +### Resetting the app + +To start fresh, remove the Docker volume: +```bash +docker compose down -v +docker compose up --build +``` diff --git a/examples/simple/compose.yaml b/examples/simple/compose.yaml new file mode 100644 index 0000000..6578b9c --- /dev/null +++ b/examples/simple/compose.yaml @@ -0,0 +1,33 @@ +services: + app: + build: + context: ../.. + dockerfile: examples/simple/Dockerfile + ports: + - "${PORT:-8080}:8080" + volumes: + # persist credentials across restarts + - app-data:/data + environment: + - LOG_FORMAT=${LOG_FORMAT:-text} + - GITHUB_APP_INSTALLER_ENABLED=${GITHUB_APP_INSTALLER_ENABLED:-true} + - GITHUB_URL=${GITHUB_URL:-https://github.com} + - GITHUB_ORG=${GITHUB_ORG:-} + - STORAGE_MODE=envfile + - STORAGE_DIR=/data/.env + - CONFIG_WAIT_MAX_RETRIES=${CONFIG_WAIT_MAX_RETRIES:-30} + - CONFIG_WAIT_RETRY_INTERVAL=${CONFIG_WAIT_RETRY_INTERVAL:-2s} + env_file: + # load .env file if it exists (for pre-configured credentials) + - path: .env + required: false + restart: unless-stopped + healthcheck: + test: ["CMD", "wget", "-q", "--spider", "http://localhost:8080/healthz"] + interval: 10s + timeout: 5s + retries: 3 + start_period: 5s + +volumes: + app-data: diff --git a/examples/simple/main.go b/examples/simple/main.go new file mode 100644 index 0000000..05bd9f9 --- /dev/null +++ b/examples/simple/main.go @@ -0,0 +1,273 @@ +// Copyright 2025 CruxStack +// SPDX-License-Identifier: MIT + +// Example demonstrating a GitHub App with webhook handling. +package main + +import ( + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "os" + "os/signal" + "strings" + "time" + + "github.com/cruxstack/github-app-setup-go/configstore" + "github.com/cruxstack/github-app-setup-go/configwait" + "github.com/cruxstack/github-app-setup-go/installer" +) + +const ( + defaultPort = 8080 + defaultReadHeaderTimeout = 10 * time.Second + defaultShutdownTimeout = 30 * time.Second +) + +func main() { + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) + defer cancel() + + log := setupLogger() + ctx = withLogger(ctx, log) + + port := defaultPort + if p := os.Getenv("PORT"); p != "" { + fmt.Sscanf(p, "%d", &port) + } + + allowedPaths := []string{"/healthz"} + installerEnabled := configstore.InstallerEnabled() + if installerEnabled { + allowedPaths = append(allowedPaths, "/setup", "/callback", "/") + } + + gate := configwait.NewReadyGate(nil, allowedPaths) + mux := http.NewServeMux() + + mux.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) { + if gate.IsReady() { + w.WriteHeader(http.StatusOK) + w.Write([]byte("ok")) + } else { + w.WriteHeader(http.StatusServiceUnavailable) + w.Write([]byte("not ready")) + } + }) + + if installerEnabled { + store, err := configstore.NewFromEnv() + if err != nil { + log.Error("failed to create config store", "error", err) + os.Exit(1) + } + + manifest := installer.Manifest{ + URL: "https://github.com/cruxstack/github-app-setup-go", + Public: false, + DefaultPerms: map[string]string{ + "contents": "read", + "pull_requests": "read", + }, + DefaultEvents: []string{ + "push", + "pull_request", + }, + } + + installerCfg := installer.NewConfigFromEnv() + installerCfg.Store = store + installerCfg.Manifest = manifest + installerCfg.AppDisplayName = "Simple Webhook App" + + installerHandler, err := installer.New(installerCfg) + if err != nil { + log.Error("failed to create installer handler", "error", err) + os.Exit(1) + } + + mux.Handle("/setup", installerHandler) + mux.Handle("/setup/", installerHandler) + mux.Handle("/callback", installerHandler) + mux.Handle("/", installerHandler) + + log.Info("installer enabled, visit /setup to create GitHub App") + } + + gate.SetHandler(mux) + + srv := &http.Server{ + Addr: fmt.Sprintf(":%d", port), + ReadHeaderTimeout: defaultReadHeaderTimeout, + Handler: gate, + } + + log.Info("starting HTTP server", "port", port, "installer_enabled", installerEnabled) + + go func() { + if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Error("server error", "error", err) + os.Exit(1) + } + }() + + go func() { + waitCfg := configwait.NewConfigFromEnv() + + err := configwait.Wait(ctx, waitCfg, func(ctx context.Context) error { + return loadConfig(ctx, log, mux) + }) + + if err != nil { + log.Error("failed to load configuration after retries", "error", err) + os.Exit(1) + } + + log.Info("configuration loaded, service is ready") + gate.SetReady() + + reloader := configwait.NewReloader(ctx, gate, func(ctx context.Context) error { + return loadConfig(ctx, log, mux) + }) + configwait.SetGlobalReloader(reloader) + reloader.Start() + + log.Info("configuration reloader started (send SIGHUP to reload)") + }() + + <-ctx.Done() + log.Info("shutting down server...") + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), defaultShutdownTimeout) + defer shutdownCancel() + + if err := srv.Shutdown(shutdownCtx); err != nil { + log.Error("server shutdown error", "error", err) + os.Exit(1) + } +} + +// loadConfig loads configuration and sets up the webhook handler. +func loadConfig(_ context.Context, log *slog.Logger, mux *http.ServeMux) error { + webhookSecret := os.Getenv(configstore.EnvGitHubWebhookSecret) + if webhookSecret == "" { + return fmt.Errorf("%s is not set", configstore.EnvGitHubWebhookSecret) + } + + appID := os.Getenv(configstore.EnvGitHubAppID) + if appID == "" { + return fmt.Errorf("%s is not set", configstore.EnvGitHubAppID) + } + + log.Info("loaded GitHub App configuration", "app_id", appID) + + mux.HandleFunc("/webhook", webhookHandler(log, webhookSecret)) + + return nil +} + +// webhookHandler returns an HTTP handler that processes GitHub webhooks. +func webhookHandler(log *slog.Logger, secret string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + body, err := io.ReadAll(r.Body) + if err != nil { + log.Error("failed to read webhook body", "error", err) + http.Error(w, "failed to read body", http.StatusBadRequest) + return + } + defer r.Body.Close() + + signature := r.Header.Get("X-Hub-Signature-256") + if !validateSignature(body, signature, secret) { + log.Warn("webhook signature validation failed", + "remote_addr", r.RemoteAddr, + "has_signature", signature != "", + ) + http.Error(w, "invalid signature", http.StatusUnauthorized) + return + } + + eventType := r.Header.Get("X-GitHub-Event") + deliveryID := r.Header.Get("X-GitHub-Delivery") + + var payload struct { + Action string `json:"action"` + Repository struct { + FullName string `json:"full_name"` + } `json:"repository"` + Sender struct { + Login string `json:"login"` + } `json:"sender"` + } + if err := json.Unmarshal(body, &payload); err != nil { + log.Warn("failed to parse webhook payload", "error", err) + } + + log.Info("received webhook", + "event", eventType, + "action", payload.Action, + "delivery_id", deliveryID, + "repository", payload.Repository.FullName, + "sender", payload.Sender.Login, + "payload_size", len(body), + ) + + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status":"ok"}`)) + } +} + +// validateSignature validates the GitHub webhook signature. +func validateSignature(payload []byte, signature, secret string) bool { + if signature == "" || secret == "" { + return false + } + + if !strings.HasPrefix(signature, "sha256=") { + return false + } + sig := strings.TrimPrefix(signature, "sha256=") + + mac := hmac.New(sha256.New, []byte(secret)) + mac.Write(payload) + expected := hex.EncodeToString(mac.Sum(nil)) + + return hmac.Equal([]byte(sig), []byte(expected)) +} + +// setupLogger creates a slog.Logger based on LOG_FORMAT environment variable. +func setupLogger() *slog.Logger { + format := strings.ToLower(os.Getenv("LOG_FORMAT")) + + var handler slog.Handler + opts := &slog.HandlerOptions{ + Level: slog.LevelInfo, + } + + switch format { + case "json": + handler = slog.NewJSONHandler(os.Stderr, opts) + default: + handler = slog.NewTextHandler(os.Stderr, opts) + } + + return slog.New(handler) +} + +type loggerKey struct{} + +// withLogger adds a logger to the context. +func withLogger(ctx context.Context, log *slog.Logger) context.Context { + return context.WithValue(ctx, loggerKey{}, log) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..e3aa7f5 --- /dev/null +++ b/go.mod @@ -0,0 +1,26 @@ +module github.com/cruxstack/github-app-setup-go + +go 1.25 + +require ( + github.com/aws/aws-sdk-go-v2 v1.41.0 + github.com/aws/aws-sdk-go-v2/config v1.32.5 + github.com/aws/aws-sdk-go-v2/service/ssm v1.67.7 + github.com/chainguard-dev/clog v1.8.0 + gopkg.in/yaml.v3 v3.0.1 +) + +require ( + github.com/aws/aws-sdk-go-v2/credentials v1.19.5 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.16 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.16 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.16 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.16 // indirect + github.com/aws/aws-sdk-go-v2/service/signin v1.0.4 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.30.7 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.12 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.41.5 // indirect + github.com/aws/smithy-go v1.24.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..8e9ab60 --- /dev/null +++ b/go.sum @@ -0,0 +1,36 @@ +github.com/aws/aws-sdk-go-v2 v1.41.0 h1:tNvqh1s+v0vFYdA1xq0aOJH+Y5cRyZ5upu6roPgPKd4= +github.com/aws/aws-sdk-go-v2 v1.41.0/go.mod h1:MayyLB8y+buD9hZqkCW3kX1AKq07Y5pXxtgB+rRFhz0= +github.com/aws/aws-sdk-go-v2/config v1.32.5 h1:pz3duhAfUgnxbtVhIK39PGF/AHYyrzGEyRD9Og0QrE8= +github.com/aws/aws-sdk-go-v2/config v1.32.5/go.mod h1:xmDjzSUs/d0BB7ClzYPAZMmgQdrodNjPPhd6bGASwoE= +github.com/aws/aws-sdk-go-v2/credentials v1.19.5 h1:xMo63RlqP3ZZydpJDMBsH9uJ10hgHYfQFIk1cHDXrR4= +github.com/aws/aws-sdk-go-v2/credentials v1.19.5/go.mod h1:hhbH6oRcou+LpXfA/0vPElh/e0M3aFeOblE1sssAAEk= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.16 h1:80+uETIWS1BqjnN9uJ0dBUaETh+P1XwFy5vwHwK5r9k= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.16/go.mod h1:wOOsYuxYuB/7FlnVtzeBYRcjSRtQpAW0hCP7tIULMwo= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.16 h1:rgGwPzb82iBYSvHMHXc8h9mRoOUBZIGFgKb9qniaZZc= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.16/go.mod h1:L/UxsGeKpGoIj6DxfhOWHWQ/kGKcd4I1VncE4++IyKA= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.16 h1:1jtGzuV7c82xnqOVfx2F0xmJcOw5374L7N6juGW6x6U= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.16/go.mod h1:M2E5OQf+XLe+SZGmmpaI2yy+J326aFf6/+54PoxSANc= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 h1:WKuaxf++XKWlHWu9ECbMlha8WOEGm0OUEZqm4K/Gcfk= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4/go.mod h1:ZWy7j6v1vWGmPReu0iSGvRiise4YI5SkR3OHKTZ6Wuc= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4 h1:0ryTNEdJbzUCEWkVXEXoqlXV72J5keC1GvILMOuD00E= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4/go.mod h1:HQ4qwNZh32C3CBeO6iJLQlgtMzqeG17ziAA/3KDJFow= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.16 h1:oHjJHeUy0ImIV0bsrX0X91GkV5nJAyv1l1CC9lnO0TI= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.16/go.mod h1:iRSNGgOYmiYwSCXxXaKb9HfOEj40+oTKn8pTxMlYkRM= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.4 h1:HpI7aMmJ+mm1wkSHIA2t5EaFFv5EFYXePW30p1EIrbQ= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.4/go.mod h1:C5RdGMYGlfM0gYq/tifqgn4EbyX99V15P2V3R+VHbQU= +github.com/aws/aws-sdk-go-v2/service/ssm v1.67.7 h1:0q42w8/mywPCzQD1IoWIBUCYfBJc5+fLwtZNpHffBSM= +github.com/aws/aws-sdk-go-v2/service/ssm v1.67.7/go.mod h1:urlU9nfKJEfi0+8T9luB3f3Y0UnomH/yxI7tTrfH9es= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.7 h1:eYnlt6QxnFINKzwxP5/Ucs1vkG7VT3Iezmvfgc2waUw= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.7/go.mod h1:+fWt2UHSb4kS7Pu8y+BMBvJF0EWx+4H0hzNwtDNRTrg= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.12 h1:AHDr0DaHIAo8c9t1emrzAlVDFp+iMMKnPdYy6XO4MCE= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.12/go.mod h1:GQ73XawFFiWxyWXMHWfhiomvP3tXtdNar/fi8z18sx0= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.5 h1:SciGFVNZ4mHdm7gpD1dgZYnCuVdX1s+lFTg4+4DOy70= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.5/go.mod h1:iW40X4QBmUxdP+fZNOpfmkdMZqsovezbAeO+Ubiv2pk= +github.com/aws/smithy-go v1.24.0 h1:LpilSUItNPFr1eY85RYgTIg5eIEPtvFbskaFcmmIUnk= +github.com/aws/smithy-go v1.24.0/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= +github.com/chainguard-dev/clog v1.8.0 h1:frlTMEdg3XQR+ioQ6O9i92uigY8GTUcWKpuCFkhcCHA= +github.com/chainguard-dev/clog v1.8.0/go.mod h1:5MQOZi+Iu7fV7GcJG8ag8rCB5elEOpqRMKEASgnGVdo= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/installer/installer.go b/installer/installer.go new file mode 100644 index 0000000..140c2c3 --- /dev/null +++ b/installer/installer.go @@ -0,0 +1,461 @@ +// Copyright 2025 CruxStack +// SPDX-License-Identifier: MIT + +// Package installer provides a web-based installer for creating GitHub Apps +// using the GitHub App Manifest flow. +package installer + +import ( + "bytes" + "context" + "embed" + "encoding/json" + "fmt" + "html/template" + "io" + "net/http" + "os" + "strings" + "time" + + "github.com/chainguard-dev/clog" + + "github.com/cruxstack/github-app-setup-go/configstore" + "github.com/cruxstack/github-app-setup-go/configwait" +) + +//go:embed templates/* +var templateFS embed.FS + +var indexTemplate = template.Must(template.ParseFS(templateFS, "templates/index.html")) +var successTemplate = template.Must(template.ParseFS(templateFS, "templates/success.html")) + +const ( + httpClientTimeout = 30 * time.Second + EnvGitHubURL = "GITHUB_URL" + EnvGitHubOrg = "GITHUB_ORG" + disableSetupPath = "/setup/disable" +) + +// CredentialsSavedFunc is called after credentials are saved. +type CredentialsSavedFunc func(ctx context.Context, creds *configstore.AppCredentials) error + +// Config holds the installer configuration. +type Config struct { + Store configstore.Store + Manifest Manifest + AppDisplayName string + GitHubURL string + GitHubOrg string + RedirectURL string + WebhookURL string + OnCredentialsSaved CredentialsSavedFunc +} + +// NewConfigFromEnv creates a Config from environment variables. +func NewConfigFromEnv() Config { + return Config{ + GitHubURL: configstore.GetEnvDefault(EnvGitHubURL, "https://github.com"), + GitHubOrg: os.Getenv(EnvGitHubOrg), + } +} + +// Handler handles the GitHub App manifest installation flow. +type Handler struct { + config Config +} + +type indexTemplateData struct { + AppDisplayName string + GitHubURL string + GitHubOrg string + FormActionURL string + ManifestJSON template.JS + WebhookURL string + NeedsWebhook bool + DefaultAppName string +} + +type successTemplateData struct { + AppDisplayName string + AppID int64 + AppSlug string + HTMLURL string + InstallURL string + DisableActionURL string + InstallerDisabled bool +} + +// New creates a new installer Handler with the given configuration. +func New(cfg Config) (*Handler, error) { + if cfg.Store == nil { + return nil, fmt.Errorf("store is required") + } + if cfg.GitHubURL == "" { + cfg.GitHubURL = "https://github.com" + } + if cfg.AppDisplayName == "" { + cfg.AppDisplayName = "GitHub App" + } + return &Handler{config: cfg}, nil +} + +// ServeHTTP implements http.Handler. +func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + path := r.URL.Path + + switch { + case (r.Method == http.MethodGet || r.Method == http.MethodHead) && (path == "/" || path == ""): + h.handleRoot(w, r) + case (r.Method == http.MethodGet || r.Method == http.MethodHead) && (path == "/setup" || path == "/setup/"): + h.handleIndex(w, r) + case (r.Method == http.MethodGet || r.Method == http.MethodHead) && path == "/callback": + h.handleCallback(w, r) + + case r.Method == http.MethodPost && (path == disableSetupPath || path == disableSetupPath+"/"): + h.handleDisable(w, r) + default: + http.NotFound(w, r) + } +} + +// handleRoot redirects to /setup if enabled, otherwise returns 404. +func (h *Handler) handleRoot(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + log := clog.FromContext(ctx) + + status, err := h.config.Store.Status(ctx) + if err != nil { + log.Errorf("[installer] failed to read installer status: %v", err) + http.Error(w, "Failed to load installer status", http.StatusInternalServerError) + return + } + + if status != nil && status.InstallerDisabled { + http.NotFound(w, r) + return + } + + http.Redirect(w, r, "/setup", http.StatusFound) +} + +// handleIndex serves the main page. +func (h *Handler) handleIndex(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + log := clog.FromContext(ctx) + + status, err := h.config.Store.Status(ctx) + + if err != nil { + log.Errorf("[installer] failed to read installer status: %v", err) + http.Error(w, "Failed to load installer status", http.StatusInternalServerError) + return + } + if status != nil && status.Registered { + data := h.successDataFromStatus(status) + h.renderSuccess(w, r, data) + return + } + + redirectURL := h.config.RedirectURL + if redirectURL == "" { + redirectURL = getBaseURL(ctx, r) + log.Infof("[installer] auto-detected redirect url: url=%s host=%s forwarded_host=%s", + redirectURL, r.Host, r.Header.Get("X-Forwarded-Host")) + } + + webhookURL := h.config.WebhookURL + if webhookURL == "" { + webhookURL = r.FormValue("webhook_url") + if webhookURL == "" { + webhookURL = getBaseURL(ctx, r) + "/webhook" + log.Infof("[installer] auto-detected webhook url: url=%s", webhookURL) + } + } + + manifest := h.config.Manifest.Clone() + if manifest == nil { + manifest = &Manifest{} + } + manifest.RedirectURL = redirectURL + "/callback" + manifest.HookAttributes.URL = webhookURL + manifest.HookAttributes.Active = webhookURL != "" + + log.Infof("[installer] manifest redirect_url: %s", manifest.RedirectURL) + manifestJSON, err := json.Marshal(manifest) + if err != nil { + http.Error(w, "Failed to generate manifest", http.StatusInternalServerError) + return + } + + var formActionURL string + if h.config.GitHubOrg != "" { + formActionURL = fmt.Sprintf("%s/organizations/%s/settings/apps/new", + h.config.GitHubURL, h.config.GitHubOrg) + } else { + formActionURL = fmt.Sprintf("%s/settings/apps/new", h.config.GitHubURL) + } + + defaultAppName := manifest.Name + if defaultAppName == "" { + defaultAppName = strings.ToLower(strings.ReplaceAll(h.config.AppDisplayName, " ", "-")) + } + + data := indexTemplateData{ + AppDisplayName: h.config.AppDisplayName, + GitHubURL: h.config.GitHubURL, + GitHubOrg: h.config.GitHubOrg, + FormActionURL: formActionURL, + ManifestJSON: template.JS(manifestJSON), + WebhookURL: webhookURL, + NeedsWebhook: h.config.WebhookURL == "", + DefaultAppName: defaultAppName, + } + + var buf bytes.Buffer + if err := indexTemplate.Execute(&buf, data); err != nil { + log.Errorf("[installer] failed to render index template: %v", err) + http.Error(w, "Failed to render page", http.StatusInternalServerError) + return + } + setSecurityHeaders(w) + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusOK) + if _, err := buf.WriteTo(w); err != nil { + log.Errorf("[installer] failed to write response: %v", err) + } +} + +// handleCallback handles the GitHub redirect after app creation. +func (h *Handler) handleCallback(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + log := clog.FromContext(ctx) + + code := r.URL.Query().Get("code") + if code == "" { + http.Error(w, "Missing code parameter", http.StatusBadRequest) + return + } + if !isValidOAuthCode(code) { + http.Error(w, "Invalid code parameter", http.StatusBadRequest) + return + } + + var customDomain string + if cookie, err := r.Cookie("custom_domain"); err == nil { + customDomain = cookie.Value + http.SetCookie(w, &http.Cookie{ + Name: "custom_domain", + Value: "", + Path: "/", + MaxAge: -1, + }) + } + + creds, err := h.exchangeCode(ctx, code) + if err != nil { + log.Errorf("[installer] failed to exchange code: %v", err) + http.Error(w, "Failed to exchange code", http.StatusInternalServerError) + return + } + + if creds.CustomFields == nil { + creds.CustomFields = make(map[string]string) + } + + if customDomain != "" { + creds.CustomFields["CUSTOM_DOMAIN"] = customDomain + } + + if h.config.OnCredentialsSaved != nil { + if err := h.config.OnCredentialsSaved(ctx, creds); err != nil { + log.Errorf("[installer] OnCredentialsSaved callback failed: %v", err) + } + } + + if err := h.config.Store.Save(ctx, creds); err != nil { + log.Errorf("[installer] failed to save credentials: %v", err) + http.Error(w, "Failed to save credentials", http.StatusInternalServerError) + return + } + + log.Infof("[installer] successfully created github app: slug=%s app_id=%d", creds.AppSlug, creds.AppID) + + log.Infof("[installer] triggering configuration reload") + configwait.TriggerReload() + + data := h.successDataFromCreds(creds) + h.renderSuccess(w, r, data) +} + +// exchangeCode exchanges the temporary code for app credentials. +func (h *Handler) exchangeCode(ctx context.Context, code string) (*configstore.AppCredentials, error) { + url := fmt.Sprintf("%s/api/v3/app-manifests/%s/conversions", h.config.GitHubURL, code) + + if h.config.GitHubURL == "https://github.com" { + url = fmt.Sprintf("https://api.github.com/app-manifests/%s/conversions", code) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Accept", "application/vnd.github+json") + + client := &http.Client{ + Timeout: httpClientTimeout, + } + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to call GitHub API: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusCreated { + return nil, fmt.Errorf("GitHub API returned %d: %s", resp.StatusCode, string(body)) + } + + var creds configstore.AppCredentials + if err := json.Unmarshal(body, &creds); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + return &creds, nil +} + +func (h *Handler) handleDisable(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + log := clog.FromContext(ctx) + + status, err := h.config.Store.Status(ctx) + if err != nil { + log.Errorf("[installer] failed to check status: %v", err) + http.Error(w, "Failed to check installer status", http.StatusInternalServerError) + return + } + if status == nil || !status.Registered { + http.Error(w, "Cannot disable installer before app is registered", http.StatusBadRequest) + return + } + + if err := h.config.Store.DisableInstaller(ctx); err != nil { + log.Errorf("[installer] failed to disable installer: %v", err) + http.Error(w, "Failed to disable installer", http.StatusInternalServerError) + return + } + + log.Infof("[installer] installer disabled via setup UI") + http.Redirect(w, r, "/healthz", http.StatusSeeOther) +} + +func (h *Handler) successDataFromCreds(creds *configstore.AppCredentials) successTemplateData { + data := successTemplateData{ + AppDisplayName: h.config.AppDisplayName, + AppID: creds.AppID, + AppSlug: creds.AppSlug, + HTMLURL: creds.HTMLURL, + DisableActionURL: disableSetupPath, + } + data.InstallURL = h.installURLFor(creds.AppSlug, creds.HTMLURL) + return data +} + +func (h *Handler) successDataFromStatus(status *configstore.InstallerStatus) successTemplateData { + if status == nil { + return successTemplateData{AppDisplayName: h.config.AppDisplayName} + } + data := successTemplateData{ + AppDisplayName: h.config.AppDisplayName, + AppID: status.AppID, + AppSlug: status.AppSlug, + HTMLURL: status.HTMLURL, + InstallerDisabled: status.InstallerDisabled, + DisableActionURL: disableSetupPath, + } + data.InstallURL = h.installURLFor(status.AppSlug, status.HTMLURL) + return data +} + +func (h *Handler) installURLFor(slug, htmlURL string) string { + if slug != "" { + githubURL := h.config.GitHubURL + if githubURL == "" { + githubURL = "https://github.com" + } + return fmt.Sprintf("%s/apps/%s/installations/new", githubURL, slug) + } + if htmlURL != "" { + trimmed := strings.TrimRight(htmlURL, "/") + return trimmed + "/installations/new" + } + return "" +} + +func (h *Handler) renderSuccess(w http.ResponseWriter, r *http.Request, data successTemplateData) { + ctx := r.Context() + log := clog.FromContext(ctx) + + var buf bytes.Buffer + if err := successTemplate.Execute(&buf, data); err != nil { + log.Errorf("[installer] failed to render success template: %v", err) + http.Error(w, "Failed to render page", http.StatusInternalServerError) + return + } + setSecurityHeaders(w) + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusOK) + if _, err := buf.WriteTo(w); err != nil { + log.Errorf("[installer] failed to write response: %v", err) + } +} + +// setSecurityHeaders sets common security headers. +func setSecurityHeaders(w http.ResponseWriter) { + w.Header().Set("X-Content-Type-Options", "nosniff") + w.Header().Set("X-Frame-Options", "DENY") + w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") +} + +// isValidOAuthCode validates the OAuth code format from GitHub. +func isValidOAuthCode(code string) bool { + if len(code) < 10 || len(code) > 100 { + return false + } + for _, c := range code { + if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9')) { + return false + } + } + return true +} + +// getBaseURL derives the base URL from the request headers. +func getBaseURL(ctx context.Context, r *http.Request) string { + log := clog.FromContext(ctx) + + host := r.Header.Get("X-Forwarded-Host") + if host == "" { + host = r.Host + } + + scheme := r.Header.Get("X-Forwarded-Proto") + if scheme == "" { + scheme = "https" + if host == "localhost" || strings.HasPrefix(host, "localhost:") || + host == "127.0.0.1" || strings.HasPrefix(host, "127.0.0.1:") { + scheme = "http" + } + } else if scheme == "http" && !strings.HasPrefix(host, "localhost") && !strings.HasPrefix(host, "127.0.0.1") { + scheme = "https" + } + + baseURL := scheme + "://" + host + log.Debugf("[installer] getBaseURL: scheme=%s host=%s r.Host=%s X-Forwarded-Proto=%s X-Forwarded-Host=%s result=%s", + scheme, host, r.Host, r.Header.Get("X-Forwarded-Proto"), r.Header.Get("X-Forwarded-Host"), baseURL) + return baseURL +} diff --git a/installer/installer_test.go b/installer/installer_test.go new file mode 100644 index 0000000..c2c7d43 --- /dev/null +++ b/installer/installer_test.go @@ -0,0 +1,503 @@ +// Copyright 2025 CruxStack +// SPDX-License-Identifier: MIT + +package installer + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/cruxstack/github-app-setup-go/configstore" +) + +func TestGetBaseURL(t *testing.T) { + tests := []struct { + name string + host string + xForwardedHost string + xForwardedProto string + want string + }{ + { + name: "localhost defaults to http", + host: "localhost:8080", + want: "http://localhost:8080", + }, + { + name: "localhost without port", + host: "localhost", + want: "http://localhost", + }, + { + name: "127.0.0.1 defaults to http", + host: "127.0.0.1:8080", + want: "http://127.0.0.1:8080", + }, + { + name: "non-localhost defaults to https", + host: "example.com", + want: "https://example.com", + }, + { + name: "non-localhost with port defaults to https", + host: "example.com:8443", + want: "https://example.com:8443", + }, + { + name: "X-Forwarded-Host takes precedence", + host: "internal-lb:8080", + xForwardedHost: "api.example.com", + want: "https://api.example.com", + }, + { + name: "X-Forwarded-Proto http allowed for localhost", + host: "localhost:3000", + xForwardedProto: "http", + want: "http://localhost:3000", + }, + { + name: "X-Forwarded-Proto http upgraded to https for non-localhost", + host: "example.com", + xForwardedProto: "http", + want: "https://example.com", + }, + { + name: "X-Forwarded-Proto https respected", + host: "example.com", + xForwardedProto: "https", + want: "https://example.com", + }, + { + name: "both forwarded headers", + host: "internal:8080", + xForwardedHost: "app.example.com", + xForwardedProto: "https", + want: "https://app.example.com", + }, + { + name: "X-Forwarded-Proto http with localhost forwarded host", + host: "production:8080", + xForwardedHost: "localhost:3000", + xForwardedProto: "http", + want: "http://localhost:3000", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Host = tt.host + + if tt.xForwardedHost != "" { + req.Header.Set("X-Forwarded-Host", tt.xForwardedHost) + } + if tt.xForwardedProto != "" { + req.Header.Set("X-Forwarded-Proto", tt.xForwardedProto) + } + + got := getBaseURL(context.Background(), req) + if got != tt.want { + t.Errorf("getBaseURL() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestInstallURLFor(t *testing.T) { + tests := []struct { + name string + githubURL string + slug string + htmlURL string + want string + }{ + { + name: "slug provided uses GitHub URL", + githubURL: "https://github.com", + slug: "my-app", + htmlURL: "", + want: "https://github.com/apps/my-app/installations/new", + }, + { + name: "slug with GHE URL", + githubURL: "https://github.mycompany.com", + slug: "internal-app", + htmlURL: "", + want: "https://github.mycompany.com/apps/internal-app/installations/new", + }, + { + name: "no slug uses htmlURL", + githubURL: "https://github.com", + slug: "", + htmlURL: "https://github.com/apps/my-app", + want: "https://github.com/apps/my-app/installations/new", + }, + { + name: "htmlURL with trailing slash", + githubURL: "https://github.com", + slug: "", + htmlURL: "https://github.com/apps/my-app/", + want: "https://github.com/apps/my-app/installations/new", + }, + { + name: "neither slug nor htmlURL", + githubURL: "https://github.com", + slug: "", + htmlURL: "", + want: "", + }, + { + name: "slug takes precedence over htmlURL", + githubURL: "https://github.com", + slug: "preferred-app", + htmlURL: "https://github.com/apps/other-app", + want: "https://github.com/apps/preferred-app/installations/new", + }, + { + name: "empty githubURL defaults to github.com", + githubURL: "", + slug: "my-app", + htmlURL: "", + want: "https://github.com/apps/my-app/installations/new", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := &Handler{ + config: Config{ + GitHubURL: tt.githubURL, + }, + } + + got := h.installURLFor(tt.slug, tt.htmlURL) + if got != tt.want { + t.Errorf("installURLFor(%q, %q) = %q, want %q", tt.slug, tt.htmlURL, got, tt.want) + } + }) + } +} + +func TestManifestClone(t *testing.T) { + t.Run("nil manifest returns nil", func(t *testing.T) { + var m *Manifest = nil + got := m.Clone() + if got != nil { + t.Errorf("Clone() = %v, want nil", got) + } + }) + + t.Run("clones all scalar fields", func(t *testing.T) { + original := &Manifest{ + Name: "test-app", + URL: "https://example.com", + RedirectURL: "https://example.com/callback", + Public: true, + HookAttributes: HookAttributes{ + URL: "https://example.com/webhook", + Active: true, + }, + } + + clone := original.Clone() + + if clone.Name != original.Name { + t.Errorf("Clone().Name = %q, want %q", clone.Name, original.Name) + } + if clone.URL != original.URL { + t.Errorf("Clone().URL = %q, want %q", clone.URL, original.URL) + } + if clone.RedirectURL != original.RedirectURL { + t.Errorf("Clone().RedirectURL = %q, want %q", clone.RedirectURL, original.RedirectURL) + } + if clone.Public != original.Public { + t.Errorf("Clone().Public = %v, want %v", clone.Public, original.Public) + } + if clone.HookAttributes.URL != original.HookAttributes.URL { + t.Errorf("Clone().HookAttributes.URL = %q, want %q", clone.HookAttributes.URL, original.HookAttributes.URL) + } + if clone.HookAttributes.Active != original.HookAttributes.Active { + t.Errorf("Clone().HookAttributes.Active = %v, want %v", clone.HookAttributes.Active, original.HookAttributes.Active) + } + }) + + t.Run("DefaultPerms is deep copied", func(t *testing.T) { + original := &Manifest{ + DefaultPerms: map[string]string{ + "contents": "read", + "pull_requests": "write", + }, + } + + clone := original.Clone() + + // Verify values are the same + if clone.DefaultPerms["contents"] != "read" { + t.Error("Clone().DefaultPerms missing expected value") + } + + // Modify clone, original should be unchanged + clone.DefaultPerms["contents"] = "write" + clone.DefaultPerms["new_perm"] = "read" + + if original.DefaultPerms["contents"] != "read" { + t.Error("Modifying clone affected original DefaultPerms") + } + if _, exists := original.DefaultPerms["new_perm"]; exists { + t.Error("Adding to clone affected original DefaultPerms") + } + }) + + t.Run("DefaultEvents is deep copied", func(t *testing.T) { + original := &Manifest{ + DefaultEvents: []string{"push", "pull_request"}, + } + + clone := original.Clone() + + // Verify values are the same + if len(clone.DefaultEvents) != 2 { + t.Error("Clone().DefaultEvents has wrong length") + } + + // Modify clone, original should be unchanged + clone.DefaultEvents[0] = "modified" + clone.DefaultEvents = append(clone.DefaultEvents, "new_event") + + if original.DefaultEvents[0] != "push" { + t.Error("Modifying clone affected original DefaultEvents") + } + if len(original.DefaultEvents) != 2 { + t.Error("Appending to clone affected original DefaultEvents") + } + }) + + t.Run("nil maps and slices handled", func(t *testing.T) { + original := &Manifest{ + Name: "test", + DefaultPerms: nil, + DefaultEvents: nil, + } + + clone := original.Clone() + + if clone.DefaultPerms != nil { + t.Error("Clone() should have nil DefaultPerms when original is nil") + } + if clone.DefaultEvents != nil { + t.Error("Clone() should have nil DefaultEvents when original is nil") + } + }) +} + +func TestNew_Validation(t *testing.T) { + t.Run("nil store returns error", func(t *testing.T) { + _, err := New(Config{Store: nil}) + if err == nil { + t.Error("New() with nil store should return error") + } + }) + + t.Run("valid config succeeds", func(t *testing.T) { + store := &mockStore{} + h, err := New(Config{Store: store}) + if err != nil { + t.Errorf("New() error = %v, want nil", err) + } + if h == nil { + t.Error("New() returned nil handler") + } + }) + + t.Run("empty GitHubURL defaults to github.com", func(t *testing.T) { + store := &mockStore{} + h, _ := New(Config{Store: store, GitHubURL: ""}) + if h.config.GitHubURL != "https://github.com" { + t.Errorf("GitHubURL = %q, want %q", h.config.GitHubURL, "https://github.com") + } + }) + + t.Run("empty AppDisplayName defaults", func(t *testing.T) { + store := &mockStore{} + h, _ := New(Config{Store: store, AppDisplayName: ""}) + if h.config.AppDisplayName != "GitHub App" { + t.Errorf("AppDisplayName = %q, want %q", h.config.AppDisplayName, "GitHub App") + } + }) +} + +func TestHandler_ServeHTTP_Routing(t *testing.T) { + store := &mockStore{ + statusFunc: func(ctx context.Context) (*configstore.InstallerStatus, error) { + return &configstore.InstallerStatus{Registered: false}, nil + }, + } + + h, _ := New(Config{Store: store}) + + tests := []struct { + name string + method string + path string + wantStatus int + }{ + {"GET root redirects to setup", http.MethodGet, "/", http.StatusFound}, + {"GET /setup returns page", http.MethodGet, "/setup", http.StatusOK}, + {"GET /setup/ returns page", http.MethodGet, "/setup/", http.StatusOK}, + {"GET /callback without code returns 400", http.MethodGet, "/callback", http.StatusBadRequest}, + {"POST /setup/disable without registration returns 400", http.MethodPost, "/setup/disable", http.StatusBadRequest}, + {"unknown path returns 404", http.MethodGet, "/unknown", http.StatusNotFound}, + {"POST to GET-only path returns 404", http.MethodPost, "/setup", http.StatusNotFound}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(tt.method, tt.path, nil) + rec := httptest.NewRecorder() + + h.ServeHTTP(rec, req) + + if rec.Code != tt.wantStatus { + t.Errorf("ServeHTTP(%s %s) status = %d, want %d", tt.method, tt.path, rec.Code, tt.wantStatus) + } + }) + } +} + +func TestHandler_handleRoot_DisabledInstaller(t *testing.T) { + store := &mockStore{ + statusFunc: func(ctx context.Context) (*configstore.InstallerStatus, error) { + return &configstore.InstallerStatus{InstallerDisabled: true}, nil + }, + } + + h, _ := New(Config{Store: store}) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + h.ServeHTTP(rec, req) + + if rec.Code != http.StatusNotFound { + t.Errorf("handleRoot() with disabled installer status = %d, want %d", rec.Code, http.StatusNotFound) + } +} + +func TestHandler_handleIndex_AlreadyRegistered(t *testing.T) { + store := &mockStore{ + statusFunc: func(ctx context.Context) (*configstore.InstallerStatus, error) { + return &configstore.InstallerStatus{ + Registered: true, + AppID: 12345, + AppSlug: "my-app", + }, nil + }, + } + + h, _ := New(Config{Store: store}) + + req := httptest.NewRequest(http.MethodGet, "/setup", nil) + rec := httptest.NewRecorder() + + h.ServeHTTP(rec, req) + + // Should show success page + if rec.Code != http.StatusOK { + t.Errorf("handleIndex() with registered app status = %d, want %d", rec.Code, http.StatusOK) + } +} + +func TestIsValidOAuthCode(t *testing.T) { + tests := []struct { + name string + code string + want bool + }{ + {"valid alphanumeric code", "abc123DEF456xyz789", true}, + {"valid 20 character code", "abcdefghij1234567890", true}, + {"code too short", "abc", false}, + {"code too long", string(make([]byte, 101)), false}, + {"contains hyphen", "abc-123", false}, + {"contains underscore", "abc_123", false}, + {"contains space", "abc 123", false}, + {"contains special chars", "abc!@#123", false}, + {"empty string", "", false}, + {"minimum valid length", "abcdefghij", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isValidOAuthCode(tt.code) + if got != tt.want { + t.Errorf("isValidOAuthCode(%q) = %v, want %v", tt.code, got, tt.want) + } + }) + } +} + +func TestHandler_handleCallback_InvalidCode(t *testing.T) { + store := &mockStore{ + statusFunc: func(ctx context.Context) (*configstore.InstallerStatus, error) { + return &configstore.InstallerStatus{Registered: false}, nil + }, + } + + h, _ := New(Config{Store: store}) + + tests := []struct { + name string + code string + wantStatus int + }{ + {"missing code", "", http.StatusBadRequest}, + {"invalid code with special chars", "abc!@#123", http.StatusBadRequest}, + {"code too short", "abc", http.StatusBadRequest}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + url := "/callback" + if tt.code != "" { + url += "?code=" + tt.code + } + req := httptest.NewRequest(http.MethodGet, url, nil) + rec := httptest.NewRecorder() + + h.ServeHTTP(rec, req) + + if rec.Code != tt.wantStatus { + t.Errorf("handleCallback() with code %q status = %d, want %d", tt.code, rec.Code, tt.wantStatus) + } + }) + } +} + +// mockStore implements configstore.Store for testing +type mockStore struct { + saveFunc func(ctx context.Context, creds *configstore.AppCredentials) error + statusFunc func(ctx context.Context) (*configstore.InstallerStatus, error) + disableInstallerFunc func(ctx context.Context) error +} + +func (m *mockStore) Save(ctx context.Context, creds *configstore.AppCredentials) error { + if m.saveFunc != nil { + return m.saveFunc(ctx, creds) + } + return nil +} + +func (m *mockStore) Status(ctx context.Context) (*configstore.InstallerStatus, error) { + if m.statusFunc != nil { + return m.statusFunc(ctx) + } + return &configstore.InstallerStatus{}, nil +} + +func (m *mockStore) DisableInstaller(ctx context.Context) error { + if m.disableInstallerFunc != nil { + return m.disableInstallerFunc(ctx) + } + return nil +} diff --git a/installer/manifest.go b/installer/manifest.go new file mode 100644 index 0000000..9f28713 --- /dev/null +++ b/installer/manifest.go @@ -0,0 +1,50 @@ +// Copyright 2025 CruxStack +// SPDX-License-Identifier: MIT + +package installer + +// Manifest represents a GitHub App manifest. +type Manifest struct { + Name string `json:"name,omitempty"` + URL string `json:"url"` + HookAttributes HookAttributes `json:"hook_attributes"` + RedirectURL string `json:"redirect_url"` + Public bool `json:"public"` + DefaultPerms map[string]string `json:"default_permissions"` + DefaultEvents []string `json:"default_events"` +} + +// HookAttributes configures the webhook for the GitHub App. +type HookAttributes struct { + URL string `json:"url"` + Active bool `json:"active"` +} + +// Clone returns a deep copy of the manifest. +func (m *Manifest) Clone() *Manifest { + if m == nil { + return nil + } + + clone := &Manifest{ + Name: m.Name, + URL: m.URL, + RedirectURL: m.RedirectURL, + Public: m.Public, + HookAttributes: m.HookAttributes, + } + + if m.DefaultPerms != nil { + clone.DefaultPerms = make(map[string]string, len(m.DefaultPerms)) + for k, v := range m.DefaultPerms { + clone.DefaultPerms[k] = v + } + } + + if m.DefaultEvents != nil { + clone.DefaultEvents = make([]string, len(m.DefaultEvents)) + copy(clone.DefaultEvents, m.DefaultEvents) + } + + return clone +} diff --git a/installer/templates/index.html b/installer/templates/index.html new file mode 100644 index 0000000..c79d0e7 --- /dev/null +++ b/installer/templates/index.html @@ -0,0 +1,178 @@ + + +
+ + +Create a GitHub App with all the required permissions for {{.AppDisplayName}}.
+ +GITHUB_ORG to create for an organization instead.
+ Your GitHub App {{.AppSlug}} has been created and credentials have been saved.
+Setup has already been disabled. Restart the service to stop exposing the installer.
+ {{else}} +Once credentials are saved, disable this page so future visitors can't modify your configuration.
+ + {{end}} ++ Credentials have been saved to the configured storage backend. + After installing the app, you can configure your services to use the saved credentials. +
+