diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5012e79..d11df1e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,14 +2,14 @@ name: CI on: push: - branches: [ main ] + branches: [main] pull_request: - branches: [ main ] + branches: [main] jobs: lint-and-test: name: Lint, Format Check, and Test - runs-on: ubuntu-latest + runs-on: macos-latest steps: - name: Checkout code @@ -18,7 +18,7 @@ jobs: - name: Set up Go 1.25 uses: actions/setup-go@v5 with: - go-version: '1.25.0' + go-version: "1.25.0" cache: true - name: Install golangci-lint @@ -49,8 +49,13 @@ jobs: exit 1 fi - - name: Run golangci-lint - run: make lint + - name: Generate routes + run: make generate + + - name: Run golintci + uses: golangci/golangci-lint-action@v8 + with: + version: v2.6.0 - name: Run tests with Anvil run: make test diff --git a/.golangci.yml b/.golangci.yml index 6092019..e1b3655 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,3 +1,4 @@ +version: "2" run: timeout: 5m tests: true @@ -5,217 +6,45 @@ run: linters: enable: # Enabled by default - - errcheck # Check for unchecked errors - - gosimple # Simplify code - - govet # Vet examines Go source code - - ineffassign # Detect ineffectual assignments - - staticcheck # Advanced Go linter - - unused # Check for unused constants, variables, functions and types - - # Style and code quality - - gofmt # Check whether code was gofmt-ed - - goimports # Check import statements are formatted - - revive # Fast, configurable, extensible, flexible, and beautiful linter for Go - - stylecheck # Replacement for golint + - errcheck # Check for unchecked errors + - govet # Vet examines Go source code + - ineffassign # Detect ineffectual assignments + - staticcheck # Advanced Go linter + - unused # Check for unused constants, variables, functions and types # Complexity and maintainability - - gocyclo # Cyclomatic complexity - - gocognit # Cognitive complexity - - cyclop # Package complexity - - nestif # Reports deeply nested if statements + - gocyclo # Cyclomatic complexity + - gocognit # Cognitive complexity + - nestif # Reports deeply nested if statements # Bugs and correctness - - bodyclose # Check HTTP response body is closed - - noctx # Finds HTTP requests without context.Context - - rowserrcheck # Check whether Rows.Err is checked + - bodyclose # Check HTTP response body is closed + - noctx # Finds HTTP requests without context.Context + - rowserrcheck # Check whether Rows.Err is checked - sqlclosecheck # Check sql.Rows and sql.Stmt are closed - - gosec # Security issues + - gosec # Security issues # Performance - - prealloc # Find slice declarations that could potentially be preallocated + - prealloc # Find slice declarations that could potentially be preallocated # Error handling - - errname # Check error naming conventions - - errorlint # Find code that will cause problems with Go 1.13 error wrapping - - wrapcheck # Check that errors from external packages are wrapped + - errname # Check error naming conventions + - errorlint # Find code that will cause problems with Go 1.13 error wrapping + - wrapcheck # Check that errors from external packages are wrapped # Code organization - - depguard # Checks if package imports are in whitelist - - gomodguard # Check for blocked dependencies + - gomodguard # Check for blocked dependencies # Formatting and style - - whitespace # Check for unnecessary whitespace - - unconvert # Remove unnecessary type conversions - - unparam # Reports unused function parameters - - wastedassign # Find wasted assignment statements + - whitespace # Check for unnecessary whitespace + - unconvert # Remove unnecessary type conversions + - unparam # Reports unused function parameters + - wastedassign # Find wasted assignment statements # Naming conventions - - predeclared # Find code that shadows predeclared identifiers - - varnamelen # Check variable name length + - predeclared # Find code that shadows predeclared identifiers + - varnamelen # Check variable name length # Comments and documentation - - godot # Check if comments end in a period - - misspell # Finds commonly misspelled English words - -linters-settings: - revive: - rules: - - name: comment-spacings - severity: warning - disabled: false - # Private functions should appear above public methods - - name: function-result-limit - severity: warning - disabled: false - arguments: [3] - - name: cognitive-complexity - severity: warning - disabled: false - arguments: [20] - - name: cyclomatic - severity: warning - disabled: false - arguments: [15] - - gocyclo: - min-complexity: 15 - - gocognit: - min-complexity: 20 - - nestif: - min-complexity: 4 - - varnamelen: - min-name-length: 2 - ignore-names: - - err - - tx - - id - - ok - - i - - j - - k - ignore-decls: - - t testing.T - - e error - - errcheck: - check-type-assertions: true - check-blank: true - - govet: - enable-all: true - disable: - - shadow # Can be too strict - - fieldalignment # Struct field alignment is a micro-optimization - - stylecheck: - checks: ["all"] - initialisms: ["ACL", "API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SQL", "SSH", "TCP", "TLS", "TTL", "UDP", "UI", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XMPP", "XSRF", "XSS", "ABI"] - - godot: - scope: declarations - capital: true - - depguard: - rules: - main: - deny: - - pkg: "io/ioutil" - desc: "Use os and io packages instead" - - gosec: - excludes: - - G104 # Audit errors not checked (covered by errcheck) - - wrapcheck: - ignorePackageGlobs: - - github.com/rxtech-lab/smart-contract-cli/* - -issues: - exclude-dirs: - - vendor - - third_party - exclude-files: - - ".*\\.pb\\.go$" - exclude-rules: - # Exclude some linters from running on tests files - - path: _test\.go - linters: - - gocyclo - - errcheck - - gosec - - wrapcheck - - varnamelen - - godot - - # Exclude cognitive complexity for test files - - path: _test\.go - text: "cognitive complexity" - - # Allow longer functions in test suites - - path: _test\.go - text: "Function 'Test" - linters: - - gocognit - - gocyclo - - # Exclude storage models from stylecheck naming (database/JSON field names) - - path: internal/contract/evm/storage/models/ - linters: - - stylecheck - - # Exclude abi.go from stylecheck naming (legacy API compatibility) - - path: internal/contract/evm/abi/abi\.go - linters: - - stylecheck - - # Exclude component package from godot (type names like "If" cause false positives) - - path: internal/ui/component/ - linters: - - godot - - # Exclude generated files from all linters - - path: _gen\.go - linters: - - gofmt - - goimports - - revive - - stylecheck - - godot - - misspell - - whitespace - - unconvert - - unparam - - wastedassign - - predeclared - - varnamelen - - errcheck - - gosimple - - govet - - ineffassign - - staticcheck - - unused - - gocyclo - - gocognit - - cyclop - - nestif - - bodyclose - - noctx - - rowserrcheck - - sqlclosecheck - - gosec - - prealloc - - errname - - errorlint - - wrapcheck - - depguard - - gomodguard - - max-issues-per-linter: 0 - max-same-issues: 0 - -output: - formats: - - format: colored-line-number + - godot # Check if comments end in a period + - misspell # Finds commonly misspelled English words diff --git a/app/evm/storage/page.go b/app/evm/storage/page.go index a6a6050..5367f7b 100644 --- a/app/evm/storage/page.go +++ b/app/evm/storage/page.go @@ -32,12 +32,12 @@ type StorageOption struct { var storageOptions = []StorageOption{ { Label: "SQLite", - Value: "sqlite", + Value: config.StorageClientTypeSQLite, Description: "Local file-based database", }, { Label: "Postgres", - Value: "postgres", + Value: config.StorageClientTypePostgres, Description: "PostgreSQL database server", }, } @@ -103,10 +103,10 @@ func NewPage(router view.Router, sharedMemory storage.SharedMemory) view.View { } // Create text input for later use - ti := textinput.New() - ti.Focus() - ti.CharLimit = 256 - ti.Width = 50 + textInput := textinput.New() + textInput.Focus() + textInput.CharLimit = 256 + textInput.Width = 50 model := Model{ router: router, @@ -115,7 +115,7 @@ func NewPage(router view.Router, sharedMemory storage.SharedMemory) view.View { options: storageOptions, selectedIndex: 0, inputMode: InputModeNone, - textInput: ti, + textInput: textInput, confirmOptions: []string{ "Use existing configuration", "Change configuration", @@ -227,9 +227,10 @@ func (m Model) handleInputSubmit() Model { return m } - if m.inputMode == InputModeSqlitePath { + switch m.inputMode { + case InputModeSqlitePath: return m.saveSQLiteConfiguration(value) - } else if m.inputMode == InputModePostgresURL { + case InputModePostgresURL: return m.savePostgresConfiguration(value) } return m @@ -312,11 +313,12 @@ func (m Model) useExistingConfiguration(clientType string) Model { } func (m Model) changeConfiguration(clientType string) Model { - if clientType == "sqlite" { + switch clientType { + case "sqlite": m.inputMode = InputModeSqlitePath m.textInput.SetValue(m.sqlitePath) m.textInput.Placeholder = "Enter SQLite file path" - } else if clientType == "postgres" { + case "postgres": m.inputMode = InputModePostgresURL m.textInput.SetValue(m.postgresURL) m.textInput.Placeholder = "Enter PostgreSQL connection URL" @@ -333,9 +335,10 @@ func (m Model) removeConfiguration(clientType string) Model { return m } - if clientType == "sqlite" { + switch clientType { + case "sqlite": m.sqlitePath = "" - } else if clientType == "postgres" { + case "postgres": m.postgresURL = "" } @@ -366,9 +369,10 @@ func (m *Model) handleClientSelection() { // Check if configuration exists var hasConfig bool - if selectedOption.Value == "sqlite" { + switch selectedOption.Value { + case "sqlite": hasConfig = m.sqlitePath != "" - } else if selectedOption.Value == "postgres" { + case "postgres": hasConfig = m.postgresURL != "" } @@ -378,11 +382,12 @@ func (m *Model) handleClientSelection() { m.confirmIndex = 0 } else { // Show input dialog - if selectedOption.Value == "sqlite" { + switch selectedOption.Value { + case "sqlite": m.inputMode = InputModeSqlitePath m.textInput.SetValue("") m.textInput.Placeholder = "Enter SQLite file path (e.g., ~/.smart-contract-cli/data.db)" - } else if selectedOption.Value == "postgres" { + case "postgres": m.inputMode = InputModePostgresURL m.textInput.SetValue("") m.textInput.Placeholder = "Enter PostgreSQL URL (e.g., postgres://user:pass@localhost:5432/db)" @@ -398,20 +403,24 @@ func (m *Model) saveStorageClient(clientType string, value string) error { // Save the value var key string - if clientType == "sqlite" { + switch clientType { + case "sqlite": key = config.StorageKeySqlitePathKey - } else if clientType == "postgres" { + case "postgres": key = config.StorageKeyPostgresURLKey - } else { + default: return fmt.Errorf("invalid client type: %s", clientType) } if err := m.secureStorage.Set(key, value); err != nil { - return err + return fmt.Errorf("failed to save storage client configuration: %w", err) } // Save as active client - return m.secureStorage.Set(config.StorageKeyTypeKey, clientType) + if err := m.secureStorage.Set(config.StorageKeyTypeKey, clientType); err != nil { + return fmt.Errorf("failed to set active storage client: %w", err) + } + return nil } // switchActiveClient switches the active storage client. @@ -419,7 +428,10 @@ func (m *Model) switchActiveClient(clientType string) error { if m.secureStorage == nil { return fmt.Errorf("secure storage not initialized") } - return m.secureStorage.Set(config.StorageKeyTypeKey, clientType) + if err := m.secureStorage.Set(config.StorageKeyTypeKey, clientType); err != nil { + return fmt.Errorf("failed to switch active storage client: %w", err) + } + return nil } // removeStorageClient removes a storage client configuration. @@ -428,16 +440,21 @@ func (m *Model) removeStorageClient(clientType string) error { return fmt.Errorf("secure storage not initialized") } + // Determine which key to delete var key string - if clientType == "sqlite" { + switch clientType { + case "sqlite": key = config.StorageKeySqlitePathKey - } else if clientType == "postgres" { + case "postgres": key = config.StorageKeyPostgresURLKey - } else { + default: return fmt.Errorf("invalid client type: %s", clientType) } - return m.secureStorage.Delete(key) + if err := m.secureStorage.Delete(key); err != nil { + return fmt.Errorf("failed to delete storage client configuration: %w", err) + } + return nil } // maskPostgresURL masks the password in a Postgres URL for display. @@ -463,11 +480,12 @@ func maskPostgresURL(url string) string { } func (m Model) Help() (string, view.HelpDisplayOption) { - if m.inputMode == InputModeNone { + switch m.inputMode { + case InputModeNone: return "↑/k: up • ↓/j: down • enter: select • esc/q: back", view.HelpDisplayOptionAppend - } else if m.inputMode == InputModeSqlitePath || m.inputMode == InputModePostgresURL { + case InputModeSqlitePath, InputModePostgresURL: return "enter: save • esc: cancel", view.HelpDisplayOptionAppend - } else if m.inputMode == InputModeConfirmation { + case InputModeConfirmation: return "↑/k: up • ↓/j: down • enter: confirm • esc: cancel", view.HelpDisplayOptionAppend } return "", view.HelpDisplayOptionAppend @@ -486,9 +504,10 @@ func (m Model) View() string { } // Input mode - show text input - if m.inputMode == InputModeSqlitePath { + switch m.inputMode { + case InputModeSqlitePath: return m.renderInputView("Configure SQLite", "Enter the path for your SQLite database file:") - } else if m.inputMode == InputModePostgresURL { + case InputModePostgresURL: return m.renderInputView("Configure Postgres", "Enter the PostgreSQL connection URL:") } @@ -518,9 +537,10 @@ func (m Model) renderInputView(title string, prompt string) string { func (m Model) renderConfirmationView() string { selectedOption := m.options[m.selectedIndex] currentValue := "" - if selectedOption.Value == "sqlite" { + switch selectedOption.Value { + case "sqlite": currentValue = m.sqlitePath - } else if selectedOption.Value == "postgres" { + case "postgres": currentValue = maskPostgresURL(m.postgresURL) } @@ -549,7 +569,7 @@ func (m Model) renderConfirmationView() string { func (m Model) renderNormalView() string { // Build list items with descriptions items := make([]component.ListItem, len(m.options)) - for i, opt := range m.options { + for idx, opt := range m.options { desc := opt.Description // Add stored path/URL to description if available @@ -559,7 +579,7 @@ func (m Model) renderNormalView() string { desc = desc + "\nURL: " + maskPostgresURL(m.postgresURL) } - items[i] = component.Item(opt.Label, opt.Value, desc) + items[idx] = component.Item(opt.Label, opt.Value, desc) } // Highlight the active client diff --git a/app/evm/storage/page_test.go b/app/evm/storage/page_test.go index 3b778b5..3b9df85 100644 --- a/app/evm/storage/page_test.go +++ b/app/evm/storage/page_test.go @@ -35,7 +35,8 @@ func (s *StoragePageTestSuite) SetupTest() { s.testStoragePath = tmpDir // Override the storage path for tests - os.Setenv("HOME", tmpDir) + err = os.Setenv("HOME", tmpDir) + s.NoError(err, "Should set HOME environment variable") // Set up password and storage s.password = "testpassword123" @@ -62,7 +63,8 @@ func (s *StoragePageTestSuite) SetupTest() { func (s *StoragePageTestSuite) TearDownTest() { // Clean up test storage if s.testStoragePath != "" { - os.RemoveAll(s.testStoragePath) + err := os.RemoveAll(s.testStoragePath) + s.NoError(err, "Should clean up test storage directory") } } @@ -82,7 +84,7 @@ func (s *StoragePageTestSuite) TestInitialState() { model := NewPage(s.router, s.sharedMemory) - tm := teatest.NewTestModel( + testModel := teatest.NewTestModel( s.T(), model, teatest.WithInitialTermSize(300, 100), @@ -91,15 +93,15 @@ func (s *StoragePageTestSuite) TestInitialState() { // Wait for initial render time.Sleep(100 * time.Millisecond) - output := s.getOutput(tm) + output := s.getOutput(testModel) s.Contains(output, "Storage Client Configuration", "Should show page title") s.Contains(output, "SQLite", "Should show SQLite option") s.Contains(output, "Postgres", "Should show Postgres option") s.Contains(output, "Legend:", "Should show legend") // Quit - tm.Send(tea.KeyMsg{Type: tea.KeyCtrlC}) - tm.WaitFinished(s.T(), teatest.WithFinalTimeout(time.Second)) + testModel.Send(tea.KeyMsg{Type: tea.KeyCtrlC}) + testModel.WaitFinished(s.T(), teatest.WithFinalTimeout(time.Second)) } // TestNavigationUpDown tests keyboard navigation between storage options. @@ -130,7 +132,7 @@ func (s *StoragePageTestSuite) TestNavigationUpDown() { func (s *StoragePageTestSuite) TestSelectSQLiteFirstTime() { model := NewPage(s.router, s.sharedMemory) - tm := teatest.NewTestModel( + testModel := teatest.NewTestModel( s.T(), model, teatest.WithInitialTermSize(300, 100), @@ -140,18 +142,18 @@ func (s *StoragePageTestSuite) TestSelectSQLiteFirstTime() { time.Sleep(100 * time.Millisecond) // Press Enter to select SQLite (first option) - tm.Send(tea.KeyMsg{Type: tea.KeyEnter}) + testModel.Send(tea.KeyMsg{Type: tea.KeyEnter}) time.Sleep(100 * time.Millisecond) // Should show input mode for SQLite path - output := s.getOutput(tm) + output := s.getOutput(testModel) s.Contains(output, "Configure SQLite", "Should show SQLite configuration") s.Contains(output, "Enter the path", "Should prompt for path") // Type a path testPath := "/tmp/test.db" for _, char := range testPath { - tm.Send(tea.KeyMsg{ + testModel.Send(tea.KeyMsg{ Type: tea.KeyRunes, Runes: []rune{char}, }) @@ -159,24 +161,24 @@ func (s *StoragePageTestSuite) TestSelectSQLiteFirstTime() { } // Submit the path - tm.Send(tea.KeyMsg{Type: tea.KeyEnter}) + testModel.Send(tea.KeyMsg{Type: tea.KeyEnter}) time.Sleep(200 * time.Millisecond) // Should return to normal view - output = s.getOutput(tm) + output = s.getOutput(testModel) s.Contains(output, "Storage Client Configuration", "Should return to main view") s.Contains(output, testPath, "Should display the configured path") // Quit - tm.Send(tea.KeyMsg{Type: tea.KeyCtrlC}) - tm.WaitFinished(s.T(), teatest.WithFinalTimeout(time.Second)) + testModel.Send(tea.KeyMsg{Type: tea.KeyCtrlC}) + testModel.WaitFinished(s.T(), teatest.WithFinalTimeout(time.Second)) } // TestSelectPostgresFirstTime tests selecting Postgres for the first time. func (s *StoragePageTestSuite) TestSelectPostgresFirstTime() { model := NewPage(s.router, s.sharedMemory) - tm := teatest.NewTestModel( + testModel := teatest.NewTestModel( s.T(), model, teatest.WithInitialTermSize(300, 100), @@ -186,22 +188,22 @@ func (s *StoragePageTestSuite) TestSelectPostgresFirstTime() { time.Sleep(100 * time.Millisecond) // Navigate to Postgres (second option) - tm.Send(tea.KeyMsg{Type: tea.KeyDown}) + testModel.Send(tea.KeyMsg{Type: tea.KeyDown}) time.Sleep(50 * time.Millisecond) // Press Enter to select Postgres - tm.Send(tea.KeyMsg{Type: tea.KeyEnter}) + testModel.Send(tea.KeyMsg{Type: tea.KeyEnter}) time.Sleep(100 * time.Millisecond) // Should show input mode for Postgres URL - output := s.getOutput(tm) + output := s.getOutput(testModel) s.Contains(output, "Configure Postgres", "Should show Postgres configuration") s.Contains(output, "PostgreSQL connection URL", "Should prompt for URL") // Type a URL testURL := "postgres://user:pass@localhost:5432/db" for _, char := range testURL { - tm.Send(tea.KeyMsg{ + testModel.Send(tea.KeyMsg{ Type: tea.KeyRunes, Runes: []rune{char}, }) @@ -209,24 +211,24 @@ func (s *StoragePageTestSuite) TestSelectPostgresFirstTime() { } // Submit the URL - tm.Send(tea.KeyMsg{Type: tea.KeyEnter}) + testModel.Send(tea.KeyMsg{Type: tea.KeyEnter}) time.Sleep(200 * time.Millisecond) // Should return to normal view with masked password - output = s.getOutput(tm) + output = s.getOutput(testModel) s.Contains(output, "Storage Client Configuration", "Should return to main view") s.Contains(output, "postgres://user:****@localhost:5432/db", "Should display masked URL") // Quit - tm.Send(tea.KeyMsg{Type: tea.KeyCtrlC}) - tm.WaitFinished(s.T(), teatest.WithFinalTimeout(time.Second)) + testModel.Send(tea.KeyMsg{Type: tea.KeyCtrlC}) + testModel.WaitFinished(s.T(), teatest.WithFinalTimeout(time.Second)) } // TestCancelInput tests canceling input with Escape key. func (s *StoragePageTestSuite) TestCancelInput() { model := NewPage(s.router, s.sharedMemory) - tm := teatest.NewTestModel( + testModel := teatest.NewTestModel( s.T(), model, teatest.WithInitialTermSize(300, 100), @@ -236,32 +238,32 @@ func (s *StoragePageTestSuite) TestCancelInput() { time.Sleep(100 * time.Millisecond) // Select SQLite - tm.Send(tea.KeyMsg{Type: tea.KeyEnter}) + testModel.Send(tea.KeyMsg{Type: tea.KeyEnter}) time.Sleep(100 * time.Millisecond) // Should be in input mode - output := s.getOutput(tm) + output := s.getOutput(testModel) s.Contains(output, "Configure SQLite", "Should be in SQLite config mode") // Press Escape to cancel - tm.Send(tea.KeyMsg{Type: tea.KeyEsc}) + testModel.Send(tea.KeyMsg{Type: tea.KeyEsc}) time.Sleep(100 * time.Millisecond) // Should return to normal view - output = s.getOutput(tm) + output = s.getOutput(testModel) s.Contains(output, "Storage Client Configuration", "Should return to main view") s.NotContains(output, "Configure SQLite", "Should exit config mode") // Quit - tm.Send(tea.KeyMsg{Type: tea.KeyCtrlC}) - tm.WaitFinished(s.T(), teatest.WithFinalTimeout(time.Second)) + testModel.Send(tea.KeyMsg{Type: tea.KeyCtrlC}) + testModel.WaitFinished(s.T(), teatest.WithFinalTimeout(time.Second)) } // TestEmptyPathValidation tests that empty path/URL is rejected. func (s *StoragePageTestSuite) TestEmptyPathValidation() { model := NewPage(s.router, s.sharedMemory) - tm := teatest.NewTestModel( + testModel := teatest.NewTestModel( s.T(), model, teatest.WithInitialTermSize(300, 100), @@ -271,20 +273,20 @@ func (s *StoragePageTestSuite) TestEmptyPathValidation() { time.Sleep(100 * time.Millisecond) // Select SQLite - tm.Send(tea.KeyMsg{Type: tea.KeyEnter}) + testModel.Send(tea.KeyMsg{Type: tea.KeyEnter}) time.Sleep(100 * time.Millisecond) // Submit without entering anything - tm.Send(tea.KeyMsg{Type: tea.KeyEnter}) + testModel.Send(tea.KeyMsg{Type: tea.KeyEnter}) time.Sleep(100 * time.Millisecond) // Should show error - output := s.getOutput(tm) + output := s.getOutput(testModel) s.Contains(output, "Path/URL cannot be empty", "Should show validation error") // Quit - tm.Send(tea.KeyMsg{Type: tea.KeyCtrlC}) - tm.WaitFinished(s.T(), teatest.WithFinalTimeout(time.Second)) + testModel.Send(tea.KeyMsg{Type: tea.KeyCtrlC}) + testModel.WaitFinished(s.T(), teatest.WithFinalTimeout(time.Second)) } // TestActiveClientHighlighting tests that active client is highlighted. @@ -297,7 +299,7 @@ func (s *StoragePageTestSuite) TestActiveClientHighlighting() { model := NewPage(s.router, s.sharedMemory) - tm := teatest.NewTestModel( + testModel := teatest.NewTestModel( s.T(), model, teatest.WithInitialTermSize(300, 100), @@ -307,13 +309,13 @@ func (s *StoragePageTestSuite) TestActiveClientHighlighting() { time.Sleep(100 * time.Millisecond) // Should show active client marker - output := s.getOutput(tm) + output := s.getOutput(testModel) s.Contains(output, "★", "Should show star marker for active client") s.Contains(output, "/tmp/test.db", "Should show configured path") // Quit - tm.Send(tea.KeyMsg{Type: tea.KeyCtrlC}) - tm.WaitFinished(s.T(), teatest.WithFinalTimeout(time.Second)) + testModel.Send(tea.KeyMsg{Type: tea.KeyCtrlC}) + testModel.WaitFinished(s.T(), teatest.WithFinalTimeout(time.Second)) } // TestConfirmationDialog tests the confirmation dialog for existing config. @@ -326,7 +328,7 @@ func (s *StoragePageTestSuite) TestConfirmationDialog() { model := NewPage(s.router, s.sharedMemory) - tm := teatest.NewTestModel( + testModel := teatest.NewTestModel( s.T(), model, teatest.WithInitialTermSize(300, 100), @@ -336,11 +338,11 @@ func (s *StoragePageTestSuite) TestConfirmationDialog() { time.Sleep(100 * time.Millisecond) // Select SQLite (already configured) - tm.Send(tea.KeyMsg{Type: tea.KeyEnter}) + testModel.Send(tea.KeyMsg{Type: tea.KeyEnter}) time.Sleep(100 * time.Millisecond) // Should show confirmation dialog - output := s.getOutput(tm) + output := s.getOutput(testModel) s.Contains(output, "SQLite Configuration", "Should show configuration dialog") s.Contains(output, "Use existing configuration", "Should show use existing option") s.Contains(output, "Change configuration", "Should show change option") @@ -348,8 +350,8 @@ func (s *StoragePageTestSuite) TestConfirmationDialog() { s.Contains(output, "/tmp/existing.db", "Should show current path") // Quit - tm.Send(tea.KeyMsg{Type: tea.KeyCtrlC}) - tm.WaitFinished(s.T(), teatest.WithFinalTimeout(time.Second)) + testModel.Send(tea.KeyMsg{Type: tea.KeyCtrlC}) + testModel.WaitFinished(s.T(), teatest.WithFinalTimeout(time.Second)) } // TestPasswordMasking tests that Postgres password is properly masked. @@ -396,7 +398,8 @@ func (s *StoragePageTestSuite) TestBackNavigation() { } // Store password for initialization - s.sharedMemory.Set(config.SecureStoragePasswordKey, s.password) + err := s.sharedMemory.Set(config.SecureStoragePasswordKey, s.password) + s.NoError(err, "Should store password in shared memory") model := NewPage(mockRouter, s.sharedMemory) diff --git a/app/page.go b/app/page.go index 6a01cf0..ab314d9 100644 --- a/app/page.go +++ b/app/page.go @@ -178,7 +178,10 @@ func (m *Model) ensureSecureStorageInitialized() error { var err error m.secureStorage, err = storage.NewSecureStorageWithEncryption("smart-contract-cli-key", "") - return err + if err != nil { + return fmt.Errorf("failed to create secure storage: %w", err) + } + return nil } // createStorageIfNeeded creates storage if it doesn't exist. @@ -186,15 +189,21 @@ func (m Model) createStorageIfNeeded(password string) error { if m.secureStorage.Exists() { return nil } - return m.secureStorage.Create(password) + if err := m.secureStorage.Create(password); err != nil { + return fmt.Errorf("failed to create storage: %w", err) + } + return nil } // unlockAndStorePassword unlocks storage and stores password in shared memory. func (m Model) unlockAndStorePassword(password string) error { if err := m.secureStorage.Unlock(password); err != nil { - return err + return fmt.Errorf("failed to unlock storage: %w", err) + } + if err := m.sharedMemory.Set(config.SecureStoragePasswordKey, password); err != nil { + return fmt.Errorf("failed to store password in shared memory: %w", err) } - return m.sharedMemory.Set(config.SecureStoragePasswordKey, password) + return nil } func (m Model) moveUp(currentIndex int) int { diff --git a/app/page_test.go b/app/page_test.go index a86b865..ee3aa48 100644 --- a/app/page_test.go +++ b/app/page_test.go @@ -13,7 +13,7 @@ import ( "github.com/stretchr/testify/suite" ) -// PagePasswordTestSuite tests password unlock functionality using teatest +// PagePasswordTestSuite tests password unlock functionality using teatest. type PagePasswordTestSuite struct { suite.Suite testStoragePath string @@ -32,7 +32,8 @@ func (s *PagePasswordTestSuite) SetupTest() { s.testStoragePath = tmpDir // Override the storage path for tests - os.Setenv("HOME", tmpDir) + err = os.Setenv("HOME", tmpDir) + s.NoError(err, "Should set HOME environment variable") // Create shared memory and router for each test s.sharedMemory = storage.NewSharedMemory() @@ -42,7 +43,8 @@ func (s *PagePasswordTestSuite) SetupTest() { func (s *PagePasswordTestSuite) TearDownTest() { // Clean up test storage if s.testStoragePath != "" { - os.RemoveAll(s.testStoragePath) + err := os.RemoveAll(s.testStoragePath) + s.NoError(err, "Should clean up test storage directory") } } @@ -52,7 +54,7 @@ func (s *PagePasswordTestSuite) getOutput(tm *teatest.TestModel) string { return string(output) } -// TestInitialStateNewStorage tests that a new storage creation prompt is shown +// TestInitialStateNewStorage tests that a new storage creation prompt is shown. func (s *PagePasswordTestSuite) TestInitialStateNewStorage() { model := NewPage(s.router, s.sharedMemory) pageModel := model.(Model) @@ -61,7 +63,7 @@ func (s *PagePasswordTestSuite) TestInitialStateNewStorage() { s.False(pageModel.isUnlocked, "Should not be unlocked initially") s.True(pageModel.isCreatingNew, "Should be in creating new storage mode") - tm := teatest.NewTestModel( + testModel := teatest.NewTestModel( s.T(), model, teatest.WithInitialTermSize(300, 100), @@ -70,20 +72,20 @@ func (s *PagePasswordTestSuite) TestInitialStateNewStorage() { // Wait for initial render time.Sleep(100 * time.Millisecond) - output := s.getOutput(tm) + output := s.getOutput(testModel) s.Contains(output, "Create a password", "Should show create password prompt") s.Contains(output, "Password:", "Should show password input field") // Quit - tm.Send(tea.KeyMsg{Type: tea.KeyCtrlC}) - tm.WaitFinished(s.T(), teatest.WithFinalTimeout(time.Second)) + testModel.Send(tea.KeyMsg{Type: tea.KeyCtrlC}) + testModel.WaitFinished(s.T(), teatest.WithFinalTimeout(time.Second)) } -// TestSuccessfulPasswordCreation tests creating a new storage with a password +// TestSuccessfulPasswordCreation tests creating a new storage with a password. func (s *PagePasswordTestSuite) TestSuccessfulPasswordCreation() { model := NewPage(s.router, s.sharedMemory) - tm := teatest.NewTestModel( + testModel := teatest.NewTestModel( s.T(), model, teatest.WithInitialTermSize(300, 100), @@ -95,7 +97,7 @@ func (s *PagePasswordTestSuite) TestSuccessfulPasswordCreation() { // Type password password := "testpass123" for _, char := range password { - tm.Send(tea.KeyMsg{ + testModel.Send(tea.KeyMsg{ Type: tea.KeyRunes, Runes: []rune{char}, }) @@ -103,7 +105,7 @@ func (s *PagePasswordTestSuite) TestSuccessfulPasswordCreation() { } // Submit password with Enter - tm.Send(tea.KeyMsg{Type: tea.KeyEnter}) + testModel.Send(tea.KeyMsg{Type: tea.KeyEnter}) // Wait for unlock to process time.Sleep(200 * time.Millisecond) @@ -114,15 +116,15 @@ func (s *PagePasswordTestSuite) TestSuccessfulPasswordCreation() { s.Equal(password, storedPassword, "Password should match") // Verify main menu is shown - output := s.getOutput(tm) + output := s.getOutput(testModel) s.Contains(output, "Select a blockchain", "Should show main menu after unlock") // Quit - tm.Send(tea.KeyMsg{Type: tea.KeyCtrlC}) - tm.WaitFinished(s.T(), teatest.WithFinalTimeout(time.Second)) + testModel.Send(tea.KeyMsg{Type: tea.KeyCtrlC}) + testModel.WaitFinished(s.T(), teatest.WithFinalTimeout(time.Second)) } -// TestExistingStorageUnlock tests unlocking existing storage with correct password +// TestExistingStorageUnlock tests unlocking existing storage with correct password. func (s *PagePasswordTestSuite) TestExistingStorageUnlock() { // Pre-create storage with a known password password := "mypassword" @@ -139,7 +141,7 @@ func (s *PagePasswordTestSuite) TestExistingStorageUnlock() { s.False(pageModel.isUnlocked, "Should not be unlocked initially") s.False(pageModel.isCreatingNew, "Should not be in create mode for existing storage") - tm := teatest.NewTestModel( + testModel := teatest.NewTestModel( s.T(), model, teatest.WithInitialTermSize(300, 100), @@ -148,12 +150,12 @@ func (s *PagePasswordTestSuite) TestExistingStorageUnlock() { // Wait for initial render time.Sleep(100 * time.Millisecond) - output := s.getOutput(tm) + output := s.getOutput(testModel) s.Contains(output, "Enter password to unlock", "Should show unlock prompt") // Type the correct password for _, char := range password { - tm.Send(tea.KeyMsg{ + testModel.Send(tea.KeyMsg{ Type: tea.KeyRunes, Runes: []rune{char}, }) @@ -161,21 +163,21 @@ func (s *PagePasswordTestSuite) TestExistingStorageUnlock() { } // Submit password - tm.Send(tea.KeyMsg{Type: tea.KeyEnter}) + testModel.Send(tea.KeyMsg{Type: tea.KeyEnter}) // Wait for unlock time.Sleep(200 * time.Millisecond) // Verify unlocked - output = s.getOutput(tm) + output = s.getOutput(testModel) s.Contains(output, "Select a blockchain", "Should show main menu after successful unlock") // Quit - tm.Send(tea.KeyMsg{Type: tea.KeyCtrlC}) - tm.WaitFinished(s.T(), teatest.WithFinalTimeout(time.Second)) + testModel.Send(tea.KeyMsg{Type: tea.KeyCtrlC}) + testModel.WaitFinished(s.T(), teatest.WithFinalTimeout(time.Second)) } -// TestInvalidPasswordError tests that wrong password shows error +// TestInvalidPasswordError tests that wrong password shows error. func (s *PagePasswordTestSuite) TestInvalidPasswordError() { // Pre-create storage with a known password correctPassword := "correctpass" @@ -188,7 +190,7 @@ func (s *PagePasswordTestSuite) TestInvalidPasswordError() { // Create model model := NewPage(s.router, s.sharedMemory) - tm := teatest.NewTestModel( + testModel2 := teatest.NewTestModel( s.T(), model, teatest.WithInitialTermSize(300, 100), @@ -200,7 +202,7 @@ func (s *PagePasswordTestSuite) TestInvalidPasswordError() { // Type wrong password wrongPassword := "wrongpass" for _, char := range wrongPassword { - tm.Send(tea.KeyMsg{ + testModel2.Send(tea.KeyMsg{ Type: tea.KeyRunes, Runes: []rune{char}, }) @@ -208,13 +210,13 @@ func (s *PagePasswordTestSuite) TestInvalidPasswordError() { } // Submit wrong password - tm.Send(tea.KeyMsg{Type: tea.KeyEnter}) + testModel2.Send(tea.KeyMsg{Type: tea.KeyEnter}) // Wait for error processing time.Sleep(200 * time.Millisecond) // Verify error message shown - output := s.getOutput(tm) + output := s.getOutput(testModel2) s.Contains(output, "Failed to unlock", "Should show unlock failure message") s.NotContains(output, "Select a blockchain", "Should not show main menu") @@ -224,15 +226,15 @@ func (s *PagePasswordTestSuite) TestInvalidPasswordError() { s.Nil(storedPassword, "Should not have password in shared memory after failed unlock") // Quit - tm.Send(tea.KeyMsg{Type: tea.KeyCtrlC}) - tm.WaitFinished(s.T(), teatest.WithFinalTimeout(time.Second)) + testModel2.Send(tea.KeyMsg{Type: tea.KeyCtrlC}) + testModel2.WaitFinished(s.T(), teatest.WithFinalTimeout(time.Second)) } -// TestEmptyPasswordValidation tests that empty password is rejected +// TestEmptyPasswordValidation tests that empty password is rejected. func (s *PagePasswordTestSuite) TestEmptyPasswordValidation() { model := NewPage(s.router, s.sharedMemory) - tm := teatest.NewTestModel( + testModel := teatest.NewTestModel( s.T(), model, teatest.WithInitialTermSize(300, 100), @@ -242,21 +244,21 @@ func (s *PagePasswordTestSuite) TestEmptyPasswordValidation() { time.Sleep(100 * time.Millisecond) // Submit without typing anything - tm.Send(tea.KeyMsg{Type: tea.KeyEnter}) + testModel.Send(tea.KeyMsg{Type: tea.KeyEnter}) // Wait for validation time.Sleep(200 * time.Millisecond) // Verify error message - output := s.getOutput(tm) + output := s.getOutput(testModel) s.Contains(output, "Password cannot be empty", "Should show empty password error") // Quit - tm.Send(tea.KeyMsg{Type: tea.KeyCtrlC}) - tm.WaitFinished(s.T(), teatest.WithFinalTimeout(time.Second)) + testModel.Send(tea.KeyMsg{Type: tea.KeyCtrlC}) + testModel.WaitFinished(s.T(), teatest.WithFinalTimeout(time.Second)) } -// TestSharedMemoryIntegration tests that pre-existing password in shared memory skips unlock +// TestSharedMemoryIntegration tests that pre-existing password in shared memory skips unlock. func (s *PagePasswordTestSuite) TestSharedMemoryIntegration() { // Pre-create storage and set password in shared memory password := "presetpass" @@ -276,7 +278,7 @@ func (s *PagePasswordTestSuite) TestSharedMemoryIntegration() { s.True(pageModel.isUnlocked, "Should be unlocked immediately with password in shared memory") - tm := teatest.NewTestModel( + testModel := teatest.NewTestModel( s.T(), model, teatest.WithInitialTermSize(300, 100), @@ -286,20 +288,20 @@ func (s *PagePasswordTestSuite) TestSharedMemoryIntegration() { time.Sleep(100 * time.Millisecond) // Verify main menu shown immediately - output := s.getOutput(tm) + output := s.getOutput(testModel) s.Contains(output, "Select a blockchain", "Should show main menu immediately") s.NotContains(output, "Enter password", "Should not show password prompt") // Quit - tm.Send(tea.KeyMsg{Type: tea.KeyCtrlC}) - tm.WaitFinished(s.T(), teatest.WithFinalTimeout(time.Second)) + testModel.Send(tea.KeyMsg{Type: tea.KeyCtrlC}) + testModel.WaitFinished(s.T(), teatest.WithFinalTimeout(time.Second)) } -// TestQuitDuringPasswordEntry tests that Ctrl+C works during password entry +// TestQuitDuringPasswordEntry tests that Ctrl+C works during password entry. func (s *PagePasswordTestSuite) TestQuitDuringPasswordEntry() { model := NewPage(s.router, s.sharedMemory) - tm := teatest.NewTestModel( + testModel3 := teatest.NewTestModel( s.T(), model, teatest.WithInitialTermSize(300, 100), @@ -309,14 +311,14 @@ func (s *PagePasswordTestSuite) TestQuitDuringPasswordEntry() { time.Sleep(100 * time.Millisecond) // Type some password - tm.Send(tea.KeyMsg{ + testModel3.Send(tea.KeyMsg{ Type: tea.KeyRunes, Runes: []rune{'t', 'e', 's', 't'}, }) // Quit with Ctrl+C - tm.Send(tea.KeyMsg{Type: tea.KeyCtrlC}) - tm.WaitFinished(s.T(), teatest.WithFinalTimeout(time.Second)) + testModel3.Send(tea.KeyMsg{Type: tea.KeyCtrlC}) + testModel3.WaitFinished(s.T(), teatest.WithFinalTimeout(time.Second)) // If we get here without hanging, the test passes s.True(true, "Should quit cleanly during password entry") diff --git a/internal/config/shared_memory_keys.go b/internal/config/shared_memory_keys.go index 5ece94b..9b56833 100644 --- a/internal/config/shared_memory_keys.go +++ b/internal/config/shared_memory_keys.go @@ -5,4 +5,5 @@ const ( StorageKeyTypeKey = "storage_client_type" StorageKeySqlitePathKey = "storage_client_sqlite_path" StorageKeyPostgresURLKey = "storage_client_postgres_url" + StorageClientKey = "storage_client" ) diff --git a/internal/config/storage.go b/internal/config/storage.go new file mode 100644 index 0000000..2c57f05 --- /dev/null +++ b/internal/config/storage.go @@ -0,0 +1,6 @@ +package config + +const ( + StorageClientTypeSQLite = "sqlite" + StorageClientTypePostgres = "postgres" +) diff --git a/internal/contract/evm/abi/abi.go b/internal/contract/evm/abi/abi.go index 4021dff..b532450 100644 --- a/internal/contract/evm/abi/abi.go +++ b/internal/contract/evm/abi/abi.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "os" + "path/filepath" "strings" "github.com/rxtech-lab/smart-contract-cli/internal/errors" @@ -69,7 +70,12 @@ func downloadAbi(url string) (AbiArray, error) { if err != nil { return nil, errors.WrapABIError(err, errors.ErrCodeABIParseFailed, fmt.Sprintf("failed to download ABI from URL: %s", url)) } - defer resp.Body.Close() + defer func() { + if closeErr := resp.Body.Close(); closeErr != nil { + // Log error but don't fail the operation + _ = closeErr + } + }() // Check status code if resp.StatusCode != http.StatusOK { @@ -87,11 +93,17 @@ func downloadAbi(url string) (AbiArray, error) { } // readAbiFromFile reads an ABI from a local file and parses it. -func readAbiFromFile(filepath string) (AbiArray, error) { +func readAbiFromFile(filePath string) (AbiArray, error) { + // Validate filePath to prevent directory traversal + cleaned := filepath.Clean(filePath) + if strings.Contains(cleaned, "..") { + return nil, errors.NewABIError(errors.ErrCodeABIParseFailed, fmt.Sprintf("invalid file path: %s", filePath)) + } + // Read file - data, err := os.ReadFile(filepath) + data, err := os.ReadFile(filePath) if err != nil { - return nil, errors.WrapABIError(err, errors.ErrCodeABIParseFailed, fmt.Sprintf("failed to read ABI file: %s", filepath)) + return nil, errors.WrapABIError(err, errors.ErrCodeABIParseFailed, fmt.Sprintf("failed to read ABI file: %s", filePath)) } // Parse ABI diff --git a/internal/contract/evm/abi/abi_test.go b/internal/contract/evm/abi/abi_test.go index efe9462..af1141e 100644 --- a/internal/contract/evm/abi/abi_test.go +++ b/internal/contract/evm/abi/abi_test.go @@ -199,20 +199,20 @@ func TestParseAbi(t *testing.T) { }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := ParseAbi(tt.input) + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + result, err := ParseAbi(testCase.input) - if tt.wantErr { + if testCase.wantErr { assert.Error(t, err) return } require.NoError(t, err) - assert.Len(t, result, tt.expectedLen) + assert.Len(t, result, testCase.expectedLen) - if tt.validate != nil { - tt.validate(t, result) + if testCase.validate != nil { + testCase.validate(t, result) } }) } @@ -293,7 +293,7 @@ func TestReadAbi(t *testing.T) { "stateMutability": "nonpayable" } ]` - err := os.WriteFile(filePath, []byte(abiContent), 0644) + err := os.WriteFile(filePath, []byte(abiContent), 0600) require.NoError(t, err) return filePath }, @@ -323,7 +323,7 @@ func TestReadAbi(t *testing.T) { ], "bytecode": "0x608060405234801561001057600080fd5b50" }` - err := os.WriteFile(filePath, []byte(abiContent), 0644) + err := os.WriteFile(filePath, []byte(abiContent), 0600) require.NoError(t, err) return filePath }, @@ -408,25 +408,25 @@ func TestReadAbi(t *testing.T) { }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - path := tt.setup(t) - defer tt.cleanup(t, path) + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + path := testCase.setup(t) + defer testCase.cleanup(t, path) result, err := ReadAbi(path) - if tt.wantErr { + if testCase.wantErr { assert.Error(t, err) return } require.NoError(t, err) - if tt.expectedLen > 0 { - assert.Len(t, result, tt.expectedLen) + if testCase.expectedLen > 0 { + assert.Len(t, result, testCase.expectedLen) } - if tt.validate != nil { - tt.validate(t, result) + if testCase.validate != nil { + testCase.validate(t, result) } }) } diff --git a/internal/contract/evm/contract/signer/privatekey.go b/internal/contract/evm/contract/signer/privatekey.go index 45c2dff..4fb7109 100644 --- a/internal/contract/evm/contract/signer/privatekey.go +++ b/internal/contract/evm/contract/signer/privatekey.go @@ -47,9 +47,9 @@ func (p *PrivateKeySigner) SignMessageString(message string) (signature string, } // SignTransaction implements Signer. -func (p *PrivateKeySigner) SignTransaction(tx *types.Transaction) (signedTx *types.Transaction, err error) { +func (p *PrivateKeySigner) SignTransaction(transaction *types.Transaction) (signedTx *types.Transaction, err error) { // Get the chain ID from the transaction - chainID := tx.ChainId() + chainID := transaction.ChainId() if chainID == nil { return nil, errors.NewSignerError(errors.ErrCodeInvalidChainID, "transaction has no chain ID") } @@ -58,7 +58,7 @@ func (p *PrivateKeySigner) SignTransaction(tx *types.Transaction) (signedTx *typ signer := types.NewLondonSigner(chainID) // Sign the transaction - signedTx, err = types.SignTx(tx, signer, p.PrivateKey) + signedTx, err = types.SignTx(transaction, signer, p.PrivateKey) if err != nil { return nil, errors.WrapSignerError(err, errors.ErrCodeTransactionSignFailed, "failed to sign transaction") } diff --git a/internal/contract/evm/contract/signer/privatekey_test.go b/internal/contract/evm/contract/signer/privatekey_test.go index e1395f7..2f2c3cd 100644 --- a/internal/contract/evm/contract/signer/privatekey_test.go +++ b/internal/contract/evm/contract/signer/privatekey_test.go @@ -77,14 +77,14 @@ func (suite *PrivateKeySignerTestSuite) TestNewPrivateKeySigner() { }, } - for _, tt := range tests { - suite.Run(tt.name, func() { - signer, err := NewPrivateKeySigner(tt.privateKey) + for _, testCase := range tests { + suite.Run(testCase.name, func() { + signer, err := NewPrivateKeySigner(testCase.privateKey) - if tt.wantErr { + if testCase.wantErr { suite.Error(err, "expected error but got none") - if tt.errContains != "" && err != nil { - suite.Contains(err.Error(), tt.errContains, "error should contain expected text") + if testCase.errContains != "" && err != nil { + suite.Contains(err.Error(), testCase.errContains, "error should contain expected text") } suite.Nil(signer, "expected nil signer on error") return @@ -135,11 +135,11 @@ func (suite *PrivateKeySignerTestSuite) TestSignMessageString() { }, } - for _, tt := range tests { - suite.Run(tt.name, func() { - signature, err := suite.signer.SignMessageString(tt.message) + for _, testCase := range tests { + suite.Run(testCase.name, func() { + signature, err := suite.signer.SignMessageString(testCase.message) - if tt.wantErr { + if testCase.wantErr { suite.Error(err, "expected error but got none") return } @@ -150,13 +150,13 @@ func (suite *PrivateKeySignerTestSuite) TestSignMessageString() { suite.Equal("0x", signature[:2], "signature should start with 0x") // Signature should be 65 bytes (130 hex chars) + 2 for 0x prefix = 132 total suite.Equal(132, len(signature), "signature should be 132 characters (0x + 130 hex)") - suite.T().Logf("Message: %q", tt.message) + suite.T().Logf("Message: %q", testCase.message) suite.T().Logf("Signature: %s", signature) }) } } -// TestVerifyMessageString tests message verification with invalid cases +// TestVerifyMessageString tests message verification with invalid cases. func (suite *PrivateKeySignerTestSuite) TestVerifyMessageString() { message := "Test message for verification" @@ -224,19 +224,19 @@ func (suite *PrivateKeySignerTestSuite) TestVerifyMessageString() { }, } - for _, tt := range tests { - suite.Run(tt.name, func() { - isValid, recoveredAddr, err := suite.signer.VerifyMessageString(tt.address, tt.message, tt.signature) + for _, testCase := range tests { + suite.Run(testCase.name, func() { + isValid, recoveredAddr, err := suite.signer.VerifyMessageString(testCase.address, testCase.message, testCase.signature) - if tt.wantErr { + if testCase.wantErr { suite.Error(err, "expected error but got none") return } suite.NoError(err, "VerifyMessageString should not return error") - suite.Equal(tt.wantValid, isValid, "signature validity should match expected") + suite.Equal(testCase.wantValid, isValid, "signature validity should match expected") - if tt.checkRecoveredAddr && tt.wantValid { + if testCase.checkRecoveredAddr && testCase.wantValid { suite.Equal(suite.testAddress, recoveredAddr, "recovered address should match signer address") suite.T().Logf("Recovered address: %s", recoveredAddr.Hex()) } @@ -244,7 +244,7 @@ func (suite *PrivateKeySignerTestSuite) TestVerifyMessageString() { } } -// TestSignAndVerifyMessageRoundtrip tests the full sign and verify flow +// TestSignAndVerifyMessageRoundtrip tests the full sign and verify flow. func (suite *PrivateKeySignerTestSuite) TestSignAndVerifyMessageRoundtrip() { tests := []struct { name string @@ -264,15 +264,15 @@ func (suite *PrivateKeySignerTestSuite) TestSignAndVerifyMessageRoundtrip() { }, } - for _, tt := range tests { - suite.Run(tt.name, func() { + for _, testCase := range tests { + suite.Run(testCase.name, func() { // Step 1: Sign the message - signature, err := suite.signer.SignMessageString(tt.message) + signature, err := suite.signer.SignMessageString(testCase.message) suite.Require().NoError(err, "failed to sign message") - suite.T().Logf("Signed message %q with signature: %s", tt.message, signature) + suite.T().Logf("Signed message %q with signature: %s", testCase.message, signature) // Step 2: Verify the signature with correct address - isValid, recoveredAddr, err := suite.signer.VerifyMessageString(suite.testAddress, tt.message, signature) + isValid, recoveredAddr, err := suite.signer.VerifyMessageString(suite.testAddress, testCase.message, signature) suite.Require().NoError(err, "failed to verify message") suite.True(isValid, "signature should be valid for correct address") suite.Equal(suite.testAddress, recoveredAddr, "recovered address should match signer address") @@ -280,14 +280,14 @@ func (suite *PrivateKeySignerTestSuite) TestSignAndVerifyMessageRoundtrip() { // Step 3: Verify the signature with wrong address (should fail) wrongAddress := common.HexToAddress("0x0000000000000000000000000000000000000001") - isValid, _, err = suite.signer.VerifyMessageString(wrongAddress, tt.message, signature) + isValid, _, err = suite.signer.VerifyMessageString(wrongAddress, testCase.message, signature) suite.NoError(err, "verify should not error even with wrong address") suite.False(isValid, "signature should be invalid for wrong address") }) } } -// TestVerifyMessageWithMetaMaskFormat tests v=27/28 signature format +// TestVerifyMessageWithMetaMaskFormat tests v=27/28 signature format. func (suite *PrivateKeySignerTestSuite) TestVerifyMessageWithMetaMaskFormat() { message := "Test MetaMask format" @@ -322,43 +322,43 @@ func (suite *PrivateKeySignerTestSuite) TestVerifyMessageWithMetaMaskFormat() { func (suite *PrivateKeySignerTestSuite) TestSignTransaction() { // Create a simple transaction nonce := uint64(0) - to := common.HexToAddress("0x1234567890123456789012345678901234567890") + toAddress := common.HexToAddress("0x1234567890123456789012345678901234567890") amount := big.NewInt(1000000000000000000) // 1 ETH gasLimit := uint64(21000) gasFeeCap := big.NewInt(30000000000) // 30 gwei gasTipCap := big.NewInt(2000000000) // 2 gwei chainID := big.NewInt(testChainID) - tx := types.NewTx(&types.DynamicFeeTx{ + transaction := types.NewTx(&types.DynamicFeeTx{ ChainID: chainID, Nonce: nonce, GasTipCap: gasTipCap, GasFeeCap: gasFeeCap, Gas: gasLimit, - To: &to, + To: &toAddress, Value: amount, Data: nil, }) // Sign the transaction - signedTx, err := suite.signer.SignTransaction(tx) + signedTx, err := suite.signer.SignTransaction(transaction) suite.NoError(err, "SignTransaction should not return error") suite.NotNil(signedTx, "signed transaction should not be nil") // Verify the transaction has a signature - v, r, s := signedTx.RawSignatureValues() - suite.NotNil(v, "v value should not be nil") - suite.NotNil(r, "r value should not be nil") - suite.NotNil(s, "s value should not be nil") - suite.True(r.Sign() > 0, "r should be positive") - suite.True(s.Sign() > 0, "s should be positive") + vValue, rValue, sValue := signedTx.RawSignatureValues() + suite.NotNil(vValue, "v value should not be nil") + suite.NotNil(rValue, "r value should not be nil") + suite.NotNil(sValue, "s value should not be nil") + suite.True(rValue.Sign() > 0, "r should be positive") + suite.True(sValue.Sign() > 0, "s should be positive") suite.T().Logf("Transaction signed successfully") suite.T().Logf("Transaction hash: %s", signedTx.Hash().Hex()) - suite.T().Logf("V: %s, R: %s, S: %s", v.String(), r.String(), s.String()) + suite.T().Logf("V: %s, R: %s, S: %s", vValue.String(), rValue.String(), sValue.String()) } -// TestSignTransactionAndSendToE2E is an E2E integration test +// TestSignTransactionAndSendToE2E is an E2E integration test. func (suite *PrivateKeySignerTestSuite) TestSignTransactionAndSendToE2E() { // Create transport transport, err := transport.NewHTTPTransport(testEndpoint, 30*time.Second) @@ -379,28 +379,28 @@ func (suite *PrivateKeySignerTestSuite) TestSignTransactionAndSendToE2E() { // Create a transaction to send ETH to another address // Use Anvil's second default account as recipient - to := common.HexToAddress("0x70997970C51812dc3A010C7d01b50e0d17dc79C8") + toAddress := common.HexToAddress("0x70997970C51812dc3A010C7d01b50e0d17dc79C8") amount := big.NewInt(1000000000000000000) // 1 ETH gasLimit := uint64(21000) gasFeeCap := big.NewInt(30000000000) // 30 gwei gasTipCap := big.NewInt(2000000000) // 2 gwei chainID := big.NewInt(testChainID) - tx := types.NewTx(&types.DynamicFeeTx{ + transaction := types.NewTx(&types.DynamicFeeTx{ ChainID: chainID, Nonce: nonce, GasTipCap: gasTipCap, GasFeeCap: gasFeeCap, Gas: gasLimit, - To: &to, + To: &toAddress, Value: amount, Data: nil, }) - suite.T().Logf("Created transaction to send %s wei to %s", amount.String(), to.Hex()) + suite.T().Logf("Created transaction to send %s wei to %s", amount.String(), toAddress.Hex()) // Sign the transaction - signedTx, err := suite.signer.SignTransaction(tx) + signedTx, err := suite.signer.SignTransaction(transaction) suite.Require().NoError(err, "failed to sign transaction") suite.T().Logf("Transaction signed, hash: %s", signedTx.Hash().Hex()) @@ -427,7 +427,15 @@ func (suite *PrivateKeySignerTestSuite) TestSignTransactionAndSendToE2E() { suite.T().Logf("Balance after: %s wei", balanceAfter.String()) // Balance should decrease by amount + gas costs - expectedMaxDecrease := new(big.Int).Add(amount, new(big.Int).Mul(big.NewInt(int64(receipt.GasUsed)), gasFeeCap)) + // Check for potential overflow before converting uint64 to int64 + var gasUsedInt64 int64 + const maxInt64 = 1<<63 - 1 + if receipt.GasUsed > uint64(maxInt64) { + gasUsedInt64 = maxInt64 // Max int64 value + } else { + gasUsedInt64 = int64(receipt.GasUsed) + } + expectedMaxDecrease := new(big.Int).Add(amount, new(big.Int).Mul(big.NewInt(gasUsedInt64), gasFeeCap)) actualDecrease := new(big.Int).Sub(balanceBefore, balanceAfter) suite.T().Logf("Actual balance decrease: %s wei", actualDecrease.String()) suite.T().Logf("Expected max decrease: %s wei", expectedMaxDecrease.String()) @@ -438,7 +446,7 @@ func (suite *PrivateKeySignerTestSuite) TestSignTransactionAndSendToE2E() { suite.LessOrEqual(actualDecrease.Cmp(expectedMaxDecrease), 0, "balance should not decrease more than amount + gas") } -// TestPrivateKeySignerTestSuite runs the test suite +// TestPrivateKeySignerTestSuite runs the test suite. func TestPrivateKeySignerTestSuite(t *testing.T) { suite.Run(t, new(PrivateKeySignerTestSuite)) } diff --git a/internal/contract/evm/contract/signer/privatekey_with_transport.go b/internal/contract/evm/contract/signer/privatekey_with_transport.go index e6482a6..c2fc182 100644 --- a/internal/contract/evm/contract/signer/privatekey_with_transport.go +++ b/internal/contract/evm/contract/signer/privatekey_with_transport.go @@ -58,7 +58,7 @@ func (p *PrivateKeySignerWithTransport) executeReadOnlyCall(contractAddress comm // Call the contract using transport rawResult, err := p.transport.CallContract(contractAddress, contractABI, methodName, args...) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to call contract method %s: %w", methodName, err) } // If no outputs, return nil @@ -117,7 +117,7 @@ func (p *PrivateKeySignerWithTransport) buildTransaction(contractAddress common. // Get chain ID from transport chainID, err := p.transport.GetChainID() if err != nil { - return nil, err + return nil, fmt.Errorf("failed to get chain ID: %w", err) } // Use EIP-1559 transaction for better compatibility @@ -146,14 +146,14 @@ func (p *PrivateKeySignerWithTransport) buildTransaction(contractAddress common. estimatedGas, err := p.transport.EstimateGas(signedTempTx) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to estimate gas: %w", err) } // Add 50% buffer to gas estimate to avoid out-of-gas errors // Gas estimation can be inaccurate, especially for complex contracts gasLimit = estimatedGas + (estimatedGas / 2) } - tx := types.NewTx(&types.DynamicFeeTx{ + transaction := types.NewTx(&types.DynamicFeeTx{ ChainID: chainID, Nonce: nonce, GasTipCap: gasTipCap, @@ -163,7 +163,7 @@ func (p *PrivateKeySignerWithTransport) buildTransaction(contractAddress common. Value: value, Data: data, }) - return tx, nil + return transaction, nil } // executeWriteTransaction signs and sends a transaction, then waits for receipt. @@ -177,13 +177,13 @@ func (p *PrivateKeySignerWithTransport) executeWriteTransaction(tx *types.Transa // Send the transaction txHash, err := p.transport.SendTransaction(signedTx) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to send transaction: %w", err) } // Wait for transaction receipt receipt, err := p.transport.WaitForTransactionReceipt(txHash) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to wait for transaction receipt: %w", err) } // Return status and transaction hash @@ -213,25 +213,29 @@ func (p *PrivateKeySignerWithTransport) CallContractMethod(contractAddress commo signerAddress := p.PrivateKeySigner.GetAddress() nonce, err := p.transport.GetTransactionCount(signerAddress) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to get transaction count: %w", err) } // Set default parameters setDefaultTransactionParams(&value, &gasPrice) // Build transaction with gas estimation - tx, err := p.buildTransaction(contractAddress, nonce, value, gasLimit, gasPrice, data) + transaction, err := p.buildTransaction(contractAddress, nonce, value, gasLimit, gasPrice, data) if err != nil { return nil, err } // Execute the transaction - return p.executeWriteTransaction(tx) + return p.executeWriteTransaction(transaction) } // EstimateGas implements SignerWithTransport. func (p *PrivateKeySignerWithTransport) EstimateGas(tx *types.Transaction) (gas uint64, err error) { - return p.transport.EstimateGas(tx) + gas, err = p.transport.EstimateGas(tx) + if err != nil { + return 0, fmt.Errorf("failed to estimate gas: %w", err) + } + return gas, nil } // GetAddress implements SignerWithTransport. @@ -241,12 +245,20 @@ func (p *PrivateKeySignerWithTransport) GetAddress() (address common.Address, er // GetBalance implements SignerWithTransport. func (p *PrivateKeySignerWithTransport) GetBalance(address common.Address) (balance *big.Int, err error) { - return p.transport.GetBalance(address) + balance, err = p.transport.GetBalance(address) + if err != nil { + return nil, fmt.Errorf("failed to get balance: %w", err) + } + return balance, nil } // GetTransactionCount implements SignerWithTransport. func (p *PrivateKeySignerWithTransport) GetTransactionCount(address common.Address) (nonce uint64, err error) { - return p.transport.GetTransactionCount(address) + nonce, err = p.transport.GetTransactionCount(address) + if err != nil { + return 0, fmt.Errorf("failed to get transaction count: %w", err) + } + return nonce, nil } // SendTransaction implements SignerWithTransport. @@ -258,7 +270,11 @@ func (p *PrivateKeySignerWithTransport) SendTransaction(tx *types.Transaction) ( } // Send the signed transaction - return p.transport.SendTransaction(signedTx) + txHash, err = p.transport.SendTransaction(signedTx) + if err != nil { + return common.Hash{}, fmt.Errorf("failed to send transaction: %w", err) + } + return txHash, nil } // SignMessageString implements SignerWithTransport. @@ -278,5 +294,9 @@ func (p *PrivateKeySignerWithTransport) VerifyMessageString(address common.Addre // WaitForTransactionReceipt implements SignerWithTransport. func (p *PrivateKeySignerWithTransport) WaitForTransactionReceipt(txHash common.Hash) (receipt *types.Receipt, err error) { - return p.transport.WaitForTransactionReceipt(txHash) + receipt, err = p.transport.WaitForTransactionReceipt(txHash) + if err != nil { + return nil, fmt.Errorf("failed to wait for transaction receipt: %w", err) + } + return receipt, nil } diff --git a/internal/contract/evm/contract/signer/privatekey_with_transport_test.go b/internal/contract/evm/contract/signer/privatekey_with_transport_test.go index 5f1234c..dcc24f1 100644 --- a/internal/contract/evm/contract/signer/privatekey_with_transport_test.go +++ b/internal/contract/evm/contract/signer/privatekey_with_transport_test.go @@ -59,7 +59,7 @@ contract TestContract { } ` -// PrivateKeySignerWithTransportTestSuite is the test suite +// PrivateKeySignerWithTransportTestSuite is the test suite. type PrivateKeySignerWithTransportTestSuite struct { suite.Suite signer SignerWithTransport @@ -71,7 +71,7 @@ type PrivateKeySignerWithTransportTestSuite struct { chainID *big.Int } -// SetupSuite runs once before all tests +// SetupSuite runs once before all tests. func (suite *PrivateKeySignerWithTransportTestSuite) SetupSuite() { // Anvil test account private key suite.testPrivateKey = "ac0974bec39a17e36ba4a6b4d238ff944bacb478cbed5efcae784d7bf4f2ff80" @@ -91,15 +91,20 @@ func (suite *PrivateKeySignerWithTransportTestSuite) SetupSuite() { baseSigner, err := NewPrivateKeySigner(suite.testPrivateKey) suite.Require().NoError(err, "Failed to create signer") - pkSigner, ok := baseSigner.(*PrivateKeySigner) - suite.Require().True(ok, "Failed to cast to PrivateKeySigner") + pkSigner, isValid := baseSigner.(*PrivateKeySigner) + suite.Require().True(isValid, "Failed to cast to PrivateKeySigner") suite.signer = pkSigner.WithTransport(suite.transport) // Compile contract using solc-go compiler, err := solc.NewWithVersion("0.8.20") suite.Require().NoError(err, "Failed to create compiler") - defer compiler.Close() + defer func() { + if closeErr := compiler.Close(); closeErr != nil { + // Log error but don't fail the test + _ = closeErr + } + }() // Create input for compilation input := &solc.Input{ @@ -160,7 +165,7 @@ func (suite *PrivateKeySignerWithTransportTestSuite) SetupSuite() { suite.deployContract(bytecode) } -// deployContract deploys the test contract +// deployContract deploys the test contract. func (suite *PrivateKeySignerWithTransportTestSuite) deployContract(bytecode string) { // Get nonce nonce, err := suite.transport.GetTransactionCount(suite.testAddress) @@ -170,7 +175,7 @@ func (suite *PrivateKeySignerWithTransportTestSuite) deployContract(bytecode str deployData := common.FromHex(bytecode) // Use EIP-1559 transaction - tx := types.NewTx(&types.DynamicFeeTx{ + transaction := types.NewTx(&types.DynamicFeeTx{ ChainID: suite.chainID, Nonce: nonce, GasTipCap: big.NewInt(1000000000), // 1 gwei @@ -182,7 +187,7 @@ func (suite *PrivateKeySignerWithTransportTestSuite) deployContract(bytecode str }) // Send transaction - txHash, err := suite.signer.SendTransaction(tx) + txHash, err := suite.signer.SendTransaction(transaction) suite.Require().NoError(err, "Failed to send deployment transaction") suite.T().Logf("Deployment transaction sent: %s", txHash.Hex()) @@ -195,14 +200,14 @@ func (suite *PrivateKeySignerWithTransportTestSuite) deployContract(bytecode str suite.Require().NotEqual(common.Address{}, suite.contractAddress, "Contract address is empty") } -// TestGetAddress tests the GetAddress method +// TestGetAddress tests the GetAddress method. func (suite *PrivateKeySignerWithTransportTestSuite) TestGetAddress() { address, err := suite.signer.GetAddress() suite.Require().NoError(err) suite.Assert().Equal(suite.testAddress, address) } -// TestGetBalance tests the GetBalance method +// TestGetBalance tests the GetBalance method. func (suite *PrivateKeySignerWithTransportTestSuite) TestGetBalance() { balance, err := suite.signer.GetBalance(suite.testAddress) suite.Require().NoError(err) @@ -211,14 +216,14 @@ func (suite *PrivateKeySignerWithTransportTestSuite) TestGetBalance() { suite.Assert().True(balance.Cmp(big.NewInt(0)) > 0) } -// TestGetTransactionCount tests nonce retrieval +// TestGetTransactionCount tests nonce retrieval. func (suite *PrivateKeySignerWithTransportTestSuite) TestGetTransactionCount() { nonce, err := suite.signer.GetTransactionCount(suite.testAddress) suite.Require().NoError(err) suite.Assert().True(nonce > 0, "Nonce should be > 0 after deployment") } -// TestCallContractMethod_PureFunction tests calling a pure function +// TestCallContractMethod_PureFunction tests calling a pure function. func (suite *PrivateKeySignerWithTransportTestSuite) TestCallContractMethod_PureFunction() { // Call add(10, 20) result, err := suite.signer.CallContractMethod( @@ -247,7 +252,7 @@ func (suite *PrivateKeySignerWithTransportTestSuite) TestCallContractMethod_Pure suite.Assert().True(method.IsReadOnly()) } -// TestCallContractMethod_ViewFunction tests calling a view function +// TestCallContractMethod_ViewFunction tests calling a view function. func (suite *PrivateKeySignerWithTransportTestSuite) TestCallContractMethod_ViewFunction() { // First set a value so we have something to read _, err := suite.signer.CallContractMethod( @@ -275,8 +280,8 @@ func (suite *PrivateKeySignerWithTransportTestSuite) TestCallContractMethod_View suite.Require().Len(result, 2, "Expected 2 return values") // Check value - valueBigInt, ok := result[0].(*big.Int) - suite.Require().True(ok, "First result should be *big.Int") + valueBigInt, isOk := result[0].(*big.Int) + suite.Require().True(isOk, "First result should be *big.Int") suite.Assert().Equal(int64(42), valueBigInt.Int64()) // Check address @@ -291,7 +296,7 @@ func (suite *PrivateKeySignerWithTransportTestSuite) TestCallContractMethod_View suite.Assert().True(method.IsReadOnly()) } -// TestCallContractMethod_NonPayableWrite tests a non-payable write function +// TestCallContractMethod_NonPayableWrite tests a non-payable write function. func (suite *PrivateKeySignerWithTransportTestSuite) TestCallContractMethod_NonPayableWrite() { // Call setValue(123) result, err := suite.signer.CallContractMethod( @@ -308,12 +313,12 @@ func (suite *PrivateKeySignerWithTransportTestSuite) TestCallContractMethod_NonP suite.Require().Len(result, 2, "Expected 2 return values (status, txHash)") // Check transaction succeeded - status, ok := result[0].(uint64) - suite.Require().True(ok, "First result should be uint64 status") + status, statusOk := result[0].(uint64) + suite.Require().True(statusOk, "First result should be uint64 status") // Check tx hash - txHash, ok := result[1].(string) - suite.Require().True(ok, "Second result should be string tx hash") + txHash, hashOk := result[1].(string) + suite.Require().True(hashOk, "Second result should be string tx hash") suite.Assert().NotEmpty(txHash) suite.T().Logf("setValue transaction hash: %s, status: %d", txHash, status) @@ -332,8 +337,8 @@ func (suite *PrivateKeySignerWithTransportTestSuite) TestCallContractMethod_NonP suite.Require().NoError(err) suite.Require().Len(readResult, 1) - value, ok := readResult[0].(*big.Int) - suite.Require().True(ok) + value, isOk := readResult[0].(*big.Int) + suite.Require().True(isOk) suite.Assert().Equal(int64(123), value.Int64(), "Value should be updated to 123") // Verify method is nonpayable using enum @@ -344,7 +349,7 @@ func (suite *PrivateKeySignerWithTransportTestSuite) TestCallContractMethod_NonP suite.Assert().False(method.IsPayable()) } -// TestCallContractMethod_PayableFunction tests a payable function +// TestCallContractMethod_PayableFunction tests a payable function. func (suite *PrivateKeySignerWithTransportTestSuite) TestCallContractMethod_PayableFunction() { // Get contract balance before balanceBefore, err := suite.transport.GetBalance(suite.contractAddress) @@ -366,8 +371,8 @@ func (suite *PrivateKeySignerWithTransportTestSuite) TestCallContractMethod_Paya suite.Require().Len(result, 2) // Check status - status, ok := result[0].(uint64) - suite.Require().True(ok) + status, statusOk := result[0].(uint64) + suite.Require().True(statusOk) suite.Assert().Equal(uint64(1), status) // Verify contract balance increased @@ -386,7 +391,7 @@ func (suite *PrivateKeySignerWithTransportTestSuite) TestCallContractMethod_Paya suite.Assert().True(method.IsWriteOperation()) } -// TestEstimateGas tests gas estimation +// TestEstimateGas tests gas estimation. func (suite *PrivateKeySignerWithTransportTestSuite) TestEstimateGas() { // Create a transaction with actual contract call data nonce, err := suite.transport.GetTransactionCount(suite.testAddress) @@ -394,7 +399,7 @@ func (suite *PrivateKeySignerWithTransportTestSuite) TestEstimateGas() { // Simple ETH transfer (not a contract call) for gas estimation recipient := common.HexToAddress("0x70997970C51812dc3A010C7d01b50e0d17dc79C8") - tx := types.NewTx(&types.DynamicFeeTx{ + transaction2 := types.NewTx(&types.DynamicFeeTx{ ChainID: suite.chainID, Nonce: nonce, GasTipCap: big.NewInt(1000000000), // 1 gwei @@ -406,13 +411,13 @@ func (suite *PrivateKeySignerWithTransportTestSuite) TestEstimateGas() { }) // Estimate gas - gas, err := suite.signer.EstimateGas(tx) + gas, err := suite.signer.EstimateGas(transaction2) suite.Require().NoError(err) suite.Assert().True(gas > 0, "Gas estimate should be positive") suite.Assert().True(gas >= 21000, "Gas estimate should be at least 21000 for a simple transfer") } -// TestSignAndVerifyMessage tests message signing and verification +// TestSignAndVerifyMessage tests message signing and verification. func (suite *PrivateKeySignerWithTransportTestSuite) TestSignAndVerifyMessage() { message := "Hello, Ethereum!" @@ -434,7 +439,7 @@ func (suite *PrivateKeySignerWithTransportTestSuite) TestSignAndVerifyMessage() suite.Assert().False(isValid, "Signature should be invalid for wrong address") } -// TestSendTransaction_Manual tests manual transaction sending +// TestSendTransaction_Manual tests manual transaction sending. func (suite *PrivateKeySignerWithTransportTestSuite) TestSendTransaction_Manual() { // Get nonce nonce, err := suite.transport.GetTransactionCount(suite.testAddress) @@ -445,7 +450,7 @@ func (suite *PrivateKeySignerWithTransportTestSuite) TestSendTransaction_Manual( value := big.NewInt(1000000000000000) // 0.001 ETH // Use EIP-1559 transaction - tx := types.NewTx(&types.DynamicFeeTx{ + transaction3 := types.NewTx(&types.DynamicFeeTx{ ChainID: suite.chainID, Nonce: nonce, GasTipCap: big.NewInt(1000000000), // 1 gwei @@ -457,7 +462,7 @@ func (suite *PrivateKeySignerWithTransportTestSuite) TestSendTransaction_Manual( }) // Send transaction - txHash, err := suite.signer.SendTransaction(tx) + txHash, err := suite.signer.SendTransaction(transaction3) suite.Require().NoError(err) suite.Assert().NotEqual(common.Hash{}, txHash) @@ -467,7 +472,7 @@ func (suite *PrivateKeySignerWithTransportTestSuite) TestSendTransaction_Manual( suite.Assert().Equal(uint64(1), receipt.Status) } -// TestStateMutabilityHelpers tests enum helper methods +// TestStateMutabilityHelpers tests enum helper methods. func (suite *PrivateKeySignerWithTransportTestSuite) TestStateMutabilityHelpers() { // Test pure function addMethod := findMethodInABI(suite.contractABI, "add") @@ -501,7 +506,7 @@ func (suite *PrivateKeySignerWithTransportTestSuite) TestStateMutabilityHelpers( suite.Assert().True(depositMethod.IsWritable()) } -// TestErrorHandling tests various error cases +// TestErrorHandling tests various error cases. func (suite *PrivateKeySignerWithTransportTestSuite) TestErrorHandling() { // Test non-existent method _, err := suite.signer.CallContractMethod( @@ -527,7 +532,7 @@ func (suite *PrivateKeySignerWithTransportTestSuite) TestErrorHandling() { suite.Assert().Error(err, "Should error for invalid contract address") } -// TestRunSuite runs the test suite +// TestRunSuite runs the test suite. func TestRunSuite(t *testing.T) { suite.Run(t, new(PrivateKeySignerWithTransportTestSuite)) } diff --git a/internal/contract/evm/contract/transport/http.go b/internal/contract/evm/contract/transport/http.go index c1b1b19..90ca84b 100644 --- a/internal/contract/evm/contract/transport/http.go +++ b/internal/contract/evm/contract/transport/http.go @@ -100,19 +100,19 @@ func (h *HTTPTransport) CallContract(contractAddress common.Address, customABI c } // EstimateGas implements Transport. -func (h *HTTPTransport) EstimateGas(tx *types.Transaction) (gas uint64, err error) { +func (h *HTTPTransport) EstimateGas(transaction *types.Transaction) (gas uint64, err error) { ctx := context.Background() // Estimate gas for the transaction msg := ethereum.CallMsg{ - To: tx.To(), - Gas: tx.Gas(), - GasPrice: tx.GasPrice(), - GasFeeCap: tx.GasFeeCap(), - GasTipCap: tx.GasTipCap(), - Value: tx.Value(), - Data: tx.Data(), - AccessList: tx.AccessList(), + To: transaction.To(), + Gas: transaction.Gas(), + GasPrice: transaction.GasPrice(), + GasFeeCap: transaction.GasFeeCap(), + GasTipCap: transaction.GasTipCap(), + Value: transaction.Value(), + Data: transaction.Data(), + AccessList: transaction.AccessList(), } gas, err = h.client.EstimateGas(ctx, msg) @@ -148,15 +148,15 @@ func (h *HTTPTransport) GetTransactionCount(address common.Address) (nonce uint6 } // SendTransaction implements Transport. -func (h *HTTPTransport) SendTransaction(tx *types.Transaction) (txHash common.Hash, err error) { +func (h *HTTPTransport) SendTransaction(transaction *types.Transaction) (txHash common.Hash, err error) { ctx := context.Background() - err = h.client.SendTransaction(ctx, tx) + err = h.client.SendTransaction(ctx, transaction) if err != nil { return common.Hash{}, errors.WrapTransportError(err, errors.ErrCodeTransactionSendFailed, "failed to send transaction") } - return tx.Hash(), nil + return transaction.Hash(), nil } // WaitForTransactionReceipt implements Transport. diff --git a/internal/contract/evm/contract/transport/http_test.go b/internal/contract/evm/contract/transport/http_test.go index 4c6467e..7cd32b7 100644 --- a/internal/contract/evm/contract/transport/http_test.go +++ b/internal/contract/evm/contract/transport/http_test.go @@ -12,13 +12,13 @@ import ( ) const ( - // Anvil default endpoint + // Anvil default endpoint. testEndpoint = "http://localhost:8545" - // Anvil test account with pre-funded balance + // Anvil test account with pre-funded balance. testAddress = "0xf39Fd6e51aad88F6F4ce6aB8827279cffFb92266" ) -// Simple ERC20-like ABI for testing +// Simple ERC20-like ABI for testing. var testABI = `[ { "type": "function", @@ -51,7 +51,7 @@ var testABI = `[ } ]` -// HTTPTransportTestSuite defines the test suite for HTTP transport +// HTTPTransportTestSuite defines the test suite for HTTP transport. type HTTPTransportTestSuite struct { suite.Suite transport Transport @@ -60,7 +60,7 @@ type HTTPTransportTestSuite struct { contractAddress common.Address } -// SetupSuite runs once before all tests in the suite +// SetupSuite runs once before all tests in the suite. func (suite *HTTPTransportTestSuite) SetupSuite() { // Parse test ABI err := json.Unmarshal([]byte(testABI), &suite.testABIObj) @@ -78,22 +78,22 @@ func (suite *HTTPTransportTestSuite) SetupSuite() { suite.transport = transport } -// TearDownSuite runs once after all tests in the suite +// TearDownSuite runs once after all tests in the suite. func (suite *HTTPTransportTestSuite) TearDownSuite() { // Cleanup if needed } -// SetupTest runs before each test +// SetupTest runs before each test. func (suite *HTTPTransportTestSuite) SetupTest() { // Per-test setup if needed } -// TearDownTest runs after each test +// TearDownTest runs after each test. func (suite *HTTPTransportTestSuite) TearDownTest() { // Per-test cleanup if needed } -// TestNewHTTPTransport tests transport initialization +// TestNewHTTPTransport tests transport initialization. func (suite *HTTPTransportTestSuite) TestNewHTTPTransport() { tests := []struct { name string @@ -119,19 +119,19 @@ func (suite *HTTPTransportTestSuite) TestNewHTTPTransport() { }, } - for _, tt := range tests { - suite.Run(tt.name, func() { - transport, err := NewHTTPTransport(tt.endpoint, 30*time.Second) + for _, testCase := range tests { + suite.Run(testCase.name, func() { + transport, err := NewHTTPTransport(testCase.endpoint, 30*time.Second) - if tt.wantErr { + if testCase.wantErr { suite.Error(err, "expected error but got none") - if tt.errContains != "" && err != nil { - suite.Contains(err.Error(), tt.errContains, "error should contain expected text") + if testCase.errContains != "" && err != nil { + suite.Contains(err.Error(), testCase.errContains, "error should contain expected text") } return } - if tt.endpoint == testEndpoint && err != nil { + if testCase.endpoint == testEndpoint && err != nil { suite.T().Skipf("Anvil network not running: %v", err) } @@ -141,7 +141,7 @@ func (suite *HTTPTransportTestSuite) TestNewHTTPTransport() { } } -// TestGetBalance tests balance retrieval +// TestGetBalance tests balance retrieval. func (suite *HTTPTransportTestSuite) TestGetBalance() { tests := []struct { name string @@ -165,11 +165,11 @@ func (suite *HTTPTransportTestSuite) TestGetBalance() { }, } - for _, tt := range tests { - suite.Run(tt.name, func() { - balance, err := suite.transport.GetBalance(tt.address) + for _, testCase := range tests { + suite.Run(testCase.name, func() { + balance, err := suite.transport.GetBalance(testCase.address) - if tt.wantErr { + if testCase.wantErr { suite.Error(err, "expected error but got none") return } @@ -178,19 +178,19 @@ func (suite *HTTPTransportTestSuite) TestGetBalance() { suite.NotNil(balance, "expected balance but got nil") // Check minimum balance for test account - if tt.checkMinimum { + if testCase.checkMinimum { suite.GreaterOrEqual( - balance.Cmp(tt.minimumAmount), + balance.Cmp(testCase.minimumAmount), 0, - "balance should be >= %v, got %v", tt.minimumAmount, balance, + "balance should be >= %v, got %v", testCase.minimumAmount, balance, ) - suite.T().Logf("Account %s has balance: %v wei", tt.address.Hex(), balance) + suite.T().Logf("Account %s has balance: %v wei", testCase.address.Hex(), balance) } }) } } -// TestGetTransactionCount tests nonce retrieval +// TestGetTransactionCount tests nonce retrieval. func (suite *HTTPTransportTestSuite) TestGetTransactionCount() { tests := []struct { name string @@ -209,22 +209,22 @@ func (suite *HTTPTransportTestSuite) TestGetTransactionCount() { }, } - for _, tt := range tests { - suite.Run(tt.name, func() { - nonce, err := suite.transport.GetTransactionCount(tt.address) + for _, testCase := range tests { + suite.Run(testCase.name, func() { + nonce, err := suite.transport.GetTransactionCount(testCase.address) - if tt.wantErr { + if testCase.wantErr { suite.Error(err, "expected error but got none") return } suite.NoError(err, "GetTransactionCount should not return error") - suite.T().Logf("Account %s has nonce: %d", tt.address.Hex(), nonce) + suite.T().Logf("Account %s has nonce: %d", testCase.address.Hex(), nonce) }) } } -// TestCallContract tests contract call functionality +// TestCallContract tests contract call functionality. func (suite *HTTPTransportTestSuite) TestCallContract() { suite.Run("call non-existent contract", func() { // Calling a non-existent contract should either return empty data or an error @@ -244,12 +244,12 @@ func (suite *HTTPTransportTestSuite) TestCallContract() { }) } -// TestEstimateGas tests gas estimation +// TestEstimateGas tests gas estimation. func (suite *HTTPTransportTestSuite) TestEstimateGas() { suite.T().Skip("Gas estimation requires a properly signed transaction - skipping in basic e2e test") } -// TestSequentialOperations tests multiple operations in sequence +// TestSequentialOperations tests multiple operations in sequence. func (suite *HTTPTransportTestSuite) TestSequentialOperations() { // Get balance balance, err := suite.transport.GetBalance(suite.testAddr) @@ -270,7 +270,7 @@ func (suite *HTTPTransportTestSuite) TestSequentialOperations() { ) } -// TestConcurrentOperations tests multiple operations running concurrently +// TestConcurrentOperations tests multiple operations running concurrently. func (suite *HTTPTransportTestSuite) TestConcurrentOperations() { done := make(chan error, 2) @@ -299,7 +299,7 @@ func (suite *HTTPTransportTestSuite) TestConcurrentOperations() { } } -// TestHTTPTransportTestSuite runs the test suite +// TestHTTPTransportTestSuite runs the test suite. func TestHTTPTransportTestSuite(t *testing.T) { suite.Run(t, new(HTTPTransportTestSuite)) } diff --git a/internal/contract/evm/storage/models/evm/abi.go b/internal/contract/evm/storage/models/evm/abi.go index c35cc19..c176b25 100644 --- a/internal/contract/evm/storage/models/evm/abi.go +++ b/internal/contract/evm/storage/models/evm/abi.go @@ -46,7 +46,7 @@ func (a *AbiArrayType) Scan(value any) error { parsed, err := abi.ParseAbi(string(bytes)) if err != nil { - return err + return fmt.Errorf("failed to parse ABI: %w", err) } a.AbiArray = parsed diff --git a/internal/contract/evm/storage/models/evm/models_test.go b/internal/contract/evm/storage/models/evm/models_test.go index 68b1043..14e0101 100644 --- a/internal/contract/evm/storage/models/evm/models_test.go +++ b/internal/contract/evm/storage/models/evm/models_test.go @@ -10,13 +10,13 @@ import ( "gorm.io/gorm/logger" ) -// ModelsTestSuite is the test suite for all EVM models +// ModelsTestSuite is the test suite for all EVM models. type ModelsTestSuite struct { suite.Suite db *gorm.DB } -// SetupTest is called before each test +// SetupTest is called before each test. func (suite *ModelsTestSuite) SetupTest() { var err error // Use in-memory SQLite database with foreign key support enabled @@ -38,15 +38,19 @@ func (suite *ModelsTestSuite) SetupTest() { suite.Require().NoError(err) } -// TearDownTest is called after each test +// TearDownTest is called after each test. func (suite *ModelsTestSuite) TearDownTest() { sqlDB, err := suite.db.DB() if err == nil { - sqlDB.Close() + closeErr := sqlDB.Close() + if closeErr != nil { + // Log error but don't fail the test + _ = closeErr + } } } -// TestEvmAbi_CRUD tests CRUD operations for EvmAbi +// TestEvmAbi_CRUD tests CRUD operations for EvmAbi. func (suite *ModelsTestSuite) TestEvmAbi_CRUD() { // Create abiJSON := `[{"type":"function","name":"balanceOf","inputs":[{"name":"owner","type":"address"}],"outputs":[{"name":"balance","type":"uint256"}],"stateMutability":"view"}]` @@ -99,7 +103,7 @@ func (suite *ModelsTestSuite) TestEvmAbi_CRUD() { suite.Assert().ErrorIs(result.Error, gorm.ErrRecordNotFound) } -// TestEvmAbi_UniqueConstraint tests the unique constraint on Name +// TestEvmAbi_UniqueConstraint tests the unique constraint on Name. func (suite *ModelsTestSuite) TestEvmAbi_UniqueConstraint() { abiJSON := `[{"type":"function","name":"test","inputs":[],"outputs":[],"stateMutability":"view"}]` parsed, err := abi.ParseAbi(abiJSON) @@ -127,7 +131,7 @@ func (suite *ModelsTestSuite) TestEvmAbi_UniqueConstraint() { suite.Assert().Error(result.Error) } -// TestEVMEndpoint_CRUD tests CRUD operations for EVMEndpoint +// TestEVMEndpoint_CRUD tests CRUD operations for EVMEndpoint. func (suite *ModelsTestSuite) TestEVMEndpoint_CRUD() { // Create endpoint := &EVMEndpoint{ @@ -168,7 +172,7 @@ func (suite *ModelsTestSuite) TestEVMEndpoint_CRUD() { suite.Assert().ErrorIs(result.Error, gorm.ErrRecordNotFound) } -// TestEVMEndpoint_UniqueConstraint tests the unique constraint on Name +// TestEVMEndpoint_UniqueConstraint tests the unique constraint on Name. func (suite *ModelsTestSuite) TestEVMEndpoint_UniqueConstraint() { endpoint1 := &EVMEndpoint{ Name: "TestEndpoint", @@ -190,7 +194,7 @@ func (suite *ModelsTestSuite) TestEVMEndpoint_UniqueConstraint() { suite.Assert().Error(result.Error) } -// TestEVMContract_CRUD tests CRUD operations for EVMContract +// TestEVMContract_CRUD tests CRUD operations for EVMContract. func (suite *ModelsTestSuite) TestEVMContract_CRUD() { // Create endpoint first endpoint := &EVMEndpoint{ @@ -264,7 +268,7 @@ func (suite *ModelsTestSuite) TestEVMContract_CRUD() { suite.Assert().ErrorIs(result.Error, gorm.ErrRecordNotFound) } -// TestEVMContract_CompositeUniqueIndex tests the composite unique constraint +// TestEVMContract_CompositeUniqueIndex tests the composite unique constraint. func (suite *ModelsTestSuite) TestEVMContract_CompositeUniqueIndex() { // Create endpoint endpoint := &EVMEndpoint{ @@ -307,7 +311,7 @@ func (suite *ModelsTestSuite) TestEVMContract_CompositeUniqueIndex() { suite.Assert().NoError(result.Error) } -// TestEVMContract_CascadeDelete tests cascade delete when endpoint is deleted +// TestEVMContract_CascadeDelete tests cascade delete when endpoint is deleted. func (suite *ModelsTestSuite) TestEVMContract_CascadeDelete() { // Create endpoint endpoint := &EVMEndpoint{ @@ -342,7 +346,7 @@ func (suite *ModelsTestSuite) TestEVMContract_CascadeDelete() { suite.Assert().ErrorIs(result.Error, gorm.ErrRecordNotFound) } -// TestEVMContract_SetNullOnAbiDelete tests SET NULL when ABI is deleted +// TestEVMContract_SetNullOnAbiDelete tests SET NULL when ABI is deleted. func (suite *ModelsTestSuite) TestEVMContract_SetNullOnAbiDelete() { // Create endpoint endpoint := &EVMEndpoint{ @@ -390,7 +394,7 @@ func (suite *ModelsTestSuite) TestEVMContract_SetNullOnAbiDelete() { suite.Assert().Nil(updated.AbiId) } -// TestEVMConfig_CRUD tests CRUD operations for EVMConfig +// TestEVMConfig_CRUD tests CRUD operations for EVMConfig. func (suite *ModelsTestSuite) TestEVMConfig_CRUD() { // Create endpoint endpoint := &EVMEndpoint{ @@ -463,7 +467,7 @@ func (suite *ModelsTestSuite) TestEVMConfig_CRUD() { suite.Assert().ErrorIs(result.Error, gorm.ErrRecordNotFound) } -// TestEVMConfig_SetNullOnForeignKeyDelete tests SET NULL behavior +// TestEVMConfig_SetNullOnForeignKeyDelete tests SET NULL behavior. func (suite *ModelsTestSuite) TestEVMConfig_SetNullOnForeignKeyDelete() { // Create all required entities endpoint := &EVMEndpoint{ @@ -516,7 +520,7 @@ func (suite *ModelsTestSuite) TestEVMConfig_SetNullOnForeignKeyDelete() { suite.Assert().NotNil(updated.SelectedEVMContractId) } -// TestAbiArrayType_NullValue tests AbiArrayType with NULL value +// TestAbiArrayType_NullValue tests AbiArrayType with NULL value. func (suite *ModelsTestSuite) TestAbiArrayType_NullValue() { evmAbi := &EvmAbi{ Name: "EmptyABI", @@ -534,7 +538,7 @@ func (suite *ModelsTestSuite) TestAbiArrayType_NullValue() { suite.Assert().Nil(retrieved.Abi.AbiArray) } -// TestConcurrentOperations tests concurrent database operations +// TestConcurrentOperations tests concurrent database operations. func (suite *ModelsTestSuite) TestConcurrentOperations() { endpoint := &EVMEndpoint{ Name: "TestEndpoint", @@ -562,7 +566,7 @@ func (suite *ModelsTestSuite) TestConcurrentOperations() { suite.Assert().Equal(int64(10), count) } -// TestRunSuite runs the test suite +// TestRunSuite runs the test suite. func TestRunSuite(t *testing.T) { suite.Run(t, new(ModelsTestSuite)) } diff --git a/internal/contract/evm/storage/sql/sqlite.go b/internal/contract/evm/storage/sql/sqlite.go index e3ff69e..75253de 100644 --- a/internal/contract/evm/storage/sql/sqlite.go +++ b/internal/contract/evm/storage/sql/sqlite.go @@ -23,27 +23,34 @@ type SQLiteStorage struct { // CountABIs implements Storage. func (s *SQLiteStorage) CountABIs() (count int64, err error) { - return s.abiQueries.Count() + count, err = s.abiQueries.Count() + if err != nil { + return 0, fmt.Errorf("failed to count ABIs: %w", err) + } + return count, nil } // CreateABI implements Storage. func (s *SQLiteStorage) CreateABI(abi models.EvmAbi) (id uint, err error) { if err := s.abiQueries.Create(&abi); err != nil { - return 0, err + return 0, fmt.Errorf("failed to create ABI: %w", err) } return abi.ID, nil } // DeleteABI implements Storage. func (s *SQLiteStorage) DeleteABI(id uint) (err error) { - return s.abiQueries.Delete(id) + if err := s.abiQueries.Delete(id); err != nil { + return fmt.Errorf("failed to delete ABI: %w", err) + } + return nil } // GetABIByID implements Storage. func (s *SQLiteStorage) GetABIByID(id uint) (abi models.EvmAbi, err error) { result, err := s.abiQueries.GetByID(id) if err != nil { - return models.EvmAbi{}, err + return models.EvmAbi{}, fmt.Errorf("failed to get ABI by ID: %w", err) } return *result, nil } @@ -52,7 +59,7 @@ func (s *SQLiteStorage) GetABIByID(id uint) (abi models.EvmAbi, err error) { func (s *SQLiteStorage) ListABIs(page int64, pageSize int64) (abis types.Pagination[models.EvmAbi], err error) { result, err := s.abiQueries.List(page, pageSize) if err != nil { - return types.Pagination[models.EvmAbi]{}, err + return types.Pagination[models.EvmAbi]{}, fmt.Errorf("failed to list ABIs: %w", err) } return *result, nil } @@ -61,7 +68,7 @@ func (s *SQLiteStorage) ListABIs(page int64, pageSize int64) (abis types.Paginat func (s *SQLiteStorage) SearchABIs(query string) (abis types.Pagination[models.EvmAbi], err error) { result, err := s.abiQueries.Search(query) if err != nil { - return types.Pagination[models.EvmAbi]{}, err + return types.Pagination[models.EvmAbi]{}, fmt.Errorf("failed to search ABIs: %w", err) } return *result, nil } @@ -72,34 +79,44 @@ func (s *SQLiteStorage) UpdateABI(id uint, abi models.EvmAbi) (err error) { "name": abi.Name, "abi": abi.Abi, } - return s.abiQueries.Update(id, updates) + if err := s.abiQueries.Update(id, updates); err != nil { + return fmt.Errorf("failed to update ABI: %w", err) + } + return nil } // Endpoint Methods // CountEndpoints implements Storage. func (s *SQLiteStorage) CountEndpoints() (count int64, err error) { - return s.endpointQueries.Count() + count, err = s.endpointQueries.Count() + if err != nil { + return 0, fmt.Errorf("failed to count endpoints: %w", err) + } + return count, nil } // CreateEndpoint implements Storage. func (s *SQLiteStorage) CreateEndpoint(endpoint models.EVMEndpoint) (id uint, err error) { if err := s.endpointQueries.Create(&endpoint); err != nil { - return 0, err + return 0, fmt.Errorf("failed to create endpoint: %w", err) } return endpoint.ID, nil } // DeleteEndpoint implements Storage. func (s *SQLiteStorage) DeleteEndpoint(id uint) (err error) { - return s.endpointQueries.Delete(id) + if err := s.endpointQueries.Delete(id); err != nil { + return fmt.Errorf("failed to delete endpoint: %w", err) + } + return nil } // GetEndpointByID implements Storage. func (s *SQLiteStorage) GetEndpointByID(id uint) (endpoint models.EVMEndpoint, err error) { result, err := s.endpointQueries.GetByID(id) if err != nil { - return models.EVMEndpoint{}, err + return models.EVMEndpoint{}, fmt.Errorf("failed to get endpoint by ID: %w", err) } return *result, nil } @@ -108,7 +125,7 @@ func (s *SQLiteStorage) GetEndpointByID(id uint) (endpoint models.EVMEndpoint, e func (s *SQLiteStorage) ListEndpoints(page int64, pageSize int64) (endpoints types.Pagination[models.EVMEndpoint], err error) { result, err := s.endpointQueries.List(page, pageSize) if err != nil { - return types.Pagination[models.EVMEndpoint]{}, err + return types.Pagination[models.EVMEndpoint]{}, fmt.Errorf("failed to list endpoints: %w", err) } return *result, nil } @@ -117,46 +134,56 @@ func (s *SQLiteStorage) ListEndpoints(page int64, pageSize int64) (endpoints typ func (s *SQLiteStorage) SearchEndpoints(query string) (endpoints types.Pagination[models.EVMEndpoint], err error) { result, err := s.endpointQueries.Search(query) if err != nil { - return types.Pagination[models.EVMEndpoint]{}, err + return types.Pagination[models.EVMEndpoint]{}, fmt.Errorf("failed to search endpoints: %w", err) } return *result, nil } // UpdateEndpoint implements Storage. -func (s *SQLiteStorage) UpdateEndpoint(id uint, endpoint models.EVMEndpoint) (err error) { +func (s *SQLiteStorage) UpdateEndpoint(endpointID uint, endpoint models.EVMEndpoint) (err error) { updates := map[string]any{ "name": endpoint.Name, "url": endpoint.Url, "chain_id": endpoint.ChainId, } - return s.endpointQueries.Update(id, updates) + if err := s.endpointQueries.Update(endpointID, updates); err != nil { + return fmt.Errorf("failed to update endpoint: %w", err) + } + return nil } // Contract Methods // CountContracts implements Storage. func (s *SQLiteStorage) CountContracts() (count int64, err error) { - return s.contractQueries.Count() + count, err = s.contractQueries.Count() + if err != nil { + return 0, fmt.Errorf("failed to count contracts: %w", err) + } + return count, nil } // CreateContract implements Storage. func (s *SQLiteStorage) CreateContract(contract models.EVMContract) (id uint, err error) { if err := s.contractQueries.Create(&contract); err != nil { - return 0, err + return 0, fmt.Errorf("failed to create contract: %w", err) } return contract.ID, nil } // DeleteContract implements Storage. func (s *SQLiteStorage) DeleteContract(id uint) (err error) { - return s.contractQueries.Delete(id) + if err := s.contractQueries.Delete(id); err != nil { + return fmt.Errorf("failed to delete contract: %w", err) + } + return nil } // GetContractByID implements Storage. func (s *SQLiteStorage) GetContractByID(id uint) (contract models.EVMContract, err error) { result, err := s.contractQueries.GetByID(id) if err != nil { - return models.EVMContract{}, err + return models.EVMContract{}, fmt.Errorf("failed to get contract by ID: %w", err) } return *result, nil } @@ -165,7 +192,7 @@ func (s *SQLiteStorage) GetContractByID(id uint) (contract models.EVMContract, e func (s *SQLiteStorage) ListContracts(page int64, pageSize int64) (contracts types.Pagination[models.EVMContract], err error) { result, err := s.contractQueries.List(page, pageSize) if err != nil { - return types.Pagination[models.EVMContract]{}, err + return types.Pagination[models.EVMContract]{}, fmt.Errorf("failed to list contracts: %w", err) } return *result, nil } @@ -174,13 +201,13 @@ func (s *SQLiteStorage) ListContracts(page int64, pageSize int64) (contracts typ func (s *SQLiteStorage) SearchContracts(query string) (contracts types.Pagination[models.EVMContract], err error) { result, err := s.contractQueries.Search(query) if err != nil { - return types.Pagination[models.EVMContract]{}, err + return types.Pagination[models.EVMContract]{}, fmt.Errorf("failed to search contracts: %w", err) } return *result, nil } // UpdateContract implements Storage. -func (s *SQLiteStorage) UpdateContract(id uint, contract models.EVMContract) (err error) { +func (s *SQLiteStorage) UpdateContract(contractID uint, contract models.EVMContract) (err error) { updates := map[string]any{ "name": contract.Name, "address": contract.Address, @@ -190,34 +217,44 @@ func (s *SQLiteStorage) UpdateContract(id uint, contract models.EVMContract) (er "bytecode": contract.Bytecode, "endpoint_id": contract.EndpointId, } - return s.contractQueries.Update(id, updates) + if err := s.contractQueries.Update(contractID, updates); err != nil { + return fmt.Errorf("failed to update contract: %w", err) + } + return nil } // Config Methods // CountConfigs implements Storage. func (s *SQLiteStorage) CountConfigs() (count int64, err error) { - return s.configQueries.Count() + count, err = s.configQueries.Count() + if err != nil { + return 0, fmt.Errorf("failed to count configs: %w", err) + } + return count, nil } // CreateConfig implements Storage. func (s *SQLiteStorage) CreateConfig(config models.EVMConfig) (id uint, err error) { if err := s.configQueries.Create(&config); err != nil { - return 0, err + return 0, fmt.Errorf("failed to create config: %w", err) } return config.ID, nil } // DeleteConfig implements Storage. func (s *SQLiteStorage) DeleteConfig(id uint) (err error) { - return s.configQueries.Delete(id) + if err := s.configQueries.Delete(id); err != nil { + return fmt.Errorf("failed to delete config: %w", err) + } + return nil } // GetConfigByID implements Storage. func (s *SQLiteStorage) GetConfigByID(id uint) (config models.EVMConfig, err error) { result, err := s.configQueries.GetByID(id) if err != nil { - return models.EVMConfig{}, err + return models.EVMConfig{}, fmt.Errorf("failed to get config by ID: %w", err) } return *result, nil } @@ -226,7 +263,7 @@ func (s *SQLiteStorage) GetConfigByID(id uint) (config models.EVMConfig, err err func (s *SQLiteStorage) ListConfigs(page int64, pageSize int64) (configs types.Pagination[models.EVMConfig], err error) { result, err := s.configQueries.List(page, pageSize) if err != nil { - return types.Pagination[models.EVMConfig]{}, err + return types.Pagination[models.EVMConfig]{}, fmt.Errorf("failed to list configs: %w", err) } return *result, nil } @@ -235,19 +272,22 @@ func (s *SQLiteStorage) ListConfigs(page int64, pageSize int64) (configs types.P func (s *SQLiteStorage) SearchConfigs(query string) (configs types.Pagination[models.EVMConfig], err error) { result, err := s.configQueries.Search(query) if err != nil { - return types.Pagination[models.EVMConfig]{}, err + return types.Pagination[models.EVMConfig]{}, fmt.Errorf("failed to search configs: %w", err) } return *result, nil } // UpdateConfig implements Storage. -func (s *SQLiteStorage) UpdateConfig(id uint, config models.EVMConfig) (err error) { +func (s *SQLiteStorage) UpdateConfig(configID uint, config models.EVMConfig) (err error) { updates := map[string]any{ "endpoint_id": config.EndpointId, "selected_evm_contract_id": config.SelectedEVMContractId, "selected_evm_abi_id": config.SelectedEVMAbiId, } - return s.configQueries.Update(id, updates) + if err := s.configQueries.Update(configID, updates); err != nil { + return fmt.Errorf("failed to update config: %w", err) + } + return nil } // NewSQLiteDB creates a new SQLite database connection. @@ -264,18 +304,18 @@ func NewSQLiteDB(dbPath string) (Storage, error) { // Ensure the directory exists dir := filepath.Dir(dbPath) - if err := os.MkdirAll(dir, 0755); err != nil { + if err := os.MkdirAll(dir, 0750); err != nil { return nil, fmt.Errorf("failed to create database directory: %w", err) } // Open database connection with GORM - db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{}) + database, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{}) if err != nil { return nil, fmt.Errorf("failed to open database: %w", err) } // Auto-migrate the schema - if err := db.AutoMigrate( + if err := database.AutoMigrate( &models.EvmAbi{}, &models.EVMEndpoint{}, &models.EVMContract{}, @@ -286,9 +326,9 @@ func NewSQLiteDB(dbPath string) (Storage, error) { // Initialize query helpers return &SQLiteStorage{ - abiQueries: queries.NewABIQueries(db), - endpointQueries: queries.NewEndpointQueries(db), - contractQueries: queries.NewContractQueries(db), - configQueries: queries.NewConfigQueries(db), + abiQueries: queries.NewABIQueries(database), + endpointQueries: queries.NewEndpointQueries(database), + contractQueries: queries.NewContractQueries(database), + configQueries: queries.NewConfigQueries(database), }, nil } diff --git a/internal/contract/evm/storage/sql/storage.go b/internal/contract/evm/storage/sql/storage.go index 10da5d6..8878fdf 100644 --- a/internal/contract/evm/storage/sql/storage.go +++ b/internal/contract/evm/storage/sql/storage.go @@ -1,6 +1,9 @@ package sql import ( + "fmt" + + "github.com/rxtech-lab/smart-contract-cli/internal/config" models "github.com/rxtech-lab/smart-contract-cli/internal/contract/evm/storage/models/evm" "github.com/rxtech-lab/smart-contract-cli/internal/contract/types" ) @@ -42,3 +45,19 @@ type Storage interface { UpdateConfig(id uint, config models.EVMConfig) (err error) DeleteConfig(id uint) (err error) } + +func GetStorage(storageType string, params ...any) (Storage, error) { + switch storageType { + case config.StorageClientTypeSQLite: + if len(params) == 0 || params[0] == nil { + return nil, fmt.Errorf("sqlite path is required") + } + sqlitePath, ok := params[0].(string) + if !ok { + return nil, fmt.Errorf("sqlite path must be a string") + } + return NewSQLiteDB(sqlitePath) + default: + return nil, fmt.Errorf("invalid storage type: %s", storageType) + } +} diff --git a/internal/storage/secure.go b/internal/storage/secure.go index 9d93b29..1dede99 100644 --- a/internal/storage/secure.go +++ b/internal/storage/secure.go @@ -336,14 +336,14 @@ func (s *SecureStorageWithEncryption) load() error { return err //nolint:wrapcheck // Need unwrapped error for os.IsNotExist() check } - var ed encryptedData - if err := json.Unmarshal(data, &ed); err != nil { + var encData encryptedData + if err := json.Unmarshal(data, &encData); err != nil { return fmt.Errorf("failed to unmarshal data: %w", err) } s.mu.Lock() - s.passwordHash = ed.PasswordHash - s.data = ed.Data + s.passwordHash = encData.PasswordHash + s.data = encData.Data s.mu.Unlock() return nil diff --git a/internal/storage/secure_test.go b/internal/storage/secure_test.go index 55fb700..c52b93e 100644 --- a/internal/storage/secure_test.go +++ b/internal/storage/secure_test.go @@ -34,12 +34,14 @@ func (s *SecureStorageTestSuite) SetupTest() { func (s *SecureStorageTestSuite) TearDownTest() { if s.storage != nil { - s.storage.Close() + err := s.storage.Close() + s.NoError(err, "Should close storage") } // Clean up temporary directory if s.tempDir != "" { - os.RemoveAll(s.tempDir) + err := os.RemoveAll(s.tempDir) + s.NoError(err, "Should clean up temp directory") } } @@ -47,7 +49,7 @@ func TestSecureStorageTestSuite(t *testing.T) { suite.Run(t, new(SecureStorageTestSuite)) } -// Test basic Set and Get operations +// Test basic Set and Get operations. func (s *SecureStorageTestSuite) TestSetAndGet() { err := s.storage.Set("key1", "value1") s.Require().NoError(err) @@ -57,7 +59,7 @@ func (s *SecureStorageTestSuite) TestSetAndGet() { s.Equal("value1", value) } -// Test setting multiple values +// Test setting multiple values. func (s *SecureStorageTestSuite) TestSetMultipleValues() { testData := map[string]string{ "username": "john_doe", @@ -78,7 +80,7 @@ func (s *SecureStorageTestSuite) TestSetMultipleValues() { } } -// Test getting non-existent key +// Test getting non-existent key. func (s *SecureStorageTestSuite) TestGetNonExistentKey() { value, err := s.storage.Get("non-existent-key") s.Error(err) @@ -86,7 +88,7 @@ func (s *SecureStorageTestSuite) TestGetNonExistentKey() { s.Contains(err.Error(), "key not found") } -// Test Delete operation +// Test Delete operation. func (s *SecureStorageTestSuite) TestDelete() { // Set a value err := s.storage.Set("key-to-delete", "value-to-delete") @@ -107,7 +109,7 @@ func (s *SecureStorageTestSuite) TestDelete() { s.Contains(err.Error(), "key not found") } -// Test List operation +// Test List operation. func (s *SecureStorageTestSuite) TestList() { // Set multiple values testKeys := []string{"key1", "key2", "key3"} @@ -127,7 +129,7 @@ func (s *SecureStorageTestSuite) TestList() { } } -// Test Clear operation +// Test Clear operation. func (s *SecureStorageTestSuite) TestClear() { // Set multiple values for i := 0; i < 5; i++ { @@ -151,7 +153,7 @@ func (s *SecureStorageTestSuite) TestClear() { s.Len(keys, 0) } -// Test persistence to file +// Test persistence to file. func (s *SecureStorageTestSuite) TestPersistence() { // Set some data testData := map[string]string{ @@ -175,7 +177,12 @@ func (s *SecureStorageTestSuite) TestPersistence() { // Create new storage instance with same file (it will auto-load existing data) newStorage, err := NewSecureStorageWithEncryption("test-encryption-key", s.tempFile) s.Require().NoError(err) - defer newStorage.Close() + defer func() { + if closeErr := newStorage.Close(); closeErr != nil { + // Log error but don't fail the test + _ = closeErr + } + }() // Verify data was loaded for key, expectedValue := range testData { @@ -185,7 +192,7 @@ func (s *SecureStorageTestSuite) TestPersistence() { } } -// Test encryption - verify data is actually encrypted on disk +// Test encryption - verify data is actually encrypted on disk. func (s *SecureStorageTestSuite) TestEncryptionOnDisk() { sensitiveData := "super_secret_password_12345" err := s.storage.Set("password", sensitiveData) @@ -206,17 +213,23 @@ func (s *SecureStorageTestSuite) TestEncryptionOnDisk() { s.Contains(string(fileContent), "data", "File should contain JSON structure") } -// Test wrong encryption key +// Test wrong encryption key. func (s *SecureStorageTestSuite) TestWrongEncryptionKey() { // Set data with first key err := s.storage.Set("secret", "value123") s.Require().NoError(err) - s.storage.Close() + err = s.storage.Close() + s.Require().NoError(err) // Try to load with different encryption key (but it will still load the file) wrongStorage, err := NewSecureStorageWithEncryption("wrong-key", s.tempFile) s.Require().NoError(err) - defer wrongStorage.Close() + defer func() { + if closeErr := wrongStorage.Close(); closeErr != nil { + // Log error but don't fail the test + _ = closeErr + } + }() // Should fail to decrypt because encryption key is different _, err = wrongStorage.Get("secret") @@ -224,14 +237,19 @@ func (s *SecureStorageTestSuite) TestWrongEncryptionKey() { s.Contains(err.Error(), "failed to decrypt") } -// Test in-memory storage (uses default path) +// Test in-memory storage (uses default path). func (s *SecureStorageTestSuite) TestDefaultPath() { // When no path is provided, it should use the default path // For testing, we'll provide an explicit path instead testPath := filepath.Join(s.tempDir, "default-path-test.json") storage, err := NewSecureStorageWithEncryption("memory-key", testPath) s.Require().NoError(err) - defer storage.Close() + defer func() { + if closeErr := storage.Close(); closeErr != nil { + // Log error but don't fail the test + _ = closeErr + } + }() // Create the storage err = storage.Create("test-password") @@ -250,29 +268,29 @@ func (s *SecureStorageTestSuite) TestDefaultPath() { s.NoError(err) } -// Test concurrent operations +// Test concurrent operations. func (s *SecureStorageTestSuite) TestConcurrentOperations() { numGoroutines := 10 done := make(chan bool, numGoroutines) // Concurrent writes - for i := 0; i < numGoroutines; i++ { + for index := 0; index < numGoroutines; index++ { go func(index int) { key := "concurrent-key-" + string(rune('0'+index)) value := "concurrent-value-" + string(rune('0'+index)) err := s.storage.Set(key, value) s.NoError(err) done <- true - }(i) + }(index) } // Wait for all writes to complete - for i := 0; i < numGoroutines; i++ { + for index := 0; index < numGoroutines; index++ { <-done } // Concurrent reads - for i := 0; i < numGoroutines; i++ { + for index := 0; index < numGoroutines; index++ { go func(index int) { key := "concurrent-key-" + string(rune('0'+index)) expectedValue := "concurrent-value-" + string(rune('0'+index)) @@ -280,16 +298,16 @@ func (s *SecureStorageTestSuite) TestConcurrentOperations() { s.NoError(err) s.Equal(expectedValue, value) done <- true - }(i) + }(index) } // Wait for all reads to complete - for i := 0; i < numGoroutines; i++ { + for index := 0; index < numGoroutines; index++ { <-done } } -// Test updating existing value +// Test updating existing value. func (s *SecureStorageTestSuite) TestUpdateValue() { key := "update-key" @@ -312,7 +330,7 @@ func (s *SecureStorageTestSuite) TestUpdateValue() { s.Equal("updated-value", value) } -// Test special characters and unicode +// Test special characters and unicode. func (s *SecureStorageTestSuite) TestSpecialCharacters() { testCases := map[string]string{ "unicode": "Hello 世界 🌍", @@ -332,7 +350,7 @@ func (s *SecureStorageTestSuite) TestSpecialCharacters() { } } -// Test empty values +// Test empty values. func (s *SecureStorageTestSuite) TestEmptyValues() { err := s.storage.Set("empty-key", "") s.Require().NoError(err) @@ -342,7 +360,7 @@ func (s *SecureStorageTestSuite) TestEmptyValues() { s.Equal("", value) } -// Test large values +// Test large values. func (s *SecureStorageTestSuite) TestLargeValues() { // Create a large value (1 MB) largeValue := make([]byte, 1024*1024) @@ -358,7 +376,7 @@ func (s *SecureStorageTestSuite) TestLargeValues() { s.Equal(string(largeValue), retrieved) } -// Test file permissions +// Test file permissions. func (s *SecureStorageTestSuite) TestFilePermissions() { err := s.storage.Set("key", "value") s.Require().NoError(err) @@ -376,13 +394,18 @@ func (s *SecureStorageTestSuite) TestFilePermissions() { s.Equal(expectedPerms, fileInfo.Mode().Perm()) } -// Test directory creation +// Test directory creation. func (s *SecureStorageTestSuite) TestDirectoryCreation() { nestedPath := filepath.Join(s.tempDir, "nested", "path", "storage.json") storage, err := NewSecureStorageWithEncryption("test-key", nestedPath) s.Require().NoError(err) - defer storage.Close() + defer func() { + if closeErr := storage.Close(); closeErr != nil { + // Log error but don't fail the test + _ = closeErr + } + }() err = storage.Create("test-password") s.Require().NoError(err) @@ -397,9 +420,13 @@ func (s *SecureStorageTestSuite) TestDirectoryCreation() { // Verify file was created _, err = os.Stat(nestedPath) s.NoError(err, "Storage file should be created in nested directory") + + // Close storage + err = storage.Close() + s.NoError(err, "Should close storage") } -// Test Exists method +// Test Exists method. func (s *SecureStorageTestSuite) TestExists() { // Test file should exist after Create() in SetupTest s.True(s.storage.Exists(), "Storage should exist after creation") @@ -408,17 +435,27 @@ func (s *SecureStorageTestSuite) TestExists() { newPath := filepath.Join(s.tempDir, "non-existent.json") newStorage, err := NewSecureStorageWithEncryption("test-key", newPath) s.Require().NoError(err) - defer newStorage.Close() + defer func() { + if closeErr := newStorage.Close(); closeErr != nil { + // Log error but don't fail the test + _ = closeErr + } + }() s.False(newStorage.Exists(), "Storage should not exist before creation") } -// Test Create method +// Test Create method. func (s *SecureStorageTestSuite) TestCreate() { newPath := filepath.Join(s.tempDir, "new-storage.json") storage, err := NewSecureStorageWithEncryption("test-key", newPath) s.Require().NoError(err) - defer storage.Close() + defer func() { + if closeErr := storage.Close(); closeErr != nil { + // Log error but don't fail the test + _ = closeErr + } + }() // Create should succeed err = storage.Create("my-password") @@ -436,12 +473,17 @@ func (s *SecureStorageTestSuite) TestCreate() { s.Equal("value", value) } -// Test Create with empty password +// Test Create with empty password. func (s *SecureStorageTestSuite) TestCreateEmptyPassword() { newPath := filepath.Join(s.tempDir, "empty-password-storage.json") storage, err := NewSecureStorageWithEncryption("test-key", newPath) s.Require().NoError(err) - defer storage.Close() + defer func() { + if closeErr := storage.Close(); closeErr != nil { + // Log error but don't fail the test + _ = closeErr + } + }() // Create should fail with empty password err = storage.Create("") @@ -449,7 +491,7 @@ func (s *SecureStorageTestSuite) TestCreateEmptyPassword() { s.Contains(err.Error(), "password cannot be empty") } -// Test Create when storage already exists +// Test Create when storage already exists. func (s *SecureStorageTestSuite) TestCreateAlreadyExists() { // Storage was already created in SetupTest err := s.storage.Create("another-password") @@ -457,7 +499,7 @@ func (s *SecureStorageTestSuite) TestCreateAlreadyExists() { s.Contains(err.Error(), "storage already exists") } -// Test Unlock with correct password +// Test Unlock with correct password. func (s *SecureStorageTestSuite) TestUnlockSuccess() { // Close and reload storage err := s.storage.Close() @@ -466,14 +508,19 @@ func (s *SecureStorageTestSuite) TestUnlockSuccess() { // Reload storage newStorage, err := NewSecureStorageWithEncryption("test-encryption-key", s.tempFile) s.Require().NoError(err) - defer newStorage.Close() + defer func() { + if closeErr := newStorage.Close(); closeErr != nil { + // Log error but don't fail the test + _ = closeErr + } + }() // Unlock with correct password should succeed err = newStorage.Unlock("test-password") s.NoError(err) } -// Test Unlock with wrong password +// Test Unlock with wrong password. func (s *SecureStorageTestSuite) TestUnlockWrongPassword() { // Close and reload storage err := s.storage.Close() @@ -482,7 +529,12 @@ func (s *SecureStorageTestSuite) TestUnlockWrongPassword() { // Reload storage newStorage, err := NewSecureStorageWithEncryption("test-encryption-key", s.tempFile) s.Require().NoError(err) - defer newStorage.Close() + defer func() { + if closeErr := newStorage.Close(); closeErr != nil { + // Log error but don't fail the test + _ = closeErr + } + }() // Unlock with wrong password should fail err = newStorage.Unlock("wrong-password") @@ -490,7 +542,7 @@ func (s *SecureStorageTestSuite) TestUnlockWrongPassword() { s.Contains(err.Error(), "incorrect password") } -// Test full Create-Unlock workflow +// Test full Create-Unlock workflow. func (s *SecureStorageTestSuite) TestCreateUnlockWorkflow() { newPath := filepath.Join(s.tempDir, "workflow-storage.json") password := "secure-password-123" @@ -521,7 +573,12 @@ func (s *SecureStorageTestSuite) TestCreateUnlockWorkflow() { // Step 4: Load storage again storage2, err := NewSecureStorageWithEncryption("encryption-key", newPath) s.Require().NoError(err) - defer storage2.Close() + defer func() { + if closeErr := storage2.Close(); closeErr != nil { + // Log error but don't fail the test + _ = closeErr + } + }() // Step 5: Verify password with Unlock err = storage2.Unlock(password) @@ -539,7 +596,7 @@ func (s *SecureStorageTestSuite) TestCreateUnlockWorkflow() { } } -// Test that storage operations work without calling Unlock +// Test that storage operations work without calling Unlock. func (s *SecureStorageTestSuite) TestOperationsWithoutUnlock() { // Storage was created but Unlock was never called // Operations should still work diff --git a/internal/storage/shared_memory_test.go b/internal/storage/shared_memory_test.go index d59c334..fd6ecc2 100644 --- a/internal/storage/shared_memory_test.go +++ b/internal/storage/shared_memory_test.go @@ -22,7 +22,7 @@ func (s *SharedMemoryTestSuite) TearDownTest() { s.Require().NoError(err) } -// TestNewSharedMemory tests the constructor +// TestNewSharedMemory tests the constructor. func (s *SharedMemoryTestSuite) TestNewSharedMemory() { storage := NewSharedMemory() s.NotNil(storage) @@ -32,7 +32,7 @@ func (s *SharedMemoryTestSuite) TestNewSharedMemory() { s.Empty(keys) } -// TestSetAndGet tests basic set and get operations +// TestSetAndGet tests basic set and get operations. func (s *SharedMemoryTestSuite) TestSetAndGet() { err := s.storage.Set("key1", "value1") s.NoError(err) @@ -42,14 +42,14 @@ func (s *SharedMemoryTestSuite) TestSetAndGet() { s.Equal("value1", value) } -// TestGetNonExistentKey tests getting a key that doesn't exist +// TestGetNonExistentKey tests getting a key that doesn't exist. func (s *SharedMemoryTestSuite) TestGetNonExistentKey() { value, err := s.storage.Get("nonexistent") s.NoError(err) s.Nil(value) } -// TestSetMultipleKeys tests setting multiple keys +// TestSetMultipleKeys tests setting multiple keys. func (s *SharedMemoryTestSuite) TestSetMultipleKeys() { err := s.storage.Set("key1", "value1") s.NoError(err) @@ -73,7 +73,7 @@ func (s *SharedMemoryTestSuite) TestSetMultipleKeys() { s.Equal(true, value3) } -// TestSetOverwriteExistingKey tests overwriting an existing key +// TestSetOverwriteExistingKey tests overwriting an existing key. func (s *SharedMemoryTestSuite) TestSetOverwriteExistingKey() { err := s.storage.Set("key1", "original") s.NoError(err) @@ -86,7 +86,7 @@ func (s *SharedMemoryTestSuite) TestSetOverwriteExistingKey() { s.Equal("updated", value) } -// TestDelete tests deleting a key +// TestDelete tests deleting a key. func (s *SharedMemoryTestSuite) TestDelete() { err := s.storage.Set("key1", "value1") s.NoError(err) @@ -99,13 +99,13 @@ func (s *SharedMemoryTestSuite) TestDelete() { s.Nil(value) } -// TestDeleteNonExistentKey tests deleting a key that doesn't exist +// TestDeleteNonExistentKey tests deleting a key that doesn't exist. func (s *SharedMemoryTestSuite) TestDeleteNonExistentKey() { err := s.storage.Delete("nonexistent") s.NoError(err) } -// TestList tests listing all keys +// TestList tests listing all keys. func (s *SharedMemoryTestSuite) TestList() { err := s.storage.Set("key1", "value1") s.NoError(err) @@ -124,14 +124,14 @@ func (s *SharedMemoryTestSuite) TestList() { s.Contains(keys, "key3") } -// TestListEmpty tests listing when storage is empty +// TestListEmpty tests listing when storage is empty. func (s *SharedMemoryTestSuite) TestListEmpty() { keys, err := s.storage.List() s.NoError(err) s.Empty(keys) } -// TestClear tests clearing all data +// TestClear tests clearing all data. func (s *SharedMemoryTestSuite) TestClear() { err := s.storage.Set("key1", "value1") s.NoError(err) @@ -151,7 +151,7 @@ func (s *SharedMemoryTestSuite) TestClear() { s.Nil(value) } -// TestClearEmpty tests clearing when already empty +// TestClearEmpty tests clearing when already empty. func (s *SharedMemoryTestSuite) TestClearEmpty() { err := s.storage.Clear() s.NoError(err) @@ -161,7 +161,7 @@ func (s *SharedMemoryTestSuite) TestClearEmpty() { s.Empty(keys) } -// TestDifferentValueTypes tests storing different types of values +// TestDifferentValueTypes tests storing different types of values. func (s *SharedMemoryTestSuite) TestDifferentValueTypes() { // String err := s.storage.Set("string", "hello") @@ -218,28 +218,28 @@ func (s *SharedMemoryTestSuite) TestDifferentValueTypes() { s.Equal(TestStruct{Name: "John", Age: 30}, structVal) } -// TestConcurrentAccess tests concurrent read/write operations +// TestConcurrentAccess tests concurrent read/write operations. func (s *SharedMemoryTestSuite) TestConcurrentAccess() { const numGoroutines = 100 const numOperations = 10 - var wg sync.WaitGroup + var waitGroup sync.WaitGroup // Concurrent writes - for i := 0; i < numGoroutines; i++ { - wg.Add(1) + for index := 0; index < numGoroutines; index++ { + waitGroup.Add(1) go func(id int) { - defer wg.Done() + defer waitGroup.Done() for j := 0; j < numOperations; j++ { key := "key_" + string(rune(id)) value := id*numOperations + j err := s.storage.Set(key, value) s.NoError(err) } - }(i) + }(index) } - wg.Wait() + waitGroup.Wait() // Verify some data was written keys, err := s.storage.List() @@ -247,7 +247,7 @@ func (s *SharedMemoryTestSuite) TestConcurrentAccess() { s.NotEmpty(keys) } -// TestConcurrentReadWrite tests concurrent reads and writes +// TestConcurrentReadWrite tests concurrent reads and writes. func (s *SharedMemoryTestSuite) TestConcurrentReadWrite() { // Pre-populate with some data for i := 0; i < 10; i++ { @@ -256,31 +256,31 @@ func (s *SharedMemoryTestSuite) TestConcurrentReadWrite() { } const numGoroutines = 50 - var wg sync.WaitGroup + var waitGroup sync.WaitGroup // Concurrent reads - for i := 0; i < numGoroutines; i++ { - wg.Add(1) + for index := 0; index < numGoroutines; index++ { + waitGroup.Add(1) go func(id int) { - defer wg.Done() + defer waitGroup.Done() key := "key_" + string(rune(id%10)) _, err := s.storage.Get(key) s.NoError(err) - }(i) + }(index) } // Concurrent writes - for i := 0; i < numGoroutines; i++ { - wg.Add(1) + for index := 0; index < numGoroutines; index++ { + waitGroup.Add(1) go func(id int) { - defer wg.Done() + defer waitGroup.Done() key := "new_key_" + string(rune(id)) err := s.storage.Set(key, id) s.NoError(err) - }(i) + }(index) } - wg.Wait() + waitGroup.Wait() // Verify storage is still functional keys, err := s.storage.List() @@ -288,7 +288,7 @@ func (s *SharedMemoryTestSuite) TestConcurrentReadWrite() { s.NotEmpty(keys) } -// TestConcurrentDelete tests concurrent delete operations +// TestConcurrentDelete tests concurrent delete operations. func (s *SharedMemoryTestSuite) TestConcurrentDelete() { // Pre-populate with data for i := 0; i < 100; i++ { @@ -296,20 +296,20 @@ func (s *SharedMemoryTestSuite) TestConcurrentDelete() { s.NoError(err) } - var wg sync.WaitGroup + var waitGroup sync.WaitGroup // Concurrent deletes - for i := 0; i < 100; i++ { - wg.Add(1) + for index := 0; index < 100; index++ { + waitGroup.Add(1) go func(id int) { - defer wg.Done() + defer waitGroup.Done() key := "key_" + string(rune(id)) err := s.storage.Delete(key) s.NoError(err) - }(i) + }(index) } - wg.Wait() + waitGroup.Wait() // Verify all keys are deleted keys, err := s.storage.List() @@ -317,7 +317,7 @@ func (s *SharedMemoryTestSuite) TestConcurrentDelete() { s.Empty(keys) } -// TestConcurrentList tests concurrent list operations +// TestConcurrentList tests concurrent list operations. func (s *SharedMemoryTestSuite) TestConcurrentList() { // Pre-populate with data for i := 0; i < 10; i++ { @@ -325,23 +325,23 @@ func (s *SharedMemoryTestSuite) TestConcurrentList() { s.NoError(err) } - var wg sync.WaitGroup + var waitGroup sync.WaitGroup // Concurrent list operations for i := 0; i < 50; i++ { - wg.Add(1) + waitGroup.Add(1) go func() { - defer wg.Done() + defer waitGroup.Done() keys, err := s.storage.List() s.NoError(err) s.NotEmpty(keys) }() } - wg.Wait() + waitGroup.Wait() } -// TestNilValue tests storing nil values +// TestNilValue tests storing nil values. func (s *SharedMemoryTestSuite) TestNilValue() { err := s.storage.Set("nil_key", nil) s.NoError(err) @@ -356,7 +356,7 @@ func (s *SharedMemoryTestSuite) TestNilValue() { s.Contains(keys, "nil_key") } -// TestEmptyStringKey tests using empty string as key +// TestEmptyStringKey tests using empty string as key. func (s *SharedMemoryTestSuite) TestEmptyStringKey() { err := s.storage.Set("", "value") s.NoError(err) diff --git a/internal/ui/component/component_test.go b/internal/ui/component/component_test.go index e48124c..72ee1aa 100644 --- a/internal/ui/component/component_test.go +++ b/internal/ui/component/component_test.go @@ -16,7 +16,7 @@ func TestComponentTestSuite(t *testing.T) { suite.Run(t, new(ComponentTestSuite)) } -// Test ComponentFunc +// Test ComponentFunc. func (s *ComponentTestSuite) TestComponentFunc() { comp := ComponentFunc(func() string { return "Hello, World!" @@ -25,19 +25,19 @@ func (s *ComponentTestSuite) TestComponentFunc() { s.Equal("Hello, World!", comp.Render()) } -// Test Empty +// Test Empty. func (s *ComponentTestSuite) TestEmpty() { comp := Empty() s.Equal("", comp.Render()) } -// Test Raw +// Test Raw. func (s *ComponentTestSuite) TestRaw() { comp := Raw("Raw content") s.Equal("Raw content", comp.Render()) } -// Test Join +// Test Join. func (s *ComponentTestSuite) TestJoin() { c1 := T("Hello") c2 := T("World") @@ -58,7 +58,7 @@ func (s *ComponentTestSuite) TestJoinSingle() { s.Equal("Hello", joined.Render()) } -// Test Text component +// Test Text component. func (s *ComponentTestSuite) TestText() { text := T("Hello") s.Contains(text.Render(), "Hello") @@ -79,7 +79,7 @@ func (s *ComponentTestSuite) TestTextPresets() { s.NotEmpty(T("Muted").Muted().Render()) } -// Test VStack +// Test VStack. func (s *ComponentTestSuite) TestVStack() { stack := NewVStack( T("Line 1"), @@ -110,7 +110,7 @@ func (s *ComponentTestSuite) TestVStackSpacing() { s.GreaterOrEqual(len(lines), 2) } -// Test HStack +// Test HStack. func (s *ComponentTestSuite) TestHStack() { stack := NewHStack( T("A"), @@ -139,7 +139,7 @@ func (s *ComponentTestSuite) TestHStackSpacing() { s.NotEmpty(rendered) } -// Test Spacer +// Test Spacer. func (s *ComponentTestSuite) TestSpacerVertical() { spacer := SpacerV(2) rendered := spacer.Render() @@ -157,7 +157,7 @@ func (s *ComponentTestSuite) TestSpacerZero() { s.Equal("", spacer.Render()) } -// Test Divider +// Test Divider. func (s *ComponentTestSuite) TestDivider() { div := NewDivider("─", 10) rendered := div.Render() @@ -170,7 +170,7 @@ func (s *ComponentTestSuite) TestDividerLine() { s.Contains(rendered, "─") } -// Test List +// Test List. func (s *ComponentTestSuite) TestList() { items := []ListItem{ Item("Item 1", "1", "Description 1"), @@ -356,7 +356,7 @@ func (s *ComponentTestSuite) TestBulletListCustomBullet() { s.Contains(rendered, "*") } -// Test Box +// Test Box. func (s *ComponentTestSuite) TestBox() { box := NewBox(T("Content")) rendered := box.Render() @@ -381,7 +381,7 @@ func (s *ComponentTestSuite) TestPanel() { s.Contains(rendered, "Panel content") } -// Test Padding +// Test Padding. func (s *ComponentTestSuite) TestPadding() { padded := NewPadding(T("Content")).All(1) rendered := padded.Render() @@ -400,14 +400,14 @@ func (s *ComponentTestSuite) TestPaddingHorizontal() { s.NotEmpty(rendered) } -// Test Center +// Test Center. func (s *ComponentTestSuite) TestCenter() { centered := NewCenter(T("Centered"), 20, 5) rendered := centered.Render() s.Contains(rendered, "Centered") } -// Test Conditional - If +// Test Conditional - If. func (s *ComponentTestSuite) TestIfTrue() { comp := IfC(true, T("True"), T("False")) s.Equal("True", comp.Render()) @@ -418,7 +418,7 @@ func (s *ComponentTestSuite) TestIfFalse() { s.Equal("False", comp.Render()) } -// Test Conditional - IfThen +// Test Conditional - IfThen. func (s *ComponentTestSuite) TestIfThenTrue() { comp := IfThenC(true, T("Shown")) s.Equal("Shown", comp.Render()) @@ -429,7 +429,7 @@ func (s *ComponentTestSuite) TestIfThenFalse() { s.Equal("", comp.Render()) } -// Test Conditional - Unless +// Test Conditional - Unless. func (s *ComponentTestSuite) TestUnlessTrue() { comp := UnlessC(true, T("Hidden")) s.Equal("", comp.Render()) @@ -440,7 +440,7 @@ func (s *ComponentTestSuite) TestUnlessFalse() { s.Equal("Shown", comp.Render()) } -// Test Conditional - Switch +// Test Conditional - Switch. func (s *ComponentTestSuite) TestSwitch() { value := 2 @@ -487,7 +487,7 @@ func (s *ComponentTestSuite) TestMatchRange() { s.Equal("Mid", comp.Render()) } -// Test Show/Hide helpers +// Test Show/Hide helpers. func (s *ComponentTestSuite) TestShow() { s.Equal("Shown", Show(true, T("Shown")).Render()) s.Equal("", Show(false, T("Hidden")).Render()) @@ -498,13 +498,13 @@ func (s *ComponentTestSuite) TestHide() { s.Equal("Shown", Hide(false, T("Shown")).Render()) } -// Test Toggle +// Test Toggle. func (s *ComponentTestSuite) TestToggle() { s.Equal("On", Toggle(true, T("On"), T("Off")).Render()) s.Equal("Off", Toggle(false, T("On"), T("Off")).Render()) } -// Integration tests - Complex compositions +// Integration tests - Complex compositions. func (s *ComponentTestSuite) TestComplexComposition() { comp := VStackC( T("Title").Bold(true), @@ -558,7 +558,7 @@ func (s *ComponentTestSuite) TestNestedConditionals() { s.NotContains(rendered, "Please log in") } -// Test style application +// Test style application. func (s *ComponentTestSuite) TestStyleApplication() { style := lipgloss.NewStyle(). Bold(true). @@ -569,7 +569,7 @@ func (s *ComponentTestSuite) TestStyleApplication() { s.NotEmpty(rendered) } -// Test List with highlighting +// Test List with highlighting. func (s *ComponentTestSuite) TestListWithHighlighting() { items := []ListItem{ Item("SQLite", "sqlite", "Local database"), diff --git a/internal/ui/component/layout.go b/internal/ui/component/layout.go index f0b55e4..2bcef14 100644 --- a/internal/ui/component/layout.go +++ b/internal/ui/component/layout.go @@ -191,13 +191,13 @@ func (z *ZStack) Render() string { // TODO: True overlay support would require more sophisticated rendering // For now, we'll just render them vertically as a fallback // In a real TUI, true overlay would need cursor positioning - for i := 1; i < len(z.children); i++ { + for index := 1; index < len(z.children); index++ { result = lipgloss.Place( lipgloss.Width(result), lipgloss.Height(result), z.alignment, z.alignment, - z.children[i].Render(), + z.children[index].Render(), lipgloss.WithWhitespaceChars(" "), lipgloss.WithWhitespaceForeground(lipgloss.NoColor{}), ) diff --git a/internal/view/router.go b/internal/view/router.go index 51a5f94..4619920 100644 --- a/internal/view/router.go +++ b/internal/view/router.go @@ -55,6 +55,12 @@ func (r *RouterImplementation) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return r, tea.Quit } + // handle esc key + if msg, ok := msg.(tea.KeyMsg); ok && msg.String() == "esc" { + r.Back() + return r, nil + } + if r.currentComponent == nil { return r, nil } diff --git a/internal/view/router_navigation_test.go b/internal/view/router_navigation_test.go index 6ecabce..f00b067 100644 --- a/internal/view/router_navigation_test.go +++ b/internal/view/router_navigation_test.go @@ -83,7 +83,7 @@ func (m HomeModel) View() string { return "Home Page" } -// RouterNavigationTestSuite tests router navigation using teatest +// RouterNavigationTestSuite tests router navigation using teatest. type RouterNavigationTestSuite struct { suite.Suite } @@ -98,14 +98,15 @@ func (s *RouterNavigationTestSuite) getOutput(tm *teatest.TestModel) string { return string(output) } -// TestEnterKeyNavigation tests that pressing Enter navigates to the sub page +// TestEnterKeyNavigation tests that pressing Enter navigates to the sub page. func (s *RouterNavigationTestSuite) TestEnterKeyNavigation() { router := NewRouter() router.AddRoute(Route{Path: "/", Component: func(r Router, sharedMemory storage.SharedMemory) View { return NewPage(r) }}) router.AddRoute(Route{Path: "/page2", Component: func(r Router, sharedMemory storage.SharedMemory) View { return NewSubPage(r) }}) - router.NavigateTo("/", nil) + err := router.NavigateTo("/", nil) + s.NoError(err, "Should navigate to root") - tm := teatest.NewTestModel( + testModel := teatest.NewTestModel( s.T(), router, teatest.WithInitialTermSize(300, 100), @@ -115,7 +116,7 @@ func (s *RouterNavigationTestSuite) TestEnterKeyNavigation() { s.Equal("/", router.GetPath(), "Should be on root path") // Send Enter key to navigate to page2 - tm.Send(tea.KeyMsg{Type: tea.KeyEnter}) + testModel.Send(tea.KeyMsg{Type: tea.KeyEnter}) // Give time for the update to process time.Sleep(200 * time.Millisecond) @@ -124,23 +125,24 @@ func (s *RouterNavigationTestSuite) TestEnterKeyNavigation() { // Note: We verify the navigation worked by checking the router state directly // since the teatest FinalOutput captures the entire terminal session s.Equal("/page2", router.GetPath(), "Should navigate to /page2 after Enter") - s.Contains(s.getOutput(tm), "Sub Page", "Should contain sub page content") + s.Contains(s.getOutput(testModel), "Sub Page", "Should contain sub page content") // Quit - tm.Send(tea.KeyMsg{Type: tea.KeyCtrlC}) - tm.WaitFinished(s.T(), teatest.WithFinalTimeout(time.Second)) + testModel.Send(tea.KeyMsg{Type: tea.KeyCtrlC}) + testModel.WaitFinished(s.T(), teatest.WithFinalTimeout(time.Second)) } -// TestEscKeyNavigation tests that pressing Esc navigates back to the previous page +// TestEscKeyNavigation tests that pressing Esc navigates back to the previous page. func (s *RouterNavigationTestSuite) TestEscKeyNavigation() { router := NewRouter() router.AddRoute(Route{Path: "/", Component: func(r Router, sharedMemory storage.SharedMemory) View { return NewPage(r) }}) router.AddRoute(Route{Path: "/page2", Component: func(r Router, sharedMemory storage.SharedMemory) View { return NewSubPage(r) }}) // Navigate to sub page first - router.NavigateTo("/page2", nil) + err := router.NavigateTo("/page2", nil) + s.NoError(err, "Should navigate to page2") - tm := teatest.NewTestModel( + testModel := teatest.NewTestModel( s.T(), router, teatest.WithInitialTermSize(300, 100), @@ -151,19 +153,21 @@ func (s *RouterNavigationTestSuite) TestEscKeyNavigation() { // Verify we're on sub page s.Equal("/page2", router.GetPath(), "Should be on /page2") - s.Contains(s.getOutput(tm), "Sub Page", "Should contain sub page content") + s.Contains(s.getOutput(testModel), "Sub Page", "Should contain sub page content") // Send Esc key to go back - tm.Send(tea.KeyMsg{Type: tea.KeyEsc}) + testModel.Send(tea.KeyMsg{Type: tea.KeyEsc}) // Give time for the update to process - time.Sleep(200 * time.Millisecond) + time.Sleep(500 * time.Millisecond) // Verify back navigation occurred by checking the router's current path s.Equal("/", router.GetPath(), "Should navigate back to / after Esc") - s.Contains(s.getOutput(tm), "Home Page", "Should contain home page content") + output := s.getOutput(testModel) + s.T().Logf("Output: %s", output) + s.Contains(output, "Home Page", "Should contain home page content") // Quit - tm.Send(tea.KeyMsg{Type: tea.KeyCtrlC}) - tm.WaitFinished(s.T(), teatest.WithFinalTimeout(time.Second)) + testModel.Send(tea.KeyMsg{Type: tea.KeyCtrlC}) + testModel.WaitFinished(s.T(), teatest.WithFinalTimeout(time.Second)) } diff --git a/internal/view/router_test.go b/internal/view/router_test.go index a0cdb5a..4c57096 100644 --- a/internal/view/router_test.go +++ b/internal/view/router_test.go @@ -9,7 +9,7 @@ import ( "github.com/stretchr/testify/suite" ) -// MockView is a mock implementation of View for testing +// MockView is a mock implementation of View for testing. type MockView struct { name string initCalled bool @@ -33,7 +33,7 @@ func (m *MockView) Help() (string, HelpDisplayOption) { return "", HelpDisplayOptionAppend } -// RouterTestSuite is the test suite for Router +// RouterTestSuite is the test suite for Router. type RouterTestSuite struct { suite.Suite router Router @@ -43,7 +43,7 @@ func (suite *RouterTestSuite) SetupTest() { suite.router = NewRouter() } -// TestNewRouter tests the router initialization +// TestNewRouter tests the router initialization. func (suite *RouterTestSuite) TestNewRouter() { router := NewRouter() assert.NotNil(suite.T(), router) @@ -51,7 +51,7 @@ func (suite *RouterTestSuite) TestNewRouter() { assert.Equal(suite.T(), "", router.GetPath()) } -// TestAddRoute tests adding routes +// TestAddRoute tests adding routes. func (suite *RouterTestSuite) TestAddRoute() { mockView := &MockView{name: "home", viewContent: "Home View"} route := Route{ @@ -66,7 +66,7 @@ func (suite *RouterTestSuite) TestAddRoute() { assert.Equal(suite.T(), "/", routes[0].Path) } -// TestSetRoutes tests setting multiple routes at once +// TestSetRoutes tests setting multiple routes at once. func (suite *RouterTestSuite) TestSetRoutes() { mockView1 := &MockView{name: "home", viewContent: "Home View"} mockView2 := &MockView{name: "about", viewContent: "About View"} @@ -84,7 +84,7 @@ func (suite *RouterTestSuite) TestSetRoutes() { assert.Equal(suite.T(), "/about", retrievedRoutes[1].Path) } -// TestRemoveRoute tests removing a route +// TestRemoveRoute tests removing a route. func (suite *RouterTestSuite) TestRemoveRoute() { mockView1 := &MockView{name: "home", viewContent: "Home View"} mockView2 := &MockView{name: "about", viewContent: "About View"} @@ -99,7 +99,7 @@ func (suite *RouterTestSuite) TestRemoveRoute() { assert.Equal(suite.T(), "/", routes[0].Path) } -// TestNavigateTo tests navigation to a route +// TestNavigateTo tests navigation to a route. func (suite *RouterTestSuite) TestNavigateTo() { mockView := &MockView{name: "home", viewContent: "Home View"} suite.router.AddRoute(Route{Path: "/", Component: func(r Router, sharedMemory storage.SharedMemory) View { return mockView }}) @@ -112,7 +112,7 @@ func (suite *RouterTestSuite) TestNavigateTo() { assert.Equal(suite.T(), "/", currentRoute.Path) } -// TestNavigateToWithQueryParams tests navigation with query parameters +// TestNavigateToWithQueryParams tests navigation with query parameters. func (suite *RouterTestSuite) TestNavigateToWithQueryParams() { mockView := &MockView{name: "users", viewContent: "Users View"} suite.router.AddRoute(Route{Path: "/users", Component: func(r Router, sharedMemory storage.SharedMemory) View { return mockView }}) @@ -129,14 +129,14 @@ func (suite *RouterTestSuite) TestNavigateToWithQueryParams() { assert.Equal(suite.T(), "john", suite.router.GetQueryParam("name")) } -// TestNavigateToInvalidRoute tests navigation to non-existent route +// TestNavigateToInvalidRoute tests navigation to non-existent route. func (suite *RouterTestSuite) TestNavigateToInvalidRoute() { err := suite.router.NavigateTo("/nonexistent", nil) assert.Error(suite.T(), err) assert.Contains(suite.T(), err.Error(), "no route found") } -// TestNavigateToWithPathParams tests navigation with path parameters +// TestNavigateToWithPathParams tests navigation with path parameters. func (suite *RouterTestSuite) TestNavigateToWithPathParams() { mockView := &MockView{name: "user", viewContent: "User View"} suite.router.AddRoute(Route{Path: "/users/:id", Component: func(r Router, sharedMemory storage.SharedMemory) View { return mockView }}) @@ -147,7 +147,7 @@ func (suite *RouterTestSuite) TestNavigateToWithPathParams() { assert.Equal(suite.T(), "/users/123", suite.router.GetPath()) } -// TestNavigateToWithMultiplePathParams tests navigation with multiple path parameters +// TestNavigateToWithMultiplePathParams tests navigation with multiple path parameters. func (suite *RouterTestSuite) TestNavigateToWithMultiplePathParams() { mockView := &MockView{name: "comment", viewContent: "Comment View"} suite.router.AddRoute(Route{Path: "/posts/:postId/comments/:commentId", Component: func(r Router, sharedMemory storage.SharedMemory) View { return mockView }}) @@ -158,7 +158,7 @@ func (suite *RouterTestSuite) TestNavigateToWithMultiplePathParams() { assert.Equal(suite.T(), "789", suite.router.GetParam("commentId")) } -// TestReplaceRoute tests replacing the current route +// TestReplaceRoute tests replacing the current route. func (suite *RouterTestSuite) TestReplaceRoute() { mockView1 := &MockView{name: "home", viewContent: "Home View"} mockView2 := &MockView{name: "about", viewContent: "About View"} @@ -178,7 +178,7 @@ func (suite *RouterTestSuite) TestReplaceRoute() { assert.False(suite.T(), suite.router.CanGoBack()) } -// TestBackNavigation tests navigating back +// TestBackNavigation tests navigating back. func (suite *RouterTestSuite) TestBackNavigation() { mockView1 := &MockView{name: "home", viewContent: "Home View"} mockView2 := &MockView{name: "about", viewContent: "About View"} @@ -189,9 +189,12 @@ func (suite *RouterTestSuite) TestBackNavigation() { suite.router.AddRoute(Route{Path: "/contact", Component: func(r Router, sharedMemory storage.SharedMemory) View { return mockView3 }}) // Navigate through routes - suite.router.NavigateTo("/", nil) - suite.router.NavigateTo("/about", nil) - suite.router.NavigateTo("/contact", nil) + err := suite.router.NavigateTo("/", nil) + suite.NoError(err) + err = suite.router.NavigateTo("/about", nil) + suite.NoError(err) + err = suite.router.NavigateTo("/contact", nil) + suite.NoError(err) assert.Equal(suite.T(), "/contact", suite.router.GetPath()) assert.True(suite.T(), suite.router.CanGoBack()) @@ -207,12 +210,13 @@ func (suite *RouterTestSuite) TestBackNavigation() { assert.False(suite.T(), suite.router.CanGoBack()) } -// TestBackWithEmptyStack tests back navigation with empty stack +// TestBackWithEmptyStack tests back navigation with empty stack. func (suite *RouterTestSuite) TestBackWithEmptyStack() { mockView := &MockView{name: "home", viewContent: "Home View"} suite.router.AddRoute(Route{Path: "/", Component: func(r Router, sharedMemory storage.SharedMemory) View { return mockView }}) - suite.router.NavigateTo("/", nil) + err := suite.router.NavigateTo("/", nil) + assert.NoError(suite.T(), err) assert.False(suite.T(), suite.router.CanGoBack()) // Should not panic @@ -220,7 +224,7 @@ func (suite *RouterTestSuite) TestBackWithEmptyStack() { assert.Equal(suite.T(), "/", suite.router.GetPath()) } -// TestCanGoBack tests the CanGoBack method +// TestCanGoBack tests the CanGoBack method. func (suite *RouterTestSuite) TestCanGoBack() { mockView1 := &MockView{name: "home", viewContent: "Home View"} mockView2 := &MockView{name: "about", viewContent: "About View"} @@ -230,65 +234,71 @@ func (suite *RouterTestSuite) TestCanGoBack() { assert.False(suite.T(), suite.router.CanGoBack()) - suite.router.NavigateTo("/", nil) + err := suite.router.NavigateTo("/", nil) + assert.NoError(suite.T(), err) assert.False(suite.T(), suite.router.CanGoBack()) - suite.router.NavigateTo("/about", nil) + err = suite.router.NavigateTo("/about", nil) + assert.NoError(suite.T(), err) assert.True(suite.T(), suite.router.CanGoBack()) } -// TestGetCurrentRoute tests getting the current route +// TestGetCurrentRoute tests getting the current route. func (suite *RouterTestSuite) TestGetCurrentRoute() { mockView := &MockView{name: "home", viewContent: "Home View"} componentFunc := func(r Router, sharedMemory storage.SharedMemory) View { return mockView } route := Route{Path: "/", Component: componentFunc} suite.router.AddRoute(route) - suite.router.NavigateTo("/", nil) + err := suite.router.NavigateTo("/", nil) + assert.NoError(suite.T(), err) currentRoute := suite.router.GetCurrentRoute() assert.Equal(suite.T(), "/", currentRoute.Path) assert.NotNil(suite.T(), currentRoute.Component) } -// TestGetCurrentRouteEmpty tests getting current route when none is set +// TestGetCurrentRouteEmpty tests getting current route when none is set. func (suite *RouterTestSuite) TestGetCurrentRouteEmpty() { currentRoute := suite.router.GetCurrentRoute() assert.Equal(suite.T(), "", currentRoute.Path) assert.Nil(suite.T(), currentRoute.Component) } -// TestGetQueryParamNotFound tests getting a non-existent query parameter +// TestGetQueryParamNotFound tests getting a non-existent query parameter. func (suite *RouterTestSuite) TestGetQueryParamNotFound() { mockView := &MockView{name: "home", viewContent: "Home View"} suite.router.AddRoute(Route{Path: "/", Component: func(r Router, sharedMemory storage.SharedMemory) View { return mockView }}) - suite.router.NavigateTo("/", nil) + err := suite.router.NavigateTo("/", nil) + assert.NoError(suite.T(), err) param := suite.router.GetQueryParam("nonexistent") assert.Equal(suite.T(), "", param) } -// TestGetParamNotFound tests getting a non-existent path parameter +// TestGetParamNotFound tests getting a non-existent path parameter. func (suite *RouterTestSuite) TestGetParamNotFound() { mockView := &MockView{name: "home", viewContent: "Home View"} suite.router.AddRoute(Route{Path: "/", Component: func(r Router, sharedMemory storage.SharedMemory) View { return mockView }}) - suite.router.NavigateTo("/", nil) + err := suite.router.NavigateTo("/", nil) + assert.NoError(suite.T(), err) param := suite.router.GetParam("nonexistent") assert.Equal(suite.T(), "", param) } -// TestGetPathEmpty tests getting path when no route is active +// TestGetPathEmpty tests getting path when no route is active. func (suite *RouterTestSuite) TestGetPathEmpty() { path := suite.router.GetPath() assert.Equal(suite.T(), "", path) } -// TestRefresh tests refreshing the current route +// TestRefresh tests refreshing the current route. func (suite *RouterTestSuite) TestRefresh() { mockView := &MockView{name: "home", viewContent: "Home View"} suite.router.AddRoute(Route{Path: "/", Component: func(r Router, sharedMemory storage.SharedMemory) View { return mockView }}) - suite.router.NavigateTo("/", nil) + err := suite.router.NavigateTo("/", nil) + assert.NoError(suite.T(), err) mockView.initCalled = false suite.router.Refresh() @@ -296,11 +306,12 @@ func (suite *RouterTestSuite) TestRefresh() { assert.True(suite.T(), mockView.initCalled) } -// TestViewMethod tests the View method +// TestViewMethod tests the View method. func (suite *RouterTestSuite) TestViewMethod() { mockView := &MockView{name: "home", viewContent: "Home View Content"} suite.router.AddRoute(Route{Path: "/", Component: func(r Router, sharedMemory storage.SharedMemory) View { return mockView }}) - suite.router.NavigateTo("/", nil) + err := suite.router.NavigateTo("/", nil) + assert.NoError(suite.T(), err) view := suite.router.View() // The view should contain the content wrapped in a box with helper text @@ -308,17 +319,18 @@ func (suite *RouterTestSuite) TestViewMethod() { assert.Contains(suite.T(), view, "Ctrl + c to exit") } -// TestViewMethodNoRoute tests View when no route is active +// TestViewMethodNoRoute tests View when no route is active. func (suite *RouterTestSuite) TestViewMethodNoRoute() { view := suite.router.View() assert.Equal(suite.T(), "No route selected", view) } -// TestInitMethod tests the Init method +// TestInitMethod tests the Init method. func (suite *RouterTestSuite) TestInitMethod() { mockView := &MockView{name: "home", viewContent: "Home View"} suite.router.AddRoute(Route{Path: "/", Component: func(r Router, sharedMemory storage.SharedMemory) View { return mockView }}) - suite.router.NavigateTo("/", nil) + err := suite.router.NavigateTo("/", nil) + assert.NoError(suite.T(), err) mockView.initCalled = false cmd := suite.router.Init() @@ -327,11 +339,12 @@ func (suite *RouterTestSuite) TestInitMethod() { assert.True(suite.T(), mockView.initCalled) } -// TestUpdateMethod tests the Update method +// TestUpdateMethod tests the Update method. func (suite *RouterTestSuite) TestUpdateMethod() { mockView := &MockView{name: "home", viewContent: "Home View"} suite.router.AddRoute(Route{Path: "/", Component: func(r Router, sharedMemory storage.SharedMemory) View { return mockView }}) - suite.router.NavigateTo("/", nil) + err := suite.router.NavigateTo("/", nil) + assert.NoError(suite.T(), err) msg := tea.KeyMsg{Type: tea.KeyEnter} model, cmd := suite.router.Update(msg) @@ -340,7 +353,7 @@ func (suite *RouterTestSuite) TestUpdateMethod() { assert.Nil(suite.T(), cmd) } -// TestMatchPatternExactMatch tests exact route matching +// TestMatchPatternExactMatch tests exact route matching. func (suite *RouterTestSuite) TestMatchPatternExactMatch() { mockView := &MockView{name: "home", viewContent: "Home View"} suite.router.AddRoute(Route{Path: "/exact", Component: func(r Router, sharedMemory storage.SharedMemory) View { return mockView }}) @@ -350,7 +363,7 @@ func (suite *RouterTestSuite) TestMatchPatternExactMatch() { assert.Equal(suite.T(), "/exact", suite.router.GetPath()) } -// TestMatchPatternComplexParams tests complex parameterized routes +// TestMatchPatternComplexParams tests complex parameterized routes. func (suite *RouterTestSuite) TestMatchPatternComplexParams() { mockView := &MockView{name: "complex", viewContent: "Complex View"} suite.router.AddRoute(Route{Path: "/api/:version/users/:userId/posts/:postId", Component: func(r Router, sharedMemory storage.SharedMemory) View { return mockView }}) @@ -362,7 +375,7 @@ func (suite *RouterTestSuite) TestMatchPatternComplexParams() { assert.Equal(suite.T(), "200", suite.router.GetParam("postId")) } -// TestNavigationStackIntegrity tests that navigation stack maintains integrity +// TestNavigationStackIntegrity tests that navigation stack maintains integrity. func (suite *RouterTestSuite) TestNavigationStackIntegrity() { mockView1 := &MockView{name: "view1", viewContent: "View 1"} mockView2 := &MockView{name: "view2", viewContent: "View 2"} @@ -373,9 +386,9 @@ func (suite *RouterTestSuite) TestNavigationStackIntegrity() { suite.router.AddRoute(Route{Path: "/view3", Component: func(r Router, sharedMemory storage.SharedMemory) View { return mockView3 }}) // Navigate forward - suite.router.NavigateTo("/view1", nil) - suite.router.NavigateTo("/view2", nil) - suite.router.NavigateTo("/view3", nil) + _ = suite.router.NavigateTo("/view1", nil) + _ = suite.router.NavigateTo("/view2", nil) + _ = suite.router.NavigateTo("/view3", nil) // Go back twice suite.router.Back() @@ -384,12 +397,12 @@ func (suite *RouterTestSuite) TestNavigationStackIntegrity() { assert.Equal(suite.T(), "/view1", suite.router.GetPath()) // Navigate forward again - suite.router.NavigateTo("/view3", nil) + _ = suite.router.NavigateTo("/view3", nil) assert.Equal(suite.T(), "/view3", suite.router.GetPath()) assert.True(suite.T(), suite.router.CanGoBack()) } -// TestParameterPersistenceAcrossNavigation tests that parameters are preserved during navigation +// TestParameterPersistenceAcrossNavigation tests that parameters are preserved during navigation. func (suite *RouterTestSuite) TestParameterPersistenceAcrossNavigation() { mockView1 := &MockView{name: "users", viewContent: "Users View"} mockView2 := &MockView{name: "posts", viewContent: "Posts View"} @@ -398,13 +411,13 @@ func (suite *RouterTestSuite) TestParameterPersistenceAcrossNavigation() { suite.router.AddRoute(Route{Path: "/posts/:postId", Component: func(r Router, sharedMemory storage.SharedMemory) View { return mockView2 }}) queryParams := map[string]string{"filter": "active"} - suite.router.NavigateTo("/users/123", queryParams) + _ = suite.router.NavigateTo("/users/123", queryParams) assert.Equal(suite.T(), "123", suite.router.GetParam("id")) assert.Equal(suite.T(), "active", suite.router.GetQueryParam("filter")) // Navigate to another route - suite.router.NavigateTo("/posts/456", nil) + _ = suite.router.NavigateTo("/posts/456", nil) assert.Equal(suite.T(), "456", suite.router.GetParam("postId")) assert.Equal(suite.T(), "", suite.router.GetQueryParam("filter")) // Should be cleared @@ -414,7 +427,7 @@ func (suite *RouterTestSuite) TestParameterPersistenceAcrossNavigation() { assert.Equal(suite.T(), "active", suite.router.GetQueryParam("filter")) // Should be restored } -// Run the test suite +// Run the test suite. func TestRouterTestSuite(t *testing.T) { suite.Run(t, new(RouterTestSuite)) } diff --git a/tools/routergen/generator.go b/tools/routergen/generator.go index 3a8f1f7..d2b6b2e 100644 --- a/tools/routergen/generator.go +++ b/tools/routergen/generator.go @@ -104,45 +104,45 @@ func generatePackageAlias(fsPath string) string { // GenerateRoutesFile generates the Go source code for the routes.go file. func GenerateRoutesFile(routes []RouteDefinition, moduleName string) string { - var sb strings.Builder + var strBuilder strings.Builder // Package declaration - sb.WriteString("package app\n\n") + strBuilder.WriteString("package app\n\n") // Imports - sb.WriteString("import (\n") - sb.WriteString("\t\"github.com/rxtech-lab/smart-contract-cli/internal/view\"\n") - sb.WriteString("\t\"github.com/rxtech-lab/smart-contract-cli/internal/storage\"\n") + strBuilder.WriteString("import (\n") + strBuilder.WriteString("\t\"github.com/rxtech-lab/smart-contract-cli/internal/view\"\n") + strBuilder.WriteString("\t\"github.com/rxtech-lab/smart-contract-cli/internal/storage\"\n") // Import each page package (skip root package to avoid import cycle) appPackagePath := moduleName + "/app" for _, route := range routes { // Skip importing the app package itself to avoid circular import if route.PackagePath != appPackagePath { - sb.WriteString(fmt.Sprintf("\t%s \"%s\"\n", route.PackageAlias, route.PackagePath)) + strBuilder.WriteString(fmt.Sprintf("\t%s \"%s\"\n", route.PackageAlias, route.PackagePath)) } } - sb.WriteString(")\n\n") + strBuilder.WriteString(")\n\n") // GetRoutes function - sb.WriteString("// GetRoutes returns all routes generated from the app folder structure.\n") - sb.WriteString("func GetRoutes() []view.Route {\n") - sb.WriteString("\treturn []view.Route{\n") + strBuilder.WriteString("// GetRoutes returns all routes generated from the app folder structure.\n") + strBuilder.WriteString("func GetRoutes() []view.Route {\n") + strBuilder.WriteString("\treturn []view.Route{\n") for _, route := range routes { // For root package, call NewPage() directly without package prefix if route.PackagePath == appPackagePath { - sb.WriteString(fmt.Sprintf("\t\t{Path: %q, Component: func(r view.Router, sharedMemory storage.SharedMemory) view.View { return NewPage(r, sharedMemory) }},\n", route.Path)) + strBuilder.WriteString(fmt.Sprintf("\t\t{Path: %q, Component: func(r view.Router, sharedMemory storage.SharedMemory) view.View { return NewPage(r, sharedMemory) }},\n", route.Path)) } else { - sb.WriteString(fmt.Sprintf("\t\t{Path: %q, Component: func(r view.Router, sharedMemory storage.SharedMemory) view.View { return %s.NewPage(r, sharedMemory) }},\n", + strBuilder.WriteString(fmt.Sprintf("\t\t{Path: %q, Component: func(r view.Router, sharedMemory storage.SharedMemory) view.View { return %s.NewPage(r, sharedMemory) }},\n", route.Path, route.PackageAlias)) } } - sb.WriteString("\t}\n") - sb.WriteString("}\n") + strBuilder.WriteString("\t}\n") + strBuilder.WriteString("}\n") - return sb.String() + return strBuilder.String() } // ConvertAbsoluteToModulePath converts an absolute file path to a module-relative import path. diff --git a/tools/routergen/generator_test.go b/tools/routergen/generator_test.go index 4ad8047..9a1816f 100644 --- a/tools/routergen/generator_test.go +++ b/tools/routergen/generator_test.go @@ -23,7 +23,8 @@ func (s *GeneratorTestSuite) SetupTest() { func (s *GeneratorTestSuite) TearDownTest() { // Clean up temporary directory if s.tempDir != "" { - os.RemoveAll(s.tempDir) + err := os.RemoveAll(s.tempDir) + s.NoError(err, "Should clean up temp directory") } } @@ -31,10 +32,10 @@ func (s *GeneratorTestSuite) createPageFile(relPath string) string { fullPath := filepath.Join(s.tempDir, relPath, "page.go") dir := filepath.Dir(fullPath) - err := os.MkdirAll(dir, 0755) + err := os.MkdirAll(dir, 0750) s.Require().NoError(err) - err = os.WriteFile(fullPath, []byte("package page\n"), 0644) + err = os.WriteFile(fullPath, []byte("package page\n"), 0600) s.Require().NoError(err) return fullPath @@ -182,7 +183,7 @@ func (s *GeneratorTestSuite) TestScanAppFolder_IgnoresNonPageFiles() { // Create non-page.go files that should be ignored otherFile := filepath.Join(s.tempDir, "users", "helper.go") - err := os.WriteFile(otherFile, []byte("package users\n"), 0644) + err := os.WriteFile(otherFile, []byte("package users\n"), 0600) s.Require().NoError(err) routes, err := ScanAppFolder(s.tempDir) diff --git a/tools/routergen/main.go b/tools/routergen/main.go index 4b97488..d68d6c6 100644 --- a/tools/routergen/main.go +++ b/tools/routergen/main.go @@ -5,6 +5,7 @@ import ( "fmt" "os" "path/filepath" + "strings" ) func main() { @@ -126,6 +127,12 @@ func printSuccess(outputFile string, routes []RouteDefinition) { func detectModuleName(moduleRoot string) (string, error) { goModPath := filepath.Join(moduleRoot, "go.mod") + // Validate path to prevent directory traversal + cleaned := filepath.Clean(goModPath) + if strings.Contains(cleaned, "..") { + return "", fmt.Errorf("invalid file path: %s", goModPath) + } + data, err := os.ReadFile(goModPath) if err != nil { return "", fmt.Errorf("failed to read go.mod: %w", err)