diff --git a/app.go b/app.go index 8c854b9..61b1ce6 100644 --- a/app.go +++ b/app.go @@ -6,6 +6,7 @@ import ( "os" "strings" + "github.com/patrickdappollonio/kubectl-slice/pkg/template" "github.com/patrickdappollonio/kubectl-slice/slice" "github.com/spf13/cobra" "github.com/spf13/pflag" @@ -101,7 +102,7 @@ func root() *cobra.Command { rootCommand.Flags().StringSliceVar(&opts.InputFolderExt, "extensions", []string{".yaml", ".yml"}, "the extensions to look for in the input folder") rootCommand.Flags().BoolVarP(&opts.Recurse, "recurse", "r", false, "if true, the input folder will be read recursively (has no effect unless used with --input-folder)") rootCommand.Flags().StringVarP(&opts.OutputDirectory, "output-dir", "o", "", "the output directory used to output the splitted files") - rootCommand.Flags().StringVarP(&opts.GoTemplate, "template", "t", slice.DefaultTemplateName, "go template used to generate the file name when creating the resource files in the output directory") + rootCommand.Flags().StringVarP(&opts.GoTemplate, "template", "t", template.DefaultTemplateName, "go template used to generate the file name when creating the resource files in the output directory") rootCommand.Flags().BoolVar(&opts.DryRun, "dry-run", false, "if true, no files are created, but the potentially generated files will be printed as the command output") rootCommand.Flags().BoolVar(&opts.DebugMode, "debug", false, "enable debug mode") rootCommand.Flags().BoolVarP(&opts.Quiet, "quiet", "q", false, "if true, no output is written to stdout/err") diff --git a/pkg/errors/errors.go b/pkg/errors/errors.go new file mode 100644 index 0000000..cd4cf3c --- /dev/null +++ b/pkg/errors/errors.go @@ -0,0 +1,74 @@ +package errors + +import ( + "fmt" + "strings" +) + +// StrictModeSkipErr represents an error when a Kubernetes resource is skipped +// in strict mode because a required field is missing or empty +type StrictModeSkipErr struct { + FieldName string +} + +func (s *StrictModeSkipErr) Error() string { + return fmt.Sprintf( + "resource does not have a Kubernetes %q field or the field is invalid or empty", s.FieldName, + ) +} + +// SkipErr represents an error when a Kubernetes resource is intentionally skipped +// based on user-provided include/exclude filter configuration +type SkipErr struct { + Name string + Kind string + Group string + Reason string +} + +func (e *SkipErr) Error() string { + if e.Name == "" && e.Kind == "" { + if e.Group != "" { + if e.Reason != "" { + return fmt.Sprintf("resource with API group %q is skipped: %s", e.Group, e.Reason) + } + return fmt.Sprintf("resource with API group %q is configured to be skipped", e.Group) + } + return "resource is configured to be skipped" + } + + if e.Reason != "" { + return fmt.Sprintf("resource %s %q is skipped: %s", e.Kind, e.Name, e.Reason) + } + return fmt.Sprintf("resource %s %q is configured to be skipped", e.Kind, e.Name) +} + +// nonKubernetesMessage provides a standard error message for YAML files that don't contain +// standard Kubernetes metadata and are likely not Kubernetes resources +const nonKubernetesMessage = `the file has no Kubernetes metadata: it is most likely a non-Kubernetes YAML file, you can skip it with --skip-non-k8s` + +// CantFindFieldErr represents an error when a required field is missing in a Kubernetes +// resource. It includes contextual information about the file and resource. +type CantFindFieldErr struct { + FieldName string + FileCount int + Meta interface{} +} + +func (e *CantFindFieldErr) Error() string { + var sb strings.Builder + + sb.WriteString(fmt.Sprintf( + "unable to find Kubernetes %q field in file %d", + e.FieldName, e.FileCount, + )) + + // Type assertion to check if Meta has an empty() method + if metaWithEmpty, ok := e.Meta.(interface{ empty() bool }); ok && metaWithEmpty.empty() { + sb.WriteString(": " + nonKubernetesMessage) + } else if meta, ok := e.Meta.(fmt.Stringer); ok { + sb.WriteString(fmt.Sprintf(": %s", meta.String())) + } + + return sb.String() +} diff --git a/pkg/errors/errors_test.go b/pkg/errors/errors_test.go new file mode 100644 index 0000000..e2ad82a --- /dev/null +++ b/pkg/errors/errors_test.go @@ -0,0 +1,174 @@ +package errors + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestStrictModeSkipErr_Error(t *testing.T) { + tests := []struct { + name string + fieldName string + want string + }{ + { + name: "with metadata.name field", + fieldName: "metadata.name", + want: "resource does not have a Kubernetes \"metadata.name\" field or the field is invalid or empty", + }, + { + name: "with kind field", + fieldName: "kind", + want: "resource does not have a Kubernetes \"kind\" field or the field is invalid or empty", + }, + { + name: "with empty field", + fieldName: "", + want: "resource does not have a Kubernetes \"\" field or the field is invalid or empty", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &StrictModeSkipErr{ + FieldName: tt.fieldName, + } + + require.Equal(t, tt.want, s.Error()) + }) + } +} + +func TestSkipErr_Error(t *testing.T) { + tests := []struct { + name string + err SkipErr + want string + }{ + { + name: "with name and kind", + err: SkipErr{ + Name: "my-pod", + Kind: "Pod", + }, + want: "resource Pod \"my-pod\" is configured to be skipped", + }, + { + name: "with name, kind and reason", + err: SkipErr{ + Name: "my-pod", + Kind: "Pod", + Reason: "matched exclusion filter", + }, + want: "resource Pod \"my-pod\" is skipped: matched exclusion filter", + }, + { + name: "with group only", + err: SkipErr{ + Group: "apps/v1", + }, + want: "resource with API group \"apps/v1\" is configured to be skipped", + }, + { + name: "with group and reason", + err: SkipErr{ + Group: "apps/v1", + Reason: "matched exclusion filter", + }, + want: "resource with API group \"apps/v1\" is skipped: matched exclusion filter", + }, + { + name: "empty fields", + err: SkipErr{}, + want: "resource is configured to be skipped", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, tt.err.Error()) + }) + } +} + +// mockMeta implements the empty() method for testing CantFindFieldErr +type mockMeta struct { + isEmpty bool + str string +} + +func (m mockMeta) empty() bool { + return m.isEmpty +} + +func (m mockMeta) String() string { + return m.str +} + +// mockMetaStringOnly implements just the String() method without empty() +type mockMetaStringOnly struct { + str string +} + +func (m mockMetaStringOnly) String() string { + return m.str +} + +func TestErrorsInterface(t *testing.T) { + require.Implementsf(t, (*error)(nil), &StrictModeSkipErr{}, "StrictModeSkipErr should implement error") + require.Implementsf(t, (*error)(nil), &SkipErr{}, "SkipErr should implement error") + require.Implementsf(t, (*error)(nil), &CantFindFieldErr{}, "CantFindFieldErr should implement error") +} + +func TestCantFindFieldErr_Error(t *testing.T) { + tests := []struct { + name string + fieldName string + fileCount int + meta interface{} + want string + }{ + { + name: "with empty meta", + fieldName: "metadata.name", + fileCount: 1, + meta: mockMeta{isEmpty: true}, + want: "unable to find Kubernetes \"metadata.name\" field in file 1: " + nonKubernetesMessage, + }, + { + name: "with non-empty meta with stringer", + fieldName: "metadata.name", + fileCount: 2, + meta: mockMeta{isEmpty: false, str: "Pod/my-pod"}, + want: "unable to find Kubernetes \"metadata.name\" field in file 2: Pod/my-pod", + }, + { + name: "with meta implementing only String", + fieldName: "kind", + fileCount: 3, + meta: mockMetaStringOnly{str: "Kind/Deployment"}, + want: "unable to find Kubernetes \"kind\" field in file 3: Kind/Deployment", + }, + { + name: "with nil meta", + fieldName: "kind", + fileCount: 4, + meta: nil, + want: "unable to find Kubernetes \"kind\" field in file 4", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := &CantFindFieldErr{ + FieldName: tt.fieldName, + FileCount: tt.fileCount, + Meta: tt.meta, + } + if got := e.Error(); got != tt.want { + t.Errorf("CantFindFieldErr.Error() = %q, want %q", got, tt.want) + } + }) + } +} diff --git a/slice/utils.go b/pkg/files/io.go similarity index 59% rename from slice/utils.go rename to pkg/files/io.go index 739a1a8..e8656bd 100644 --- a/slice/utils.go +++ b/pkg/files/io.go @@ -1,4 +1,4 @@ -package slice +package files import ( "bytes" @@ -6,23 +6,14 @@ import ( "io" "os" "path/filepath" + "slices" "strings" ) -func inarray[T comparable](needle T, haystack []T) bool { - for _, v := range haystack { - if v == needle { - return true - } - } - - return false -} - -// loadfolder reads the folder contents recursively for `.yaml` and `.yml` files -// and returns a buffer with the contents of all files found; returns the buffer -// with all the files separated by `---` and the number of files found -func loadfolder(extensions []string, folderPath string, recurse bool) (*bytes.Buffer, int, error) { +// LoadFolder reads contents from files with matching extensions in the specified folder. +// Returns a buffer with all file contents concatenated with "---" separators between them, +// a count of files processed, and any error encountered. +func LoadFolder(extensions []string, folderPath string, recurse bool) (*bytes.Buffer, int, error) { var buffer bytes.Buffer var count int @@ -39,7 +30,7 @@ func loadfolder(extensions []string, folderPath string, recurse bool) (*bytes.Bu } ext := strings.ToLower(filepath.Ext(path)) - if inarray(ext, extensions) { + if inArray(ext, extensions) { count++ data, err := os.ReadFile(path) @@ -67,8 +58,10 @@ func loadfolder(extensions []string, folderPath string, recurse bool) (*bytes.Bu return &buffer, count, nil } -func loadfile(fp string) (*bytes.Buffer, error) { - f, err := openFile(fp) +// LoadFile reads a file from the filesystem and returns its contents as a buffer. +// Handles errors for file access issues. +func LoadFile(fp string) (*bytes.Buffer, error) { + f, err := OpenFile(fp) if err != nil { return nil, err } @@ -83,14 +76,13 @@ func loadfile(fp string) (*bytes.Buffer, error) { return &buf, nil } -func openFile(fp string) (*os.File, error) { - if fp == os.Stdin.Name() { - // On Windows, the name in Go for stdin is `/dev/stdin` which doesn't - // exist. It must use the syscall to point to the file and open it +// OpenFile opens a file for reading with special handling for stdin. +// When the filename is "-", it returns os.Stdin instead of attempting to open a file. +func OpenFile(fp string) (*os.File, error) { + if fp == os.Stdin.Name() || fp == "-" { return os.Stdin, nil } - // Any other file that's not stdin can be opened normally f, err := os.Open(fp) if err != nil { return nil, fmt.Errorf("unable to open file %q: %s", fp, err.Error()) @@ -99,7 +91,9 @@ func openFile(fp string) (*os.File, error) { return f, nil } -func deleteFolderContents(location string) error { +// DeleteFolderContents removes all files and subdirectories within the specified directory. +// The directory itself is preserved. +func DeleteFolderContents(location string) error { f, err := os.Open(location) if err != nil { return fmt.Errorf("unable to open folder %q: %s", location, err.Error()) @@ -119,3 +113,8 @@ func deleteFolderContents(location string) error { return nil } + +// inArray checks if an element exists in a slice +func inArray[T comparable](needle T, haystack []T) bool { + return slices.Contains(haystack, needle) +} diff --git a/pkg/files/io_test.go b/pkg/files/io_test.go new file mode 100644 index 0000000..8e43a27 --- /dev/null +++ b/pkg/files/io_test.go @@ -0,0 +1,226 @@ +package files + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestInArray(t *testing.T) { + tests := []struct { + name string + needle string + haystack []string + expected bool + }{ + { + name: "found in array", + needle: ".yaml", + haystack: []string{".yml", ".yaml", ".json"}, + expected: true, + }, + { + name: "not found in array", + needle: ".txt", + haystack: []string{".yml", ".yaml", ".json"}, + expected: false, + }, + { + name: "empty array", + needle: ".yaml", + haystack: []string{}, + expected: false, + }, + { + name: "empty needle", + needle: "", + haystack: []string{".yml", ".yaml", ".json"}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := inArray(tt.needle, tt.haystack) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestOpenFile(t *testing.T) { + t.Run("open stdin when filename matches stdin name", func(t *testing.T) { + f, err := OpenFile(os.Stdin.Name()) + require.NoError(t, err) + assert.Equal(t, os.Stdin, f) + }) + + t.Run("open stdin when filename is dash", func(t *testing.T) { + f, err := OpenFile("-") + require.NoError(t, err) + assert.Equal(t, os.Stdin, f) + }) + + t.Run("open existing file", func(t *testing.T) { + // Create a temporary file + tmpFile, err := os.CreateTemp("", "test-open-file-*.txt") + require.NoError(t, err) + defer os.Remove(tmpFile.Name()) + tmpFile.Close() + + // Open the file using our function + f, err := OpenFile(tmpFile.Name()) + require.NoError(t, err) + defer f.Close() + + assert.NotNil(t, f) + }) + + t.Run("error opening non-existent file", func(t *testing.T) { + _, err := OpenFile("/non/existent/file.txt") + assert.Error(t, err) + assert.Contains(t, err.Error(), "unable to open file") + }) +} + +func TestLoadFile(t *testing.T) { + t.Run("load existing file", func(t *testing.T) { + // Create a temporary file with content + content := []byte("test content") + tmpFile, err := os.CreateTemp("", "test-load-file-*.txt") + require.NoError(t, err) + defer os.Remove(tmpFile.Name()) + + _, err = tmpFile.Write(content) + require.NoError(t, err) + tmpFile.Close() + + // Load the file + buf, err := LoadFile(tmpFile.Name()) + require.NoError(t, err) + assert.Equal(t, "test content", buf.String()) + }) + + t.Run("error loading non-existent file", func(t *testing.T) { + _, err := LoadFile("/non/existent/file.txt") + assert.Error(t, err) + }) +} + +func TestLoadFolder(t *testing.T) { + t.Run("load files from folder with matching extensions", func(t *testing.T) { + // Create a temporary directory + tmpDir, err := os.MkdirTemp("", "test-load-folder-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + // Create files with different extensions + files := map[string]string{ + "file1.yaml": "content1", + "file2.yml": "content2", + "file3.json": "content3", + "file4.txt": "content4", + } + + for name, content := range files { + err := os.WriteFile(filepath.Join(tmpDir, name), []byte(content), 0644) + require.NoError(t, err) + } + + // Test loading only yaml/yml files + extensions := []string{".yaml", ".yml"} + buf, count, err := LoadFolder(extensions, tmpDir, false) + require.NoError(t, err) + assert.Equal(t, 2, count) + + // The buffer should contain content1 and content2 with separator + // Order is not guaranteed, so we check for both possibilities + expected1 := "content1\n---\ncontent2" + expected2 := "content2\n---\ncontent1" + bufStr := buf.String() + assert.True(t, bufStr == expected1 || bufStr == expected2, "Expected buffer to contain concatenated yaml/yml file contents") + }) + + t.Run("load files recursively", func(t *testing.T) { + // Create a temporary directory structure + tmpDir, err := os.MkdirTemp("", "test-load-folder-recursive-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + // Create subdirectory + subDir := filepath.Join(tmpDir, "subdir") + require.NoError(t, os.Mkdir(subDir, 0755)) + + // Create files in main and sub directory + require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "main.yaml"), []byte("main-content"), 0644)) + require.NoError(t, os.WriteFile(filepath.Join(subDir, "sub.yaml"), []byte("sub-content"), 0644)) + + // Test with recursion enabled + extensions := []string{".yaml"} + buf, count, err := LoadFolder(extensions, tmpDir, true) + require.NoError(t, err) + assert.Equal(t, 2, count) + assert.Contains(t, buf.String(), "main-content") + assert.Contains(t, buf.String(), "sub-content") + + // Test with recursion disabled + buf, count, err = LoadFolder(extensions, tmpDir, false) + require.NoError(t, err) + assert.Equal(t, 1, count) + assert.Contains(t, buf.String(), "main-content") + assert.NotContains(t, buf.String(), "sub-content") + }) + + t.Run("error when no matching files found", func(t *testing.T) { + // Create a temporary directory + tmpDir, err := os.MkdirTemp("", "test-load-folder-empty-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + // Test with extensions that don't match any files + extensions := []string{".xyz"} + _, _, err = LoadFolder(extensions, tmpDir, false) + assert.Error(t, err) + assert.Contains(t, err.Error(), "no files found") + }) + + t.Run("error with non-existent folder", func(t *testing.T) { + extensions := []string{".yaml"} + _, _, err := LoadFolder(extensions, "/non/existent/folder", false) + assert.Error(t, err) + }) +} + +func TestDeleteFolderContents(t *testing.T) { + t.Run("delete folder contents", func(t *testing.T) { + // Create a temporary directory with contents + tmpDir, err := os.MkdirTemp("", "test-delete-folder-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + // Create files and subdirectory + require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "file1.txt"), []byte("content"), 0644)) + subDir := filepath.Join(tmpDir, "subdir") + require.NoError(t, os.Mkdir(subDir, 0755)) + require.NoError(t, os.WriteFile(filepath.Join(subDir, "file2.txt"), []byte("content"), 0644)) + + // Delete contents + err = DeleteFolderContents(tmpDir) + require.NoError(t, err) + + // Verify contents are deleted but directory still exists + entries, err := os.ReadDir(tmpDir) + require.NoError(t, err) + assert.Empty(t, entries) + _, err = os.Stat(tmpDir) + assert.NoError(t, err) + }) + + t.Run("error with non-existent folder", func(t *testing.T) { + err := DeleteFolderContents("/non/existent/folder") + assert.Error(t, err) + assert.Contains(t, err.Error(), "unable to open folder") + }) +} diff --git a/pkg/kubernetes/metadata.go b/pkg/kubernetes/metadata.go new file mode 100644 index 0000000..fb0eb15 --- /dev/null +++ b/pkg/kubernetes/metadata.go @@ -0,0 +1,71 @@ +package kubernetes + +import ( + "strings" +) + +// ObjectMeta represents the metadata for a Kubernetes object +type ObjectMeta struct { + APIVersion string + Kind string + Name string + Namespace string + Group string +} + +// GetGroupFromAPIVersion extracts the group from the APIVersion field +func (k *ObjectMeta) GetGroupFromAPIVersion() string { + fields := strings.Split(k.APIVersion, "/") + if len(fields) == 2 { + return strings.ToLower(fields[0]) + } + + return "" +} + +// Empty checks if all fields in the metadata are empty +func (k *ObjectMeta) Empty() bool { + return k.APIVersion == "" && k.Kind == "" && k.Name == "" && k.Namespace == "" +} + +// String returns a string representation of the metadata +func (k *ObjectMeta) String() string { + return strings.TrimSpace(strings.Join([]string{ + "kind " + k.Kind, + "name " + k.Name, + "apiVersion " + k.APIVersion, + }, ", ")) +} + +// CheckStringInMap checks if a string is in a map, and returns its value if found +func CheckStringInMap(local map[string]interface{}, key string) string { + iface, found := local[key] + + if !found { + return "" + } + + str, ok := iface.(string) + if !ok { + return "" + } + + return str +} + +// ExtractMetadata extracts Kubernetes metadata from a YAML manifest +func ExtractMetadata(manifest map[string]interface{}) *ObjectMeta { + metadata := &ObjectMeta{ + APIVersion: CheckStringInMap(manifest, "apiVersion"), + Kind: CheckStringInMap(manifest, "kind"), + } + + if md, found := manifest["metadata"]; found { + if mdMap, ok := md.(map[string]interface{}); ok { + metadata.Name = CheckStringInMap(mdMap, "name") + metadata.Namespace = CheckStringInMap(mdMap, "namespace") + } + } + + return metadata +} diff --git a/pkg/kubernetes/metadata_test.go b/pkg/kubernetes/metadata_test.go new file mode 100644 index 0000000..796b4a1 --- /dev/null +++ b/pkg/kubernetes/metadata_test.go @@ -0,0 +1,276 @@ +package kubernetes + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGetGroupFromAPIVersion(t *testing.T) { + tests := []struct { + name string + apiVersion string + want string + }{ + { + name: "with group", + apiVersion: "apps/v1", + want: "apps", + }, + { + name: "core group", + apiVersion: "v1", + want: "", + }, + { + name: "empty string", + apiVersion: "", + want: "", + }, + { + name: "multiple slashes", + apiVersion: "example.com/group/v1", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := &ObjectMeta{ + APIVersion: tt.apiVersion, + } + require.Equal(t, tt.want, k.GetGroupFromAPIVersion()) + }) + } +} + +func TestEmpty(t *testing.T) { + tests := []struct { + name string + metadata ObjectMeta + want bool + }{ + { + name: "empty metadata", + metadata: ObjectMeta{}, + want: true, + }, + { + name: "only APIVersion", + metadata: ObjectMeta{ + APIVersion: "v1", + }, + want: false, + }, + { + name: "only Kind", + metadata: ObjectMeta{ + Kind: "Pod", + }, + want: false, + }, + { + name: "only Name", + metadata: ObjectMeta{ + Name: "test-pod", + }, + want: false, + }, + { + name: "only Namespace", + metadata: ObjectMeta{ + Namespace: "default", + }, + want: false, + }, + { + name: "fully populated", + metadata: ObjectMeta{ + APIVersion: "v1", + Kind: "Pod", + Name: "test-pod", + Namespace: "default", + }, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, tt.metadata.Empty()) + }) + } +} + +func TestString(t *testing.T) { + tests := []struct { + name string + metadata ObjectMeta + want string + }{ + { + name: "empty metadata", + metadata: ObjectMeta{}, + want: "kind , name , apiVersion", + }, + { + name: "fully populated", + metadata: ObjectMeta{ + APIVersion: "v1", + Kind: "Pod", + Name: "test-pod", + Namespace: "default", + }, + want: "kind Pod, name test-pod, apiVersion v1", + }, + { + name: "partial data", + metadata: ObjectMeta{ + APIVersion: "apps/v1", + Kind: "Deployment", + }, + want: "kind Deployment, name , apiVersion apps/v1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, tt.metadata.String()) + }) + } +} + +func TestCheckStringInMap(t *testing.T) { + tests := []struct { + name string + local map[string]interface{} + key string + want string + }{ + { + name: "empty map", + local: map[string]interface{}{}, + key: "test", + want: "", + }, + { + name: "key exists with string value", + local: map[string]interface{}{ + "test": "value", + }, + key: "test", + want: "value", + }, + { + name: "key exists with non-string value", + local: map[string]interface{}{ + "test": 123, + }, + key: "test", + want: "", + }, + { + name: "key doesn't exist", + local: map[string]interface{}{ + "other": "value", + }, + key: "test", + want: "", + }, + { + name: "multiple keys", + local: map[string]interface{}{ + "key1": "value1", + "key2": "value2", + }, + key: "key2", + want: "value2", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, CheckStringInMap(tt.local, tt.key)) + }) + } +} + +func TestExtractMetadata(t *testing.T) { + tests := []struct { + name string + manifest map[string]interface{} + want ObjectMeta + }{ + { + name: "empty manifest", + manifest: map[string]interface{}{}, + want: ObjectMeta{}, + }, + { + name: "minimal manifest", + manifest: map[string]interface{}{ + "apiVersion": "v1", + "kind": "Pod", + }, + want: ObjectMeta{ + APIVersion: "v1", + Kind: "Pod", + }, + }, + { + name: "full manifest", + manifest: map[string]interface{}{ + "apiVersion": "apps/v1", + "kind": "Deployment", + "metadata": map[string]interface{}{ + "name": "test-deployment", + "namespace": "default", + }, + }, + want: ObjectMeta{ + APIVersion: "apps/v1", + Kind: "Deployment", + Name: "test-deployment", + Namespace: "default", + }, + }, + { + name: "metadata is not a map", + manifest: map[string]interface{}{ + "apiVersion": "v1", + "kind": "Pod", + "metadata": "invalid", + }, + want: ObjectMeta{ + APIVersion: "v1", + Kind: "Pod", + }, + }, + { + name: "metadata with non-string values", + manifest: map[string]interface{}{ + "apiVersion": "v1", + "kind": "Pod", + "metadata": map[string]interface{}{ + "name": 123, + "namespace": true, + }, + }, + want: ObjectMeta{ + APIVersion: "v1", + Kind: "Pod", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ExtractMetadata(tt.manifest) + + require.Equal(t, tt.want.APIVersion, got.APIVersion) + require.Equal(t, tt.want.Kind, got.Kind) + require.Equal(t, tt.want.Name, got.Name) + require.Equal(t, tt.want.Namespace, got.Namespace) + }) + } +} diff --git a/slice/kube.go b/pkg/kubernetes/sorting.go similarity index 53% rename from slice/kube.go rename to pkg/kubernetes/sorting.go index 723a0ee..95a2419 100644 --- a/slice/kube.go +++ b/pkg/kubernetes/sorting.go @@ -1,39 +1,20 @@ -package slice +package kubernetes import ( "sort" - "strings" ) -type yamlFile struct { - filename string - meta kubeObjectMeta - data []byte -} - -type kubeObjectMeta struct { - APIVersion string - Kind string - Name string - Namespace string - Group string -} - -func (objectMeta *kubeObjectMeta) GetGroupFromAPIVersion() string { - fields := strings.Split(objectMeta.APIVersion, "/") - if len(fields) == 2 { - return strings.ToLower(fields[0]) - } - - return "" -} - -func (k kubeObjectMeta) empty() bool { - return k.APIVersion == "" && k.Kind == "" && k.Name == "" && k.Namespace == "" +// YAMLFile represents a Kubernetes YAML file with associated metadata and content. +// It's used throughout the application for storing and processing YAML resources. +type YAMLFile struct { + Filename string + Meta *ObjectMeta + Data []byte } +// HelmInstallOrder defines the order in which Kubernetes resources should be installed // from: https://github.com/helm/helm/blob/v3.11.1/pkg/releaseutil/kind_sorter.go#LL31-L67C2 -var helmInstallOrder = []string{ +var HelmInstallOrder = []string{ "Namespace", "NetworkPolicy", "ResourceQuota", @@ -71,19 +52,24 @@ var helmInstallOrder = []string{ "APIService", } -// from: https://github.com/helm/helm/blob/v3.11.1/pkg/releaseutil/kind_sorter.go#L113-L119 -func sortYAMLsByKind(manifests []yamlFile) []yamlFile { +// SortByKind sorts a slice of YAMLFile according to Kubernetes resource kind ordering +func SortByKind(manifests []YAMLFile) []YAMLFile { sort.SliceStable(manifests, func(i, j int) bool { - return lessByKind(manifests[i], manifests[j], manifests[i].meta.Kind, manifests[j].meta.Kind, helmInstallOrder) + return lessByKind( + manifests[i].Meta.Kind, + manifests[j].Meta.Kind, + HelmInstallOrder, + ) }) return manifests } +// lessByKind compares two kinds and determines their relative order // from: https://github.com/helm/helm/blob/v3.11.1/pkg/releaseutil/kind_sorter.go#L133-L158 -func lessByKind(_ interface{}, _ interface{}, kindA string, kindB string, o []string) bool { - ordering := make(map[string]int, len(o)) - for v, k := range o { +func lessByKind(kindA, kindB string, order []string) bool { + ordering := make(map[string]int, len(order)) + for v, k := range order { ordering[k] = v } @@ -91,7 +77,7 @@ func lessByKind(_ interface{}, _ interface{}, kindA string, kindB string, o []st second, bok := ordering[kindB] if !aok && !bok { - // if both are unknown then sort alphabetically by kind, keep original order if same kind + // if both are unknown then sort alphabetically by kind if kindA != kindB { return kindA < kindB } diff --git a/pkg/kubernetes/sorting_test.go b/pkg/kubernetes/sorting_test.go new file mode 100644 index 0000000..d1671f2 --- /dev/null +++ b/pkg/kubernetes/sorting_test.go @@ -0,0 +1,265 @@ +package kubernetes + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestLessByKind(t *testing.T) { + tests := []struct { + name string + kindA string + kindB string + order []string + want bool + }{ + { + name: "both kinds in order, kindA before kindB", + kindA: "Namespace", + kindB: "Service", + order: HelmInstallOrder, + want: true, // Namespace comes before Service + }, + { + name: "both kinds in order, kindB before kindA", + kindA: "Service", + kindB: "Namespace", + order: HelmInstallOrder, + want: false, // Service comes after Namespace + }, + { + name: "same kinds", + kindA: "Service", + kindB: "Service", + order: HelmInstallOrder, + want: false, // Same kinds should maintain order + }, + { + name: "kindA in order, kindB not in order", + kindA: "Namespace", + kindB: "UnknownKind", + order: HelmInstallOrder, + want: true, // Known kinds come before unknown + }, + { + name: "kindA not in order, kindB in order", + kindA: "UnknownKind", + kindB: "Namespace", + order: HelmInstallOrder, + want: false, // Unknown kinds come after known + }, + { + name: "neither kind in order, alphabetical first", + kindA: "AAA", + kindB: "ZZZ", + order: HelmInstallOrder, + want: true, // Alphabetical ordering for unknown kinds + }, + { + name: "neither kind in order, alphabetical second", + kindA: "ZZZ", + kindB: "AAA", + order: HelmInstallOrder, + want: false, // Alphabetical ordering for unknown kinds + }, + { + name: "both kinds unknown and equal", + kindA: "Unknown", + kindB: "Unknown", + order: HelmInstallOrder, + want: false, // Same kinds should maintain order + }, + { + name: "empty order slice", + kindA: "Service", + kindB: "Namespace", + order: []string{}, + want: false, // Alphabetical ordering when no order specified (N before S) + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := lessByKind(tt.kindA, tt.kindB, tt.order) + require.Equal(t, tt.want, result) + }) + } +} + +func TestSortByKind(t *testing.T) { + tests := []struct { + name string + manifests []YAMLFile + want []string // Expected order of kinds after sorting + }{ + { + name: "already sorted manifests", + manifests: []YAMLFile{ + { + Filename: "namespace.yaml", + Meta: &ObjectMeta{ + Kind: "Namespace", + }, + }, + { + Filename: "service.yaml", + Meta: &ObjectMeta{ + Kind: "Service", + }, + }, + { + Filename: "deployment.yaml", + Meta: &ObjectMeta{ + Kind: "Deployment", + }, + }, + }, + want: []string{"Namespace", "Service", "Deployment"}, + }, + { + name: "unsorted manifests", + manifests: []YAMLFile{ + { + Filename: "deployment.yaml", + Meta: &ObjectMeta{ + Kind: "Deployment", + }, + }, + { + Filename: "namespace.yaml", + Meta: &ObjectMeta{ + Kind: "Namespace", + }, + }, + { + Filename: "service.yaml", + Meta: &ObjectMeta{ + Kind: "Service", + }, + }, + }, + want: []string{"Namespace", "Service", "Deployment"}, + }, + { + name: "with unknown kinds", + manifests: []YAMLFile{ + { + Filename: "deployment.yaml", + Meta: &ObjectMeta{ + Kind: "Deployment", + }, + }, + { + Filename: "unknown.yaml", + Meta: &ObjectMeta{ + Kind: "UnknownKind", + }, + }, + { + Filename: "namespace.yaml", + Meta: &ObjectMeta{ + Kind: "Namespace", + }, + }, + }, + want: []string{"Namespace", "Deployment", "UnknownKind"}, + }, + { + name: "multiple of same kind", + manifests: []YAMLFile{ + { + Filename: "deployment1.yaml", + Meta: &ObjectMeta{ + Kind: "Deployment", + }, + }, + { + Filename: "deployment2.yaml", + Meta: &ObjectMeta{ + Kind: "Deployment", + }, + }, + { + Filename: "namespace.yaml", + Meta: &ObjectMeta{ + Kind: "Namespace", + }, + }, + }, + want: []string{"Namespace", "Deployment", "Deployment"}, + }, + { + name: "all unknown kinds - alphabetical", + manifests: []YAMLFile{ + { + Filename: "c.yaml", + Meta: &ObjectMeta{ + Kind: "C", + }, + }, + { + Filename: "a.yaml", + Meta: &ObjectMeta{ + Kind: "A", + }, + }, + { + Filename: "b.yaml", + Meta: &ObjectMeta{ + Kind: "B", + }, + }, + }, + want: []string{"A", "B", "C"}, + }, + { + name: "empty manifests", + manifests: []YAMLFile{}, + want: []string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sorted := SortByKind(tt.manifests) + + // Check if the length matches + require.Equal(t, len(tt.want), len(sorted)) + + // Verify the order of kinds + if len(sorted) > 0 { + kinds := make([]string, len(sorted)) + for i, manifest := range sorted { + kinds[i] = manifest.Meta.Kind + } + require.Equal(t, tt.want, kinds) + } + }) + } +} + +// TestHelmInstallOrder verifies that the predefined HelmInstallOrder slice is correctly defined +func TestHelmInstallOrder(t *testing.T) { + // Verify some key ordering relationships from the Helm install order + require.Contains(t, HelmInstallOrder, "Namespace") + require.Contains(t, HelmInstallOrder, "Deployment") + require.Contains(t, HelmInstallOrder, "Service") + + // Namespace should come before Service + namespaceIndex := -1 + serviceIndex := -1 + + for i, kind := range HelmInstallOrder { + if kind == "Namespace" { + namespaceIndex = i + } + if kind == "Service" { + serviceIndex = i + } + } + + require.NotEqual(t, -1, namespaceIndex) + require.NotEqual(t, -1, serviceIndex) + require.Less(t, namespaceIndex, serviceIndex) +} diff --git a/pkg/kubernetes/validation.go b/pkg/kubernetes/validation.go new file mode 100644 index 0000000..7dc3bb4 --- /dev/null +++ b/pkg/kubernetes/validation.go @@ -0,0 +1,71 @@ +package kubernetes + +import ( + "fmt" + "strings" + + "github.com/patrickdappollonio/kubectl-slice/pkg/errors" +) + +// CheckGroupInclusion validates if a resource belongs to any of the specified groups +// Returns nil if the resource should be included, or an error if it should be skipped +func CheckGroupInclusion(objmeta *ObjectMeta, groupNames []string, included bool) error { + resourceGroup := objmeta.GetGroupFromAPIVersion() + + for _, group := range groupNames { + if included { + if resourceGroup == strings.ToLower(group) { + return nil + } + } else { + if resourceGroup == strings.ToLower(group) { + return &errors.SkipErr{ + Name: objmeta.Name, + Kind: objmeta.Kind, + Group: resourceGroup, + Reason: fmt.Sprintf("matches excluded group %q", group), + } + } + } + } + + if included { + var reason string + if len(groupNames) > 0 { + reason = fmt.Sprintf("does not match any included groups %v", groupNames) + } else { + reason = "no included groups specified" + } + + return &errors.SkipErr{ + Name: objmeta.Name, + Kind: objmeta.Kind, + Group: resourceGroup, + Reason: reason, + } + } + + return nil +} + +// ValidateRequiredFields verifies if a resource has all required Kubernetes fields +// when operating in strict mode +func ValidateRequiredFields(meta *ObjectMeta, strictMode bool) error { + if !strictMode { + return nil + } + + if meta.APIVersion == "" { + return &errors.StrictModeSkipErr{FieldName: "apiVersion"} + } + + if meta.Kind == "" { + return &errors.StrictModeSkipErr{FieldName: "kind"} + } + + if meta.Name == "" { + return &errors.StrictModeSkipErr{FieldName: "metadata.name"} + } + + return nil +} diff --git a/pkg/kubernetes/validation_test.go b/pkg/kubernetes/validation_test.go new file mode 100644 index 0000000..e730d70 --- /dev/null +++ b/pkg/kubernetes/validation_test.go @@ -0,0 +1,260 @@ +package kubernetes + +import ( + "errors" + "testing" + + apperrors "github.com/patrickdappollonio/kubectl-slice/pkg/errors" + "github.com/stretchr/testify/require" +) + +func TestCheckGroupInclusion(t *testing.T) { + tests := []struct { + name string + objmeta *ObjectMeta + groupNames []string + included bool + expectSkipErr bool + expectedGroup string + expectedReason string + }{ + { + name: "included mode - group matches", + objmeta: &ObjectMeta{ + APIVersion: "apps/v1", + Kind: "Deployment", + Name: "test-app", + }, + groupNames: []string{"apps"}, + included: true, + expectSkipErr: false, + }, + { + name: "included mode - group doesn't match", + objmeta: &ObjectMeta{ + APIVersion: "apps/v1", + Kind: "Deployment", + Name: "test-app", + }, + groupNames: []string{"networking.k8s.io"}, + included: true, + expectSkipErr: true, + expectedGroup: "apps", + expectedReason: "does not match any included groups [networking.k8s.io]", + }, + { + name: "included mode - core group", + objmeta: &ObjectMeta{ + APIVersion: "v1", + Kind: "Pod", + Name: "test-pod", + }, + groupNames: []string{""}, + included: true, + expectSkipErr: false, + }, + { + name: "excluded mode - group matches", + objmeta: &ObjectMeta{ + APIVersion: "apps/v1", + Kind: "Deployment", + Name: "test-app", + }, + groupNames: []string{"apps"}, + included: false, + expectSkipErr: true, + expectedGroup: "apps", + expectedReason: "matches excluded group \"apps\"", + }, + { + name: "excluded mode - group doesn't match", + objmeta: &ObjectMeta{ + APIVersion: "apps/v1", + Kind: "Deployment", + Name: "test-app", + }, + groupNames: []string{"networking.k8s.io"}, + included: false, + expectSkipErr: false, + }, + { + name: "multiple groups - included mode - one matches", + objmeta: &ObjectMeta{ + APIVersion: "apps/v1", + Kind: "Deployment", + Name: "test-app", + }, + groupNames: []string{"networking.k8s.io", "apps", "batch"}, + included: true, + expectSkipErr: false, + }, + { + name: "multiple groups - excluded mode - one matches", + objmeta: &ObjectMeta{ + APIVersion: "apps/v1", + Kind: "Deployment", + Name: "test-app", + }, + groupNames: []string{"networking.k8s.io", "apps", "batch"}, + included: false, + expectSkipErr: true, + expectedGroup: "apps", + expectedReason: "matches excluded group \"apps\"", + }, + { + name: "case insensitive comparison", + objmeta: &ObjectMeta{ + APIVersion: "Apps/v1", + Kind: "Deployment", + Name: "test-app", + }, + groupNames: []string{"apps"}, + included: true, + expectSkipErr: false, + }, + { + name: "empty group list - included mode", + objmeta: &ObjectMeta{ + APIVersion: "apps/v1", + Kind: "Deployment", + Name: "test-app", + }, + groupNames: []string{}, + included: true, + expectSkipErr: true, + expectedGroup: "apps", + expectedReason: "no included groups specified", + }, + { + name: "empty group list - excluded mode", + objmeta: &ObjectMeta{ + APIVersion: "apps/v1", + Kind: "Deployment", + Name: "test-app", + }, + groupNames: []string{}, + included: false, + expectSkipErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CheckGroupInclusion(tt.objmeta, tt.groupNames, tt.included) + + if tt.expectSkipErr { + require.Error(t, err) + + // Use errors.As for type checking instead of type assertion + var skipErr *apperrors.SkipErr + require.True(t, errors.As(err, &skipErr), "Expected error of type *errors.SkipErr") + + // Additional checks for the error details when we have expected values + if tt.expectedGroup != "" { + require.Equal(t, tt.expectedGroup, skipErr.Group) + + if tt.expectedReason != "" { + require.Equal(t, tt.expectedReason, skipErr.Reason) + } + + require.Equal(t, tt.objmeta.Kind, skipErr.Kind) + require.Equal(t, tt.objmeta.Name, skipErr.Name) + } + } else { + require.NoError(t, err) + } + }) + } +} + +func TestValidateRequiredFields(t *testing.T) { + tests := []struct { + name string + meta *ObjectMeta + strictMode bool + expectErr bool + expectedField string + }{ + { + name: "strict mode - all fields present", + meta: &ObjectMeta{ + APIVersion: "apps/v1", + Kind: "Deployment", + Name: "test-deployment", + }, + strictMode: true, + expectErr: false, + }, + { + name: "non-strict mode - all fields present", + meta: &ObjectMeta{ + APIVersion: "apps/v1", + Kind: "Deployment", + Name: "test-deployment", + }, + strictMode: false, + expectErr: false, + }, + { + name: "non-strict mode - missing fields", + meta: &ObjectMeta{}, + strictMode: false, + expectErr: false, + }, + { + name: "strict mode - missing apiVersion", + meta: &ObjectMeta{ + Kind: "Deployment", + Name: "test-deployment", + }, + strictMode: true, + expectErr: true, + expectedField: "apiVersion", + }, + { + name: "strict mode - missing kind", + meta: &ObjectMeta{ + APIVersion: "apps/v1", + Name: "test-deployment", + }, + strictMode: true, + expectErr: true, + expectedField: "kind", + }, + { + name: "strict mode - missing name", + meta: &ObjectMeta{ + APIVersion: "apps/v1", + Kind: "Deployment", + }, + strictMode: true, + expectErr: true, + expectedField: "metadata.name", + }, + { + name: "strict mode - namespace optional", + meta: &ObjectMeta{ + APIVersion: "apps/v1", + Kind: "Deployment", + Name: "test-deployment", + }, + strictMode: true, + expectErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateRequiredFields(tt.meta, tt.strictMode) + + if tt.expectErr { + require.Error(t, err) + var strictErr *apperrors.StrictModeSkipErr + require.True(t, errors.As(err, &strictErr), "Expected error of type *errors.StrictModeSkipErr") + require.Equal(t, tt.expectedField, strictErr.FieldName) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go new file mode 100644 index 0000000..433b005 --- /dev/null +++ b/pkg/logger/logger.go @@ -0,0 +1,21 @@ +package logger + +import "io" + +// Logger is the interface used by Split to log debug messages +// and it's satisfied by Go's log.Logger +type Logger interface { + Printf(format string, v ...any) + SetOutput(w io.Writer) + Println(v ...any) +} + +// NOOP implements the Logger interface but performs no operations +type NOOP struct{} + +func (NOOP) Println(...any) {} +func (NOOP) SetOutput(_ io.Writer) {} +func (NOOP) Printf(_ string, _ ...any) {} + +// NOOPLogger is a pre-initialized no-operation logger instance that can be used for testing or when logging is disabled +var NOOPLogger Logger = &NOOP{} diff --git a/pkg/logger/logger_test.go b/pkg/logger/logger_test.go new file mode 100644 index 0000000..02ba731 --- /dev/null +++ b/pkg/logger/logger_test.go @@ -0,0 +1,52 @@ +package logger + +import ( + "io" + "testing" + + "github.com/stretchr/testify/require" +) + +type testLogger struct { + FnPrintln func(...any) + FnSetOutput func(w io.Writer) + FnPrintf func(format string, v ...any) + PrintlnCallCount int + SetOutputCallCount int + PrintfCallCount int +} + +func (tl *testLogger) Println(v ...any) { + tl.PrintlnCallCount++ + if tl.FnPrintln != nil { + tl.FnPrintln(v...) + } +} + +func (tl *testLogger) SetOutput(w io.Writer) { + tl.SetOutputCallCount++ + if tl.FnSetOutput != nil { + tl.FnSetOutput(w) + } +} + +func (tl *testLogger) Printf(format string, v ...any) { + tl.PrintfCallCount++ + if tl.FnPrintf != nil { + tl.FnPrintf(format, v...) + } +} + +func TestLogger(t *testing.T) { + var _ Logger = (*NOOP)(nil) // ensure NOOP satisfies Logger + + t.Run("call count", func(t *testing.T) { + tl := &testLogger{} + tl.Println("foo") + tl.SetOutput(io.Discard) + tl.Printf("foo", "bar") + require.Equal(t, 1, tl.PrintlnCallCount) + require.Equal(t, 1, tl.SetOutputCallCount) + require.Equal(t, 1, tl.PrintfCallCount) + }) +} diff --git a/pkg/template/funcs.go b/pkg/template/funcs.go new file mode 100644 index 0000000..477ac32 --- /dev/null +++ b/pkg/template/funcs.go @@ -0,0 +1,232 @@ +package template + +import ( + "crypto/sha1" + "crypto/sha256" + "encoding/hex" + "fmt" + "os" + "regexp" + "strings" + "text/template" + + "golang.org/x/text/cases" + "golang.org/x/text/language" + "gopkg.in/yaml.v3" +) + +// GetTemplateFunctions returns a map of functions that can be used in Go templates. +// These functions provide string manipulation, conversion, and other utilities +// for customizing the output filenames during the slice operation. +func GetTemplateFunctions() template.FuncMap { + return template.FuncMap{ + "pluralize": Pluralize, + "lower": jsonLower, + "lowercase": jsonLower, + "uppercase": jsonUpper, + "upper": jsonUpper, + "title": jsonTitle, + "sprintf": fmt.Sprintf, + "printf": fmt.Sprintf, + "trim": jsonTrimSpace, + "trimPrefix": jsonTrimPrefix, + "trimSuffix": jsonTrimSuffix, + "default": fnDefault, + "sha1sum": sha1sum, + "sha256sum": sha256sum, + "str": strJSON, + "required": jsonRequired, + "env": env, + "replace": jsonReplace, + "alphanumify": jsonAlphanumify, + "alphanumdash": jsonAlphanumdash, + "dottodash": jsonDotToDash, + "dottounder": jsonDotToUnder, + "index": mapValueByIndex, + "indexOrEmpty": mapValueByIndexOrEmpty, + } +} + +// Pluralize adds an "s" to the end of a string if n is not 1. +// This is useful for generating grammatically correct output when dealing with counts. +func Pluralize(s string, n int) string { + if n == 1 { + return s + } + return s + "s" +} + +// mapValueByIndexOrEmpty retrieves a value from a map without returning an error if the key is not found. +func mapValueByIndexOrEmpty(index string, m map[string]interface{}) interface{} { + if m == nil { + return "" + } + + if index == "" { + return "" + } + + v, ok := m[index] + if !ok { + return "" + } + + return v +} + +// mapValueByIndex retrieves a value from a map and returns an error if the key is not found. +func mapValueByIndex(index string, m map[string]interface{}) (interface{}, error) { + if m == nil { + return nil, fmt.Errorf("map is nil") + } + + if index == "" { + return nil, fmt.Errorf("map key is empty") + } + + v, ok := m[index] + if !ok { + return nil, fmt.Errorf("key %q not found", index) + } + + return v, nil +} + +// jsonLower converts string input to lowercase. It handles various input types +// by converting them to strings first. +func jsonLower(s interface{}) string { + return strings.ToLower(toString(s)) +} + +// jsonUpper converts string input to uppercase. It handles various input types +// by converting them to strings first. +func jsonUpper(s interface{}) string { + return strings.ToUpper(toString(s)) +} + +// jsonTitle converts string input to title case. It handles various input types +// by converting them to strings first, then applies proper title casing. +func jsonTitle(s interface{}) string { + return cases.Title(language.Und).String(toString(s)) +} + +// jsonTrimSpace trims whitespace from a string +func jsonTrimSpace(s interface{}) string { + return strings.TrimSpace(toString(s)) +} + +// jsonTrimPrefix trims a prefix from a string +func jsonTrimPrefix(prefix, s interface{}) string { + return strings.TrimPrefix(toString(s), toString(prefix)) +} + +// jsonTrimSuffix trims a suffix from a string +func jsonTrimSuffix(suffix, s interface{}) string { + return strings.TrimSuffix(toString(s), toString(suffix)) +} + +// fnDefault returns the default value if the original is empty +func fnDefault(def, orig interface{}) interface{} { + s := toString(orig) + if s != "" { + return s + } + + return def +} + +// sha1sum returns the SHA-1 hash of a string +func sha1sum(s interface{}) string { + sum := sha1.Sum([]byte(toString(s))) + return hex.EncodeToString(sum[:]) +} + +// sha256sum returns the SHA-256 hash of a string +func sha256sum(s interface{}) string { + sum := sha256.Sum256([]byte(toString(s))) + return hex.EncodeToString(sum[:]) +} + +// strJSON converts an object to a JSON string +func strJSON(v interface{}) string { + b, err := yaml.Marshal(v) + if err != nil { + return "" + } + + return string(b) +} + +// jsonRequired returns an error if the value is empty +func jsonRequired(warn string, val interface{}) (interface{}, error) { + if val == nil { + return val, fmt.Errorf("%s", warn) + } + + s := toString(val) + if s == "" { + return val, fmt.Errorf("%s", warn) + } + + return val, nil +} + +// env returns the value of an environment variable +func env(key interface{}) string { + return os.Getenv(toString(key)) +} + +// jsonReplace replaces all occurrences of a substring +func jsonReplace(old, new string, src interface{}) string { + return strings.ReplaceAll(toString(src), old, new) +} + +// alphanumRegex is a regular expression that matches only alphanumeric characters +var alphanumRegex = regexp.MustCompile(`[^a-zA-Z0-9]+`) + +// jsonAlphanumify returns only alphanumeric characters +func jsonAlphanumify(src interface{}) string { + return alphanumRegex.ReplaceAllString(toString(src), "") +} + +// alphanumDashRegex is a regular expression that matches alphanumeric characters and dashes +var alphanumDashRegex = regexp.MustCompile(`[^a-zA-Z0-9-]+`) + +// jsonAlphanumdash filters string input to contain only alphanumeric characters and dashes. +// All other characters are removed from the string. Useful for generating safe filenames. +func jsonAlphanumdash(src interface{}) string { + s := toString(src) + s = strings.ReplaceAll(s, "_", "-") + s = strings.ReplaceAll(s, ".", "-") + return alphanumDashRegex.ReplaceAllString(s, "") +} + +// jsonDotToDash replaces dots with dashes +func jsonDotToDash(src interface{}) string { + return strings.ReplaceAll(toString(src), ".", "-") +} + +// jsonDotToUnder replaces dots with underscores +func jsonDotToUnder(src interface{}) string { + return strings.ReplaceAll(toString(src), ".", "_") +} + +// toString converts an interface to a string +func toString(s interface{}) string { + if s == nil { + return "" + } + + switch v := s.(type) { + case string: + return v + case []byte: + return string(v) + case error: + return v.Error() + case fmt.Stringer: + return v.String() + default: + return fmt.Sprintf("%v", s) + } +} diff --git a/pkg/template/funcs_test.go b/pkg/template/funcs_test.go new file mode 100644 index 0000000..1439190 --- /dev/null +++ b/pkg/template/funcs_test.go @@ -0,0 +1,270 @@ +package template + +import ( + "math/rand/v2" + "os" + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_mapValueByIndexEmpty(t *testing.T) { + tests := []struct { + name string + index string + m map[string]any + want any + }{ + { + name: "nil map", + index: "foo", + m: nil, + want: "", + }, + { + name: "empty index", + index: "", + m: map[string]any{}, + want: "", + }, + { + name: "key not found", + index: "foo", + m: map[string]any{"bar": "baz"}, + want: "", + }, + { + name: "key found", + index: "foo", + m: map[string]any{"foo": "bar"}, + want: "bar", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := mapValueByIndexOrEmpty(tt.index, tt.m) + require.Equal(t, tt.want, got) + }) + } +} + +func Test_mapValueByIndex(t *testing.T) { + tests := []struct { + name string + index string + m map[string]any + want any + wantErr bool + }{ + { + name: "nil map", + index: "foo", + m: nil, + want: nil, + wantErr: true, + }, + { + name: "empty index", + index: "", + m: map[string]any{}, + want: nil, + wantErr: true, + }, + { + name: "key not found", + index: "foo", + m: map[string]any{"bar": "baz"}, + want: nil, + wantErr: true, + }, + { + name: "key found", + index: "foo", + m: map[string]any{"foo": "bar"}, + want: "bar", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := mapValueByIndex(tt.index, tt.m) + if tt.wantErr { + require.Error(t, err) + return + } + + require.NoError(t, err) + require.Equal(t, tt.want, got) + }) + } +} + +func Test_jsonLower(t *testing.T) { + require.Equal(t, "foo", jsonLower("FOO")) + require.Equal(t, "foo", jsonLower("foo")) + require.Equal(t, "123", jsonLower(123)) + require.Equal(t, "", jsonLower(nil)) +} + +func Test_jsonUpper(t *testing.T) { + require.Equal(t, "FOO", jsonUpper("foo")) + require.Equal(t, "FOO", jsonUpper("FOO")) + require.Equal(t, "123", jsonUpper(123)) + require.Equal(t, "", jsonUpper(nil)) +} + +func Test_jsonTitle(t *testing.T) { + require.Equal(t, "Foo", jsonTitle("foo")) + require.Equal(t, "Foo", jsonTitle("FOO")) + require.Equal(t, "123", jsonTitle(123)) + require.Equal(t, "", jsonTitle(nil)) +} + +func Test_jsonTrimSpace(t *testing.T) { + require.Equal(t, "foo", jsonTrimSpace(" foo ")) + require.Equal(t, "foo", jsonTrimSpace("foo")) + require.Equal(t, "123", jsonTrimSpace(123)) + require.Equal(t, "", jsonTrimSpace(nil)) +} + +func Test_jsonTrimPrefix(t *testing.T) { + require.Equal(t, "bar", jsonTrimPrefix("foo", "foobar")) + require.Equal(t, "bar", jsonTrimPrefix("foo", "bar")) + require.Equal(t, "123", jsonTrimPrefix("foo", 123)) + require.Equal(t, "", jsonTrimPrefix("foo", nil)) +} + +func Test_jsonTrimSuffix(t *testing.T) { + require.Equal(t, "foo", jsonTrimSuffix("bar", "foobar")) + require.Equal(t, "foo", jsonTrimSuffix("bar", "foo")) + require.Equal(t, "123", jsonTrimSuffix("bar", 123)) + require.Equal(t, "", jsonTrimSuffix("bar", nil)) +} + +func Test_fnDefault(t *testing.T) { + require.Equal(t, "foo", fnDefault("foo", "")) + require.Equal(t, "bar", fnDefault("foo", "bar")) + require.Equal(t, "123", fnDefault("foo", 123)) + require.Equal(t, "foo", fnDefault("foo", nil)) +} + +func Test_sha1sum(t *testing.T) { + require.Equal(t, "0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33", sha1sum("foo")) + require.Equal(t, "da39a3ee5e6b4b0d3255bfef95601890afd80709", sha1sum("")) + require.Equal(t, "da39a3ee5e6b4b0d3255bfef95601890afd80709", sha1sum(nil)) +} + +func Test_sha256sum(t *testing.T) { + require.Equal(t, "2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae", sha256sum("foo")) + require.Equal(t, "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", sha256sum("")) + require.Equal(t, "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", sha256sum(nil)) +} + +func Test_strJSON(t *testing.T) { + require.Equal(t, "foo\n", strJSON("foo")) + require.Equal(t, "123\n", strJSON(123)) + require.Equal(t, "null\n", strJSON(nil)) +} + +func Test_jsonRequired(t *testing.T) { + v, err := jsonRequired("value is required", "foo") + require.NoError(t, err) + require.Equal(t, "foo", v) + + v, err = jsonRequired("value is required", "") + require.Error(t, err) + require.Equal(t, "", v) + + v, err = jsonRequired("value is required", nil) + require.Error(t, err) + require.Equal(t, nil, v) +} + +func Test_env(t *testing.T) { + os.Setenv("TEST_ENV_VAR", "foo") + defer os.Unsetenv("TEST_ENV_VAR") + + require.Equal(t, "foo", env("TEST_ENV_VAR")) + require.Equal(t, "", env("NONEXISTENT_ENV_VAR")) + require.Equal(t, "", env("")) +} + +func Test_jsonReplace(t *testing.T) { + require.Equal(t, "foobaz", jsonReplace("bar", "baz", "foobar")) + require.Equal(t, "123", jsonReplace("bar", "baz", 123)) + require.Equal(t, "", jsonReplace("bar", "baz", nil)) +} + +func Test_jsonAlphanumify(t *testing.T) { + require.Equal(t, "foobar123", jsonAlphanumify("foo-bar-123")) + require.Equal(t, "foobar123", jsonAlphanumify("foo_bar_123")) + require.Equal(t, "foobar123", jsonAlphanumify("foo.bar.123")) + require.Equal(t, "123", jsonAlphanumify(123)) + require.Equal(t, "", jsonAlphanumify(nil)) +} + +func Test_jsonAlphanumdash(t *testing.T) { + require.Equal(t, "foo-bar-123", jsonAlphanumdash("foo-bar-123")) + require.Equal(t, "foo-bar-123", jsonAlphanumdash("foo_bar_123")) + require.Equal(t, "foo-bar-123", jsonAlphanumdash("foo.bar.123")) + require.Equal(t, "123", jsonAlphanumdash(123)) + require.Equal(t, "", jsonAlphanumdash(nil)) +} + +func Test_jsonDotToDash(t *testing.T) { + require.Equal(t, "foo-bar-123", jsonDotToDash("foo.bar.123")) + require.Equal(t, "foo_bar_123", jsonDotToDash("foo_bar_123")) + require.Equal(t, "123", jsonDotToDash(123)) + require.Equal(t, "", jsonDotToDash(nil)) +} + +func Test_jsonDotToUnder(t *testing.T) { + require.Equal(t, "foo_bar_123", jsonDotToUnder("foo.bar.123")) + require.Equal(t, "foo-bar-123", jsonDotToUnder("foo-bar-123")) + require.Equal(t, "123", jsonDotToUnder(123)) + require.Equal(t, "", jsonDotToUnder(nil)) +} + +func Test_Pluralize(t *testing.T) { + require.Equal(t, "foo", Pluralize("foo", 1)) + require.Equal(t, "foos", Pluralize("foo", 0)) + require.Equal(t, "foos", Pluralize("foo", 2)) +} + +func Test_toString(t *testing.T) { + require.Equal(t, "foo", toString("foo")) + require.Equal(t, "foo", toString([]byte("foo"))) + require.Equal(t, "123", toString(123)) + require.Equal(t, "", toString(nil)) + + // Test error type + errMsg := randomString(10) + require.Equal(t, errMsg, toString(errorString(errMsg))) + + // Test fmt.Stringer + stringerMsg := randomString(10) + require.Equal(t, stringerMsg, toString(stringStringer(stringerMsg))) +} + +type errorString string + +func (e errorString) Error() string { + return string(e) +} + +type stringStringer string + +func (s stringStringer) String() string { + return string(s) +} + +func randomString(n int) string { + const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + b := make([]byte, n) + for i := range b { + b[i] = letterBytes[int(rand.Int32N(int32(len(letterBytes))))] + } + return string(b) +} diff --git a/pkg/template/renderer.go b/pkg/template/renderer.go new file mode 100644 index 0000000..faa60ee --- /dev/null +++ b/pkg/template/renderer.go @@ -0,0 +1,59 @@ +package template + +import ( + "bytes" + "fmt" + "strings" + "text/template" +) + +// DefaultTemplateName is the default template for file naming +const DefaultTemplateName = "{{.kind | lower}}-{{.metadata.name}}.yaml" + +// Renderer handles template rendering for file names +type Renderer struct { + tmpl *template.Template +} + +// New creates a new template renderer +func New(templateString string) (*Renderer, error) { + if templateString == "" { + templateString = DefaultTemplateName + } + + tmpl, err := template.New("filename").Funcs(GetTemplateFunctions()).Parse(templateString) + if err != nil { + return nil, fmt.Errorf("unable to parse template: %w", err) + } + + return &Renderer{ + tmpl: tmpl, + }, nil +} + +// Execute renders a template with the given data +func (r *Renderer) Execute(data any) (string, error) { + var buf bytes.Buffer + if err := r.tmpl.Execute(&buf, data); err != nil { + return "", improveExecError(err) + } + + // Get the rendered filename + name := strings.TrimSpace(buf.String()) + + // Fix for text/template Go issue #24963, as well as removing any linebreaks + name = strings.NewReplacer("", "", "\n", "").Replace(name) + + return name, nil +} + +// improveExecError enhances template execution error messages. +// This uses string comparisons since the Go template engine does not return +// typed error messages. +func improveExecError(err error) error { + if strings.Contains(err.Error(), "can't evaluate field") { + return fmt.Errorf("%w (this usually means the field does not exist in the YAML)", err) + } + + return err +} diff --git a/pkg/template/renderer_test.go b/pkg/template/renderer_test.go new file mode 100644 index 0000000..54cfee0 --- /dev/null +++ b/pkg/template/renderer_test.go @@ -0,0 +1,179 @@ +package template + +import ( + "testing" + "text/template" + + "github.com/stretchr/testify/require" +) + +func TestNew(t *testing.T) { + tests := []struct { + name string + templateString string + wantErr bool + }{ + { + name: "empty template string uses default", + templateString: "", + wantErr: false, + }, + { + name: "valid template string", + templateString: "{{.kind}}-{{.metadata.name}}", + wantErr: false, + }, + { + name: "invalid template string", + templateString: "{{.kind", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + renderer, err := New(tt.templateString) + if tt.wantErr { + require.Error(t, err) + require.Nil(t, renderer) + return + } + + require.NoError(t, err) + require.NotNil(t, renderer) + + if tt.templateString == "" { + require.NotNil(t, renderer.tmpl) + } else { + require.NotNil(t, renderer.tmpl) + } + }) + } +} + +func TestRenderer_Execute(t *testing.T) { + tests := []struct { + name string + templateString string + data map[string]any + want string + wantErr bool + }{ + { + name: "simple template", + templateString: "{{.kind}}-{{.metadata.name}}", + data: map[string]any{ + "kind": "Deployment", + "metadata": map[string]any{ + "name": "nginx", + }, + }, + want: "Deployment-nginx", + wantErr: false, + }, + { + name: "handle missing fields", + templateString: "{{.kind}}-{{.metadata.name}}", + data: map[string]any{ + "kind": "Deployment", + // metadata is missing + }, + want: "Deployment-", + wantErr: false, + }, + { + name: "error accessing non-existent field", + templateString: "{{.kind}}-{{.missing.field}}", + data: map[string]any{ + "kind": "Deployment", + }, + want: "Deployment-", + wantErr: false, + }, + { + name: "handle replacement", + templateString: "{{.kind}}-{{.nonexistent}}", + data: map[string]any{ + "kind": "Deployment", + }, + want: "Deployment-", + wantErr: false, + }, + { + name: "trim spaces", + templateString: "{{ .kind }}-{{ .metadata.name }}", + data: map[string]any{ + "kind": "Deployment", + "metadata": map[string]any{ + "name": "nginx", + }, + }, + want: "Deployment-nginx", + wantErr: false, + }, + { + name: "handle line breaks", + templateString: "{{.kind}}\n{{.metadata.name}}", + data: map[string]any{ + "kind": "Deployment", + "metadata": map[string]any{ + "name": "nginx", + }, + }, + want: "Deploymentnginx", + wantErr: false, + }, + { + name: "with template functions", + templateString: "{{.kind | lower}}-{{.metadata.name}}", + data: map[string]any{ + "kind": "Deployment", + "metadata": map[string]any{ + "name": "nginx", + }, + }, + want: "deployment-nginx", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + renderer, err := New(tt.templateString) + require.NoError(t, err) + + got, err := renderer.Execute(tt.data) + if tt.wantErr { + require.Error(t, err) + return + } + + require.NoError(t, err) + require.Equal(t, tt.want, got) + }) + } +} + +func TestRenderer_ExecuteErrorHandling(t *testing.T) { + // Create a template that will trigger an error when a non-map is treated as a map + tmpl, err := template.New("filename").Funcs(GetTemplateFunctions()).Parse("{{ .metadata.name.invalid }}") + require.NoError(t, err) + + renderer := &Renderer{tmpl: tmpl} + + // Create data with metadata.name as a string (not a map) + data := map[string]any{ + "metadata": map[string]any{ + "name": "test-name", // This is a string, not a map + }, + } + + // Execute should fail because we're trying to access .invalid on a string value + result, err := renderer.Execute(data) + + // Verify the error behavior + require.Equal(t, "", result) + require.Error(t, err) + require.Contains(t, err.Error(), "can't evaluate field invalid") + require.Contains(t, err.Error(), "this usually means the field does not exist in the YAML") +} diff --git a/slice/errors.go b/slice/errors.go deleted file mode 100644 index 5f6ed10..0000000 --- a/slice/errors.go +++ /dev/null @@ -1,53 +0,0 @@ -package slice - -import ( - "fmt" - "strings" -) - -type strictModeSkipErr struct { - fieldName string -} - -func (s *strictModeSkipErr) Error() string { - return fmt.Sprintf( - "resource does not have a Kubernetes %q field or the field is invalid or empty", s.fieldName, - ) -} - -type skipErr struct { - name string - kind string -} - -func (e *skipErr) Error() string { - return fmt.Sprintf("resource %s %q is configured to be skipped", e.kind, e.name) -} - -const nonK8sHelper = `the file has no Kubernetes metadata: it is most likely a non-Kubernetes YAML file, you can skip it with --skip-non-k8s` - -type cantFindFieldErr struct { - fieldName string - fileCount int - meta kubeObjectMeta -} - -func (e *cantFindFieldErr) Error() string { - var sb strings.Builder - - sb.WriteString(fmt.Sprintf( - "unable to find Kubernetes %q field in file %d", - e.fieldName, e.fileCount, - )) - - if e.meta.empty() { - sb.WriteString(": " + nonK8sHelper) - } else { - sb.WriteString(fmt.Sprintf( - ": processed details: kind %q, name %q, apiVersion %q", - e.meta.Kind, e.meta.Name, e.meta.APIVersion, - )) - } - - return sb.String() -} diff --git a/slice/errors_test.go b/slice/errors_test.go deleted file mode 100644 index b3faf1d..0000000 --- a/slice/errors_test.go +++ /dev/null @@ -1,21 +0,0 @@ -package slice - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -func TestErrorsInterface(t *testing.T) { - require.Implementsf(t, (*error)(nil), &strictModeSkipErr{}, "strictModeSkipErr should implement error") - require.Implementsf(t, (*error)(nil), &skipErr{}, "skipErr should implement error") - require.Implementsf(t, (*error)(nil), &cantFindFieldErr{}, "cantFindFieldErr should implement error") -} - -func requireErrorIf(t *testing.T, wantErr bool, err error) { - if wantErr { - require.Error(t, err) - } else { - require.NoError(t, err) - } -} diff --git a/slice/execute.go b/slice/execute.go index 584c865..112fb7f 100644 --- a/slice/execute.go +++ b/slice/execute.go @@ -7,6 +7,11 @@ import ( "io" "os" "path/filepath" + + "github.com/patrickdappollonio/kubectl-slice/pkg/errors" + "github.com/patrickdappollonio/kubectl-slice/pkg/files" + "github.com/patrickdappollonio/kubectl-slice/pkg/kubernetes" + "github.com/patrickdappollonio/kubectl-slice/pkg/template" ) const ( @@ -37,11 +42,11 @@ func (s *Split) processSingleFile(file []byte) error { meta, err := s.parseYAMLManifest(file) if err != nil { switch err.(type) { - case *skipErr: + case *errors.SkipErr: s.log.Printf("Skipping file %d: %s", s.fileCount, err.Error()) return nil - case *strictModeSkipErr: + case *errors.StrictModeSkipErr: s.log.Printf("Skipping file %d: %s", s.fileCount, err.Error()) return nil @@ -51,29 +56,29 @@ func (s *Split) processSingleFile(file []byte) error { } existentData, position := []byte(nil), -1 - for pos := 0; pos < len(s.filesFound); pos++ { - if s.filesFound[pos].filename == meta.filename { - existentData = s.filesFound[pos].data + for pos := range s.filesFound { + if s.filesFound[pos].Filename == meta.Filename { + existentData = s.filesFound[pos].Data position = pos break } } if position == -1 { - s.log.Printf("Got nonexistent file. Adding it to the list: %s", meta.filename) - s.filesFound = append(s.filesFound, yamlFile{ - filename: meta.filename, - meta: meta.meta, - data: file, + s.log.Printf("Got nonexistent file. Adding it to the list: %s", meta.Filename) + s.filesFound = append(s.filesFound, kubernetes.YAMLFile{ + Filename: meta.Filename, + Meta: meta.Meta, + Data: file, }) } else { - s.log.Printf("Got existent file. Appending to original buffer: %s", meta.filename) + s.log.Printf("Got existent file. Appending to original buffer: %s", meta.Filename) existentData = append(existentData, []byte("\n---\n")...) existentData = append(existentData, file...) - s.filesFound[position] = yamlFile{ - filename: meta.filename, - meta: meta.meta, - data: existentData, + s.filesFound[position] = kubernetes.YAMLFile{ + Filename: meta.Filename, + Meta: meta.Meta, + Data: existentData, } } @@ -85,7 +90,7 @@ func (s *Split) scan() error { // duplicated files, we need to store them somewhere to, later, save them // to files s.fileCount = 0 - s.filesFound = make([]yamlFile, 0) + s.filesFound = make([]kubernetes.YAMLFile, 0) // We can totally create a single decoder then decode using that, however, // we want to maintain 1:1 exactly the same declaration as the YAML originally @@ -159,7 +164,7 @@ func (s *Split) store() error { // Check if the directory exists and if it does, prune it if _, err := os.Stat(s.opts.OutputDirectory); !os.IsNotExist(err) { s.log.Printf("Pruning output directory %q", s.opts.OutputDirectory) - if err := deleteFolderContents(s.opts.OutputDirectory); err != nil { + if err := files.DeleteFolderContents(s.opts.OutputDirectory); err != nil { return fmt.Errorf("unable to prune output directory %q: %w", s.opts.OutputDirectory, err) } s.log.Printf("Output directory %q pruned", s.opts.OutputDirectory) @@ -172,8 +177,8 @@ func (s *Split) store() error { for _, v := range s.filesFound { s.fileCount++ - fullpath := filepath.Join(s.opts.OutputDirectory, v.filename) - fileLength := len(v.data) + fullpath := filepath.Join(s.opts.OutputDirectory, v.Filename) + fileLength := len(v.Data) s.log.Printf("Handling file %q: %d bytes long.", fullpath, fileLength) @@ -191,18 +196,18 @@ func (s *Split) store() error { s.WriteStdout("# File: %s (%d bytes)", fullpath, fileLength) } - s.WriteStdout("%s", v.data) + s.WriteStdout("%s", v.Data) continue default: - local := make([]byte, 0, len(v.data)+4) + local := make([]byte, 0, len(v.Data)+4) // If the user wants to include the triple dash, add it // at the beginning of the file - if s.opts.IncludeTripleDash && !bytes.Equal(v.data, []byte("---")) { - local = append([]byte("---\n"), v.data...) + if s.opts.IncludeTripleDash && !bytes.Equal(v.Data, []byte("---")) { + local = append([]byte("---\n"), v.Data...) } else { - local = append(local, v.data...) + local = append(local, v.Data...) } // do nothing, handling below @@ -217,13 +222,13 @@ func (s *Split) store() error { switch { case s.opts.DryRun: - s.WriteStderr("%d %s generated (dry-run)", s.fileCount, pluralize("file", s.fileCount)) + s.WriteStderr("%d %s generated (dry-run)", s.fileCount, template.Pluralize("file", s.fileCount)) case s.opts.OutputToStdout: - s.WriteStderr("%d %s parsed to stdout.", s.fileCount, pluralize("file", s.fileCount)) + s.WriteStderr("%d %s parsed to stdout.", s.fileCount, template.Pluralize("file", s.fileCount)) default: - s.WriteStderr("%d %s generated.", s.fileCount, pluralize("file", s.fileCount)) + s.WriteStderr("%d %s generated.", s.fileCount, template.Pluralize("file", s.fileCount)) } return nil @@ -231,12 +236,14 @@ func (s *Split) store() error { func (s *Split) sort() { if s.opts.SortByKind { - s.filesFound = sortYAMLsByKind(s.filesFound) + s.filesFound = kubernetes.SortByKind(s.filesFound) } } -// Execute runs the process according to the split.Options provided. This will -// generate the files in the given directory. +// Execute processes YAML files containing Kubernetes resources and splits them into +// individual files according to the configured Options. It handles the complete workflow +// from scanning input sources, filtering resources based on criteria, to saving the +// resulting files in the specified output location. func (s *Split) Execute() error { if err := s.scan(); err != nil { return err @@ -251,7 +258,7 @@ func (s *Split) writeToFile(path string, data []byte) error { // Since a single Go Template File Name might render different folder prefixes, // we need to ensure they're all created. if err := os.MkdirAll(filepath.Dir(path), folderChmod); err != nil { - return fmt.Errorf("unable to create output folder for file %q: %w", path, err) + return fmt.Errorf("unable to create directory for file %q: %w", path, err) } // Open the file as read/write, create the file if it doesn't exist, and if diff --git a/slice/execute_test.go b/slice/execute_test.go index cafb6d0..f009e0a 100644 --- a/slice/execute_test.go +++ b/slice/execute_test.go @@ -2,23 +2,30 @@ package slice import ( "io" + "log" "os" "path/filepath" "testing" - "text/template" - local "github.com/patrickdappollonio/kubectl-slice/slice/template" + "github.com/patrickdappollonio/kubectl-slice/pkg/kubernetes" + "github.com/patrickdappollonio/kubectl-slice/pkg/logger" + "github.com/patrickdappollonio/kubectl-slice/pkg/template" "github.com/stretchr/testify/require" ) func TestExecute_processSingleFile(t *testing.T) { + type yamlFileOutputTest struct { + filename string + meta kubernetes.ObjectMeta + } + tests := []struct { name string fields Options fileInput string wantErr bool wantFilterErr bool - fileOutput *yamlFile + fileOutput *yamlFileOutputTest }{ { name: "basic pod", @@ -29,25 +36,23 @@ kind: Pod metadata: name: nginx-ingress `, - fileOutput: &yamlFile{ + fileOutput: &yamlFileOutputTest{ filename: "pod-nginx-ingress.yaml", - meta: kubeObjectMeta{ + meta: kubernetes.ObjectMeta{ APIVersion: "v1", Kind: "Pod", Name: "nginx-ingress", }, }, }, - // ---------------------------------------------------------------- { name: "empty file", fields: Options{}, fileInput: `---`, - fileOutput: &yamlFile{ + fileOutput: &yamlFileOutputTest{ filename: "-.yaml", }, }, - // ---------------------------------------------------------------- { name: "include kind", fields: Options{ @@ -59,9 +64,9 @@ kind: Pod metadata: name: nginx-ingress `, - fileOutput: &yamlFile{ + fileOutput: &yamlFileOutputTest{ filename: "pod-nginx-ingress.yaml", - meta: kubeObjectMeta{ + meta: kubernetes.ObjectMeta{ APIVersion: "v1", Kind: "Pod", Name: "nginx-ingress", @@ -79,16 +84,15 @@ kind: Pod metadata: name: nginx-ingress `, - fileOutput: &yamlFile{ + fileOutput: &yamlFileOutputTest{ filename: "pod-nginx-ingress.yaml", - meta: kubeObjectMeta{ + meta: kubernetes.ObjectMeta{ APIVersion: "v1", Kind: "Pod", Name: "nginx-ingress", }, }, }, - // ---------------------------------------------------------------- { name: "non kubernetes files skipped using strict kubernetes", fields: Options{ @@ -100,7 +104,6 @@ metadata: # `, }, - // ---------------------------------------------------------------- { name: "non kubernetes file", fields: Options{}, @@ -109,18 +112,16 @@ metadata: # This is a comment # `, - fileOutput: &yamlFile{ + fileOutput: &yamlFileOutputTest{ filename: "-.yaml", }, }, - // ---------------------------------------------------------------- { name: "file with only spaces", fields: Options{}, fileInput: ` `, }, - // ---------------------------------------------------------------- { name: "skipping kind", fields: Options{ @@ -133,7 +134,6 @@ metadata: name: foobar `, }, - // ---------------------------------------------------------------- { name: "skipping name", fields: Options{ @@ -146,14 +146,12 @@ metadata: name: foobar `, }, - // ---------------------------------------------------------------- { name: "invalid YAML", fields: Options{}, fileInput: `kind: "Namespace`, wantErr: true, }, - // ---------------------------------------------------------------- { name: "invalid YAML", fields: Options{}, @@ -183,24 +181,40 @@ kind: "Namespace t.Parallel() s := &Split{ - opts: tt.fields, - log: nolog, - template: template.Must(template.New("split").Funcs(local.Functions).Parse(DefaultTemplateName)), + opts: tt.fields, + log: log.New(io.Discard, "", 0), + template: func() *template.Renderer { + tmpl, err := template.New(template.DefaultTemplateName) + if err != nil { + t.Fatalf("unable to create template: %s", err) + } + + return tmpl + }(), fileCount: 1, } - requireErrorIf(t, tt.wantFilterErr, s.validateFilters()) - requireErrorIf(t, tt.wantErr, s.processSingleFile([]byte(tt.fileInput))) + if tt.wantFilterErr { + require.Error(t, s.validateFilters()) + return + } + require.NoError(t, s.validateFilters()) + + if tt.wantErr { + require.Error(t, s.processSingleFile([]byte(tt.fileInput))) + return + } + require.NoError(t, s.processSingleFile([]byte(tt.fileInput))) if tt.fileOutput != nil { require.Lenf(t, s.filesFound, 1, "expected 1 file from list, got %d", len(s.filesFound)) current := s.filesFound[0] - require.Equal(t, tt.fileOutput.filename, current.filename) - require.Equal(t, tt.fileOutput.meta.APIVersion, current.meta.APIVersion) - require.Equal(t, tt.fileOutput.meta.Kind, current.meta.Kind) - require.Equal(t, tt.fileOutput.meta.Name, current.meta.Name) - require.Equal(t, tt.fileOutput.meta.Namespace, current.meta.Namespace) + require.Equal(t, tt.fileOutput.filename, current.Filename) + require.Equal(t, tt.fileOutput.meta.APIVersion, current.Meta.APIVersion) + require.Equal(t, tt.fileOutput.meta.Kind, current.Meta.Kind) + require.Equal(t, tt.fileOutput.meta.Name, current.Meta.Name) + require.Equal(t, tt.fileOutput.meta.Namespace, current.Meta.Namespace) } else { require.Lenf(t, s.filesFound, 0, "expected 0 files from list, got %d", len(s.filesFound)) } @@ -210,7 +224,7 @@ kind: "Namespace func TestExecute_writeToFileCases(t *testing.T) { tempdir := t.TempDir() - s := &Split{log: nolog} + s := &Split{log: logger.NOOPLogger} t.Run("write new file", func(tt *testing.T) { t.Parallel() @@ -321,7 +335,7 @@ metadata: require.NoError(t, err, "error found while writing input file") s, err := New(Options{ - GoTemplate: DefaultTemplateName, + GoTemplate: template.DefaultTemplateName, IncludeTripleDash: tt.includeDashes, InputFile: filepath.Join(tdinput, "input.yaml"), OutputDirectory: tdoutput, diff --git a/slice/options.go b/slice/options.go new file mode 100644 index 0000000..aea8816 --- /dev/null +++ b/slice/options.go @@ -0,0 +1,56 @@ +package slice + +import "io" + +// Options configures how the Split operation processes and outputs Kubernetes resources. +// It controls input sources, output destinations, filtering criteria, and formatting options. +type Options struct { + // Stdout is the writer used for standard output + Stdout io.Writer + // Stderr is the writer used for error and debug output + Stderr io.Writer + + InputFile string // the name of the input file to be read + InputFolder string // the name of the input folder to be read + InputFolderExt []string // the extensions of the files to be read + Recurse bool // if true, the input folder will be read recursively + OutputDirectory string // the path to the directory where the files will be stored + PruneOutputDir bool // if true, the output directory will be pruned before writing the files + OutputToStdout bool // if true, the output will be written to stdout instead of a file + GoTemplate string // the go template code to render the file names + DryRun bool // if true, no files are created + DebugMode bool // enables debug mode + Quiet bool // disables all writing to stdout/stderr + IncludeTripleDash bool // include the "---" separator on resources sliced + + // IncludedKinds is a list of Kubernetes kinds to include (all others will be excluded) + IncludedKinds []string + // ExcludedKinds is a list of Kubernetes kinds to exclude (all others will be included) + ExcludedKinds []string + // IncludedNames is a list of resource names to include (all others will be excluded) + IncludedNames []string + // ExcludedNames is a list of resource names to exclude (all others will be included) + ExcludedNames []string + // Included is a list of "kind/name" combinations to include + Included []string + // Excluded is a list of "kind/name" combinations to exclude + Excluded []string + + // StrictKubernetes when enabled, any YAMLs that don't contain at least an "apiVersion", "kind" and "metadata.name" are excluded + StrictKubernetes bool + + // SortByKind enables sorting of resources by Kubernetes kind importance (follows Helm install order) + SortByKind bool + // RemoveFileComments removes auto-generated comments from output files + RemoveFileComments bool + + // AllowEmptyNames permits resources without a metadata.name field + AllowEmptyNames bool + // AllowEmptyKinds permits resources without a kind field + AllowEmptyKinds bool + + // IncludedGroups is a list of API groups to include (all others will be excluded) + IncludedGroups []string + // ExcludedGroups is a list of API groups to exclude (all others will be included) + ExcludedGroups []string +} diff --git a/slice/output.go b/slice/output.go index 622e17d..c67cda6 100644 --- a/slice/output.go +++ b/slice/output.go @@ -2,6 +2,7 @@ package slice import "fmt" +// WriteStderr writes formatted output to stderr unless quiet mode is enabled func (s *Split) WriteStderr(format string, args ...interface{}) { if s.opts.Quiet { return @@ -10,6 +11,7 @@ func (s *Split) WriteStderr(format string, args ...interface{}) { fmt.Fprintf(s.opts.Stderr, format+"\n", args...) } +// WriteStdout writes formatted output to stdout func (s *Split) WriteStdout(format string, args ...interface{}) { fmt.Fprintf(s.opts.Stdout, format+"\n", args...) } diff --git a/slice/plural.go b/slice/plural.go deleted file mode 100644 index a73d10e..0000000 --- a/slice/plural.go +++ /dev/null @@ -1,8 +0,0 @@ -package slice - -func pluralize(s string, n int) string { - if n == 1 { - return s - } - return s + "s" -} diff --git a/slice/process.go b/slice/process.go index 5d0db82..b51760b 100644 --- a/slice/process.go +++ b/slice/process.go @@ -1,50 +1,42 @@ package slice import ( - "bytes" "fmt" "path/filepath" "strings" "github.com/mb0/glob" "gopkg.in/yaml.v3" + + "github.com/patrickdappollonio/kubectl-slice/pkg/errors" + "github.com/patrickdappollonio/kubectl-slice/pkg/kubernetes" ) // parseYAMLManifest parses a single YAML file as received by contents. It also renders the // template needed to generate its name -func (s *Split) parseYAMLManifest(contents []byte) (yamlFile, error) { +func (s *Split) parseYAMLManifest(contents []byte) (*kubernetes.YAMLFile, error) { // All resources we'll handle are Kubernetes manifest, and even those who are lists, // they're still Kubernetes Objects of type List, so we can use a map manifest := make(map[string]interface{}) s.log.Println("Parsing YAML from buffer up to this point") if err := yaml.Unmarshal(contents, &manifest); err != nil { - return yamlFile{}, fmt.Errorf("unable to parse YAML file number %d: %w", s.fileCount, err) + return nil, fmt.Errorf("unable to parse YAML file number %d: %w", s.fileCount, err) } // Render the name to a buffer using the Go Template s.log.Println("Rendering filename template from Go Template") - var buf bytes.Buffer - if err := s.template.Execute(&buf, manifest); err != nil { - return yamlFile{}, fmt.Errorf("unable to render file name for YAML file number %d: %w", s.fileCount, improveExecError(err)) + name, err := s.template.Execute(manifest) + if err != nil { + return nil, fmt.Errorf("unable to render file name for YAML file number %d: %w", s.fileCount, err) } // Check if file contains the required Kubernetes metadata - k8smeta := checkKubernetesBasics(manifest) + k8smeta := kubernetes.ExtractMetadata(manifest) // Check if at least the three fields are not empty - if s.opts.StrictKubernetes { - if k8smeta.APIVersion == "" { - return yamlFile{}, &strictModeSkipErr{fieldName: "apiVersion"} - } - - if k8smeta.Kind == "" { - return yamlFile{}, &strictModeSkipErr{fieldName: "kind"} - } - - if k8smeta.Name == "" { - return yamlFile{}, &strictModeSkipErr{fieldName: "metadata.name"} - } + if err := kubernetes.ValidateRequiredFields(k8smeta, s.opts.StrictKubernetes); err != nil { + return nil, err } // Check before handling if we're about to filter resources @@ -58,55 +50,53 @@ func (s *Split) parseYAMLManifest(contents []byte) (yamlFile, error) { // Check if we have a Kubernetes kind and we're requesting inclusion or exclusion if k8smeta.Kind == "" && !s.opts.AllowEmptyKinds && (hasIncluded || hasExcluded) { - return yamlFile{}, &cantFindFieldErr{fieldName: "kind", fileCount: s.fileCount, meta: k8smeta} + return nil, &errors.CantFindFieldErr{FieldName: "kind", FileCount: s.fileCount, Meta: k8smeta} } // Check if we have a Kubernetes name and we're requesting inclusion or exclusion if k8smeta.Name == "" && !s.opts.AllowEmptyNames && (hasIncluded || hasExcluded) { - return yamlFile{}, &cantFindFieldErr{fieldName: "metadata.name", fileCount: s.fileCount, meta: k8smeta} + return nil, &errors.CantFindFieldErr{FieldName: "metadata.name", FileCount: s.fileCount, Meta: k8smeta} } // We need to check if the file should be skipped if hasExcluded || hasIncluded { // If we're working with including only specific resources, then filter by them if hasIncluded && !inSliceIgnoreCaseGlob(s.opts.Included, fmt.Sprintf("%s/%s", k8smeta.Kind, k8smeta.Name)) { - return yamlFile{}, &skipErr{kind: "kind/name", name: fmt.Sprintf("%s/%s", k8smeta.Kind, k8smeta.Name)} + return nil, &errors.SkipErr{Kind: "kind/name", Name: fmt.Sprintf("%s/%s", k8smeta.Kind, k8smeta.Name)} } // Otherwise exclude resources based on the parameter received if hasExcluded && inSliceIgnoreCaseGlob(s.opts.Excluded, fmt.Sprintf("%s/%s", k8smeta.Kind, k8smeta.Name)) { - return yamlFile{}, &skipErr{kind: "kind/name", name: fmt.Sprintf("%s/%s", k8smeta.Kind, k8smeta.Name)} + return nil, &errors.SkipErr{Kind: "kind/name", Name: fmt.Sprintf("%s/%s", k8smeta.Kind, k8smeta.Name)} } } if len(s.opts.IncludedGroups) > 0 || len(s.opts.ExcludedGroups) > 0 { if k8smeta.APIVersion == "" { - return yamlFile{}, &cantFindFieldErr{fieldName: "apiVersion", fileCount: s.fileCount, meta: k8smeta} + return nil, &errors.CantFindFieldErr{FieldName: "apiVersion", FileCount: s.fileCount, Meta: k8smeta} } var groups []string - if len(s.opts.IncludedGroups) > 0 { + included := len(s.opts.IncludedGroups) > 0 + if included { groups = s.opts.IncludedGroups - } else if len(s.opts.ExcludedGroups) > 0 { + } else { groups = s.opts.ExcludedGroups } - if err := checkGroup(k8smeta, groups, len(s.opts.IncludedGroups) > 0); err != nil { - return yamlFile{}, &skipErr{} + if err := kubernetes.CheckGroupInclusion(k8smeta, groups, included); err != nil { + return nil, err } } - // Trim the file name - name := strings.TrimSpace(buf.String()) - - // Fix for text/template Go issue #24963, as well as removing any linebreaks - name = strings.NewReplacer("", "", "\n", "").Replace(name) - if str := strings.TrimSuffix(name, filepath.Ext(name)); str == "" { - return yamlFile{}, fmt.Errorf("file name rendered will yield no file name for YAML file number %d (original name: %q, metadata: %v)", s.fileCount, name, k8smeta) + return nil, fmt.Errorf("file name rendered will yield no file name for YAML file number %d (original name: %q, metadata: %v)", s.fileCount, name, k8smeta) } - return yamlFile{filename: name, meta: k8smeta}, nil + return &kubernetes.YAMLFile{ + Filename: name, + Meta: k8smeta, + }, nil } // inSliceIgnoreCase checks if a string is in a slice, ignoring case @@ -137,55 +127,3 @@ func inSliceIgnoreCaseGlob(slice []string, expected string) bool { return false } - -// checkStringInMap checks if a string is in a map, and if not, returns an error -func checkStringInMap(local map[string]interface{}, key string) string { - iface, found := local[key] - - if !found { - return "" - } - - str, ok := iface.(string) - if !ok { - return "" - } - - return str -} - -// checkKubernetesBasics check if the minimum required keys are there for a Kubernetes Object -func checkKubernetesBasics(manifest map[string]interface{}) kubeObjectMeta { - var metadata kubeObjectMeta - - metadata.APIVersion = checkStringInMap(manifest, "apiVersion") - metadata.Kind = checkStringInMap(manifest, "kind") - - if md, found := manifest["metadata"]; found { - metadata.Name = checkStringInMap(md.(map[string]interface{}), "name") - metadata.Namespace = checkStringInMap(md.(map[string]interface{}), "namespace") - } - - return metadata -} - -func checkGroup(objmeta kubeObjectMeta, groupName []string, included bool) error { - - for _, group := range groupName { - if included { - if objmeta.GetGroupFromAPIVersion() == strings.ToLower(group) { - return nil - } - } else { - if objmeta.GetGroupFromAPIVersion() == strings.ToLower(group) { - return &skipErr{} - } - } - } - - if included { - return &skipErr{} - } else { - return nil - } -} diff --git a/slice/process_test.go b/slice/process_test.go index d4296c1..5486b51 100644 --- a/slice/process_test.go +++ b/slice/process_test.go @@ -2,9 +2,10 @@ package slice import ( "testing" - "text/template" - local "github.com/patrickdappollonio/kubectl-slice/slice/template" + "github.com/patrickdappollonio/kubectl-slice/pkg/kubernetes" + "github.com/patrickdappollonio/kubectl-slice/pkg/logger" + "github.com/patrickdappollonio/kubectl-slice/pkg/template" "github.com/stretchr/testify/require" ) @@ -56,10 +57,7 @@ func Test_inSliceIgnoreCase(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - - if got := inSliceIgnoreCase(tt.args.slice, tt.args.expected); got != tt.want { - t.Errorf("inSliceIgnoreCase() = %v, want %v", got, tt.want) - } + require.Equal(t, tt.want, inSliceIgnoreCase(tt.args.slice, tt.args.expected)) }) } } @@ -174,7 +172,7 @@ func Test_checkStringInMap(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - require.Equal(t, tt.want, checkStringInMap(tt.args.local, tt.args.key)) + require.Equal(t, tt.want, kubernetes.CheckStringInMap(tt.args.local, tt.args.key)) }) } } @@ -187,7 +185,7 @@ func Test_checkKubernetesBasics(t *testing.T) { tests := []struct { name string args args - want kubeObjectMeta + want kubernetes.ObjectMeta }{ { name: "all fields found", @@ -200,7 +198,7 @@ func Test_checkKubernetesBasics(t *testing.T) { }, }, }, - want: kubeObjectMeta{ + want: kubernetes.ObjectMeta{ Kind: "Deployment", APIVersion: "apps/v1", Name: "foo", @@ -211,7 +209,7 @@ func Test_checkKubernetesBasics(t *testing.T) { args: args{ manifest: map[string]interface{}{}, }, - want: kubeObjectMeta{}, + want: kubernetes.ObjectMeta{}, }, { name: "missing metadata fields", @@ -221,7 +219,7 @@ func Test_checkKubernetesBasics(t *testing.T) { "apiVersion": "apps/v1", }, }, - want: kubeObjectMeta{ + want: kubernetes.ObjectMeta{ Kind: "Deployment", APIVersion: "apps/v1", }, @@ -231,7 +229,7 @@ func Test_checkKubernetesBasics(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - require.Equal(t, tt.want, checkKubernetesBasics(tt.args.manifest)) + require.Equal(t, tt.want, *kubernetes.ExtractMetadata(tt.args.manifest)) }) } } @@ -241,7 +239,7 @@ func TestSplit_parseYAMLManifest(t *testing.T) { name string contents []byte strictKube bool - want yamlFile + want *kubernetes.YAMLFile wantErr bool }{ { @@ -253,9 +251,9 @@ metadata: name: foo namespace: bar `), - want: yamlFile{ - filename: "service-foo.yaml", - meta: kubeObjectMeta{ + want: &kubernetes.YAMLFile{ + Filename: "service-foo.yaml", + Meta: &kubernetes.ObjectMeta{ APIVersion: "v1", Kind: "Service", Name: "foo", @@ -271,9 +269,9 @@ kind: Service metadata: name: foo `), - want: yamlFile{ - filename: "service-foo.yaml", - meta: kubeObjectMeta{ + want: &kubernetes.YAMLFile{ + Filename: "service-foo.yaml", + Meta: &kubernetes.ObjectMeta{ APIVersion: "v1", Kind: "Service", Name: "foo", @@ -291,9 +289,9 @@ metadata: namespace: bar `), strictKube: true, - want: yamlFile{ - filename: "service-foo.yaml", - meta: kubeObjectMeta{ + want: &kubernetes.YAMLFile{ + Filename: "service-foo.yaml", + Meta: &kubernetes.ObjectMeta{ APIVersion: "v1", Kind: "Service", Name: "foo", @@ -310,9 +308,9 @@ metadata: name: foo `), strictKube: true, - want: yamlFile{ - filename: "service-foo.yaml", - meta: kubeObjectMeta{ + want: &kubernetes.YAMLFile{ + Filename: "service-foo.yaml", + Meta: &kubernetes.ObjectMeta{ APIVersion: "v1", Kind: "Service", Name: "foo", @@ -343,13 +341,25 @@ kind: Foo t.Parallel() s := &Split{ - log: nolog, - template: template.Must(template.New(DefaultTemplateName).Funcs(local.Functions).Parse(DefaultTemplateName)), + log: logger.NOOPLogger, + template: func() *template.Renderer { + tmpl, err := template.New(template.DefaultTemplateName) + if err != nil { + t.Fatalf("unable to create template: %s", err) + } + + return tmpl + }(), } s.opts.StrictKubernetes = tt.strictKube got, err := s.parseYAMLManifest(tt.contents) - requireErrorIf(t, tt.wantErr, err) + + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) require.Equal(t, tt.want, got) }) } @@ -363,7 +373,7 @@ func TestSplit_parseYamlManifestAllowingEmpties(t *testing.T) { skipEmptyKind bool includeKind string includeName string - want yamlFile + want *kubernetes.YAMLFile wantErr bool }{ { @@ -374,9 +384,9 @@ kind: Foo metadata: name: bar `), - want: yamlFile{ - filename: "foo-bar.yaml", - meta: kubeObjectMeta{APIVersion: "v1", Kind: "Foo", Name: "bar"}, + want: &kubernetes.YAMLFile{ + Filename: "foo-bar.yaml", + Meta: &kubernetes.ObjectMeta{APIVersion: "v1", Kind: "Foo", Name: "bar"}, }, includeKind: "Foo", skipEmptyName: false, @@ -390,9 +400,9 @@ kind: "" metadata: name: bar `), - want: yamlFile{ - filename: "-bar.yaml", - meta: kubeObjectMeta{APIVersion: "v1", Kind: "", Name: "bar"}, + want: &kubernetes.YAMLFile{ + Filename: "-bar.yaml", + Meta: &kubernetes.ObjectMeta{APIVersion: "v1", Kind: "", Name: "bar"}, }, includeName: "bar", skipEmptyName: false, @@ -418,9 +428,9 @@ kind: Foo metadata: name: "" `), - want: yamlFile{ - filename: "foo-.yaml", - meta: kubeObjectMeta{APIVersion: "v1", Kind: "Foo", Name: ""}, + want: &kubernetes.YAMLFile{ + Filename: "foo-.yaml", + Meta: &kubernetes.ObjectMeta{APIVersion: "v1", Kind: "Foo", Name: ""}, }, includeKind: "Foo", skipEmptyName: true, @@ -444,8 +454,11 @@ kind: Foo t.Parallel() s := &Split{ - log: nolog, - template: template.Must(template.New(DefaultTemplateName).Funcs(local.Functions).Parse(DefaultTemplateName)), + log: logger.NOOPLogger, + template: func() *template.Renderer { + tmpl, _ := template.New(template.DefaultTemplateName) + return tmpl + }(), } if len(tt.includeKind) > 0 { @@ -464,8 +477,11 @@ kind: Foo } got, err := s.parseYAMLManifest(tt.contents) - requireErrorIf(t, tt.wantErr, err) - t.Logf("got: %#v", got) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) require.Equal(t, tt.want, got) }) } diff --git a/slice/split.go b/slice/split.go index 4569778..7479a04 100644 --- a/slice/split.go +++ b/slice/split.go @@ -5,29 +5,22 @@ import ( "io" "log" "os" - "text/template" -) - -const DefaultTemplateName = "{{.kind | lower}}-{{.metadata.name}}.yaml" -// Logger is the interface used by Split to log debug messages -// and it's satisfied by Go's log.Logger -type Logger interface { - Printf(format string, v ...interface{}) - SetOutput(w io.Writer) - Println(v ...interface{}) -} + "github.com/patrickdappollonio/kubectl-slice/pkg/kubernetes" + "github.com/patrickdappollonio/kubectl-slice/pkg/logger" + "github.com/patrickdappollonio/kubectl-slice/pkg/template" +) // Split is a Kubernetes Split instance. Each instance has its own template // used to generate the resource names when saving to disk. Because of this, // avoid reusing the same instance of Split type Split struct { opts Options - log Logger - template *template.Template + log logger.Logger + template *template.Renderer data *bytes.Buffer - filesFound []yamlFile + filesFound []kubernetes.YAMLFile fileCount int } @@ -57,39 +50,3 @@ func New(opts Options) (*Split, error) { return s, nil } - -// Options holds the Split options used when splitting Kubernetes resources -type Options struct { - Stdout io.Writer - Stderr io.Writer - - InputFile string // the name of the input file to be read - InputFolder string // the name of the input folder to be read - InputFolderExt []string // the extensions of the files to be read - Recurse bool // if true, the input folder will be read recursively - OutputDirectory string // the path to the directory where the files will be stored - PruneOutputDir bool // if true, the output directory will be pruned before writing the files - OutputToStdout bool // if true, the output will be written to stdout instead of a file - GoTemplate string // the go template code to render the file names - DryRun bool // if true, no files are created - DebugMode bool // enables debug mode - Quiet bool // disables all writing to stdout/stderr - IncludeTripleDash bool // include the "---" separator on resources sliced - - IncludedKinds []string - ExcludedKinds []string - IncludedNames []string - ExcludedNames []string - Included []string - Excluded []string - StrictKubernetes bool // if true, any YAMLs that don't contain at least an "apiVersion", "kind" and "metadata.name" will be excluded - - SortByKind bool // if true, it will sort the resources by kind - RemoveFileComments bool // if true, it will remove comments generated by the app from the generated files - - AllowEmptyNames bool - AllowEmptyKinds bool - - IncludedGroups []string - ExcludedGroups []string -} diff --git a/slice/split_test.go b/slice/split_test.go index 117c98f..4e46675 100644 --- a/slice/split_test.go +++ b/slice/split_test.go @@ -7,17 +7,11 @@ import ( "strings" "testing" + "github.com/patrickdappollonio/kubectl-slice/pkg/logger" + "github.com/patrickdappollonio/kubectl-slice/pkg/template" "github.com/stretchr/testify/require" ) -type noopLogger struct{} - -func (noopLogger) Println(...interface{}) {} -func (noopLogger) SetOutput(_ io.Writer) {} -func (noopLogger) Printf(_ string, _ ...interface{}) {} - -var nolog = &noopLogger{} - func TestEndToEnd(t *testing.T) { cases := []struct { name string @@ -28,7 +22,7 @@ func TestEndToEnd(t *testing.T) { { name: "end to end", inputFile: "full.yaml", - template: DefaultTemplateName, + template: template.DefaultTemplateName, expectedFiles: []string{ "full/-.yaml", "full/deployment-hello-docker.yaml", @@ -66,7 +60,7 @@ func TestEndToEnd(t *testing.T) { require.NoError(tt, err, "not expecting an error") require.NoError(tt, slice.Execute(), "not expecting an error on Execute()") - slice.log = nolog + slice.log = logger.NOOPLogger files, err := os.ReadDir(dir) require.NoError(tt, err, "not expecting an error on ReadDir()") diff --git a/slice/template.go b/slice/template.go index b65142c..3842fcc 100644 --- a/slice/template.go +++ b/slice/template.go @@ -1,45 +1,26 @@ package slice import ( - "errors" "fmt" - "strings" - "text/template" - local "github.com/patrickdappollonio/kubectl-slice/slice/template" + "github.com/patrickdappollonio/kubectl-slice/pkg/template" ) +// compileTemplate creates and caches a template renderer using the template string +// from Options.GoTemplate. The method is idempotent - if a template is already compiled, +// it will skip compilation and return nil. func (s *Split) compileTemplate() error { + if s.template != nil { + s.log.Println("Template already compiled, skipping") + return nil + } + s.log.Printf("About to compile template: %q", s.opts.GoTemplate) - t, err := template.New("split").Funcs(local.Functions).Parse(s.opts.GoTemplate) + tmpl, err := template.New(s.opts.GoTemplate) if err != nil { - return fmt.Errorf("file name template parse failed: %w", improveExecError(err)) + return fmt.Errorf("file name template parse failed: %w", err) } - s.template = t + s.template = tmpl return nil } - -func improveExecError(err error) error { - // Before you start screaming because I'm handling an error using strings, - // consider that there's a longstanding open TODO to improve template.ExecError - // to be more meaningful: - // https://github.com/golang/go/blob/go1.17/src/text/template/exec.go#L107-L109 - - if _, ok := err.(template.ExecError); !ok { - if !strings.HasPrefix(err.Error(), "template:") { - return err - } - } - - s := err.Error() - - if pos := strings.LastIndex(s, ":"); pos >= 0 { - return template.ExecError{ - Name: "", - Err: errors.New(strings.TrimSpace(s[pos+1:])), - } - } - - return err -} diff --git a/slice/template/funcs.go b/slice/template/funcs.go deleted file mode 100644 index cc23d54..0000000 --- a/slice/template/funcs.go +++ /dev/null @@ -1,275 +0,0 @@ -package template - -import ( - "bytes" - "crypto/sha1" - "crypto/sha256" - "encoding/hex" - "fmt" - "html/template" - "os" - "regexp" - "strings" - - "golang.org/x/text/cases" - "golang.org/x/text/language" - "gopkg.in/yaml.v3" -) - -var Functions = template.FuncMap{ - "lower": jsonLower, - "lowercase": jsonLower, - "uppercase": jsonUpper, - "upper": jsonUpper, - "title": jsonTitle, - "sprintf": fmt.Sprintf, - "printf": fmt.Sprintf, - "trim": jsonTrimSpace, - "trimPrefix": jsonTrimPrefix, - "trimSuffix": jsonTrimSuffix, - "default": fnDefault, - "sha1sum": sha1sum, - "sha256sum": sha256sum, - "str": strJSON, - "required": jsonRequired, - "env": env, - "replace": jsonReplace, - "alphanumify": jsonAlphanumify, - "alphanumdash": jsonAlphanumdash, - "dottodash": jsonDotToDash, - "dottounder": jsonDotToUnder, - "index": mapValueByIndex, - "indexOrEmpty": mapValueByIndexOrEmpty, -} - -// mapValueByIndexOrEmpty retrieves a value from a map without returning an error if the key is not found. -func mapValueByIndexOrEmpty(index string, m map[string]interface{}) interface{} { - if m == nil { - return "" - } - - if index == "" { - return "" - } - - v, ok := m[index] - if !ok { - return "" - } - - return v -} - -// mapValueByIndex returns the value of the map at the given index -func mapValueByIndex(index string, m map[string]interface{}) (interface{}, error) { - if m == nil { - return nil, fmt.Errorf("map is nil") - } - - if index == "" { - return nil, fmt.Errorf("index is empty") - } - - v, ok := m[index] - if !ok { - return nil, fmt.Errorf("map does not contain index %q", index) - } - - return v, nil -} - -// strJSON converts a value received from JSON/YAML to string. Since not all data -// types are supported for JSON, we can limit to just the primitives that are -// not arrays, objects or null; see: -// https://pkg.go.dev/encoding/json#Unmarshal -func strJSON(val interface{}) (string, error) { - if val == nil { - return "", nil - } - - switch a := val.(type) { - case string: - return a, nil - - case bool: - return fmt.Sprintf("%v", a), nil - - case float64: - return fmt.Sprintf("%v", a), nil - - default: - return "", fmt.Errorf("unexpected data type %T -- can't convert to string", val) - } -} - -var ( - reAlphaNum = regexp.MustCompile(`[^a-zA-Z0-9]+`) - reSlugify = regexp.MustCompile(`[^a-zA-Z0-9-]+`) -) - -func jsonAlphanumify(val interface{}) (string, error) { - s, err := strJSON(val) - if err != nil { - return "", err - } - - return reAlphaNum.ReplaceAllString(s, ""), nil -} - -func jsonAlphanumdash(val interface{}) (string, error) { - s, err := strJSON(val) - if err != nil { - return "", err - } - - return reSlugify.ReplaceAllString(s, ""), nil -} - -func jsonDotToDash(val interface{}) (string, error) { - s, err := strJSON(val) - if err != nil { - return "", err - } - - return strings.ReplaceAll(s, ".", "-"), nil -} - -func jsonDotToUnder(val interface{}) (string, error) { - s, err := strJSON(val) - if err != nil { - return "", err - } - - return strings.ReplaceAll(s, ".", "_"), nil -} - -func jsonReplace(search, replace string, val interface{}) (string, error) { - s, err := strJSON(val) - if err != nil { - return "", err - } - - return strings.NewReplacer(search, replace).Replace(s), nil -} - -func env(key string) string { - return os.Getenv(strings.ToUpper(key)) -} - -func jsonRequired(val interface{}) (interface{}, error) { - if val == nil { - return nil, fmt.Errorf("argument is marked as required, but it renders to empty") - } - - s, err := strJSON(val) - if err != nil { - return nil, err - } - - if s == "" { - return nil, fmt.Errorf("argument is marked as required, but it renders to empty or it's an object or an unsupported type") - } - - return val, nil -} - -func jsonLower(val interface{}) (string, error) { - s, err := strJSON(val) - if err != nil { - return "", err - } - - return strings.ToLower(s), nil -} - -func jsonUpper(val interface{}) (string, error) { - s, err := strJSON(val) - if err != nil { - return "", err - } - - return strings.ToUpper(s), nil -} - -func jsonTitle(val interface{}) (string, error) { - s, err := strJSON(val) - if err != nil { - return "", err - } - - return cases.Title(language.Und).String(s), nil -} - -func jsonTrimSpace(val interface{}) (string, error) { - s, err := strJSON(val) - if err != nil { - return "", err - } - - return strings.TrimSpace(s), nil -} - -func jsonTrimPrefix(prefix string, val interface{}) (string, error) { - s, err := strJSON(val) - if err != nil { - return "", err - } - - return strings.TrimPrefix(s, prefix), nil -} - -func jsonTrimSuffix(suffix string, val interface{}) (string, error) { - s, err := strJSON(val) - if err != nil { - return "", err - } - - return strings.TrimSuffix(s, suffix), nil -} - -func fnDefault(defval, val interface{}) (string, error) { - v, err := strJSON(val) - if err != nil { - return "", err - } - - dv, err := strJSON(defval) - if err != nil { - return "", err - } - - if v != "" { - return v, nil - } - - return dv, nil -} - -func altStrJSON(val interface{}) (string, error) { - var buf bytes.Buffer - if err := yaml.NewEncoder(&buf).Encode(val); err != nil { - return "", fmt.Errorf("unable to encode object to YAML: %w", err) - } - - return buf.String(), nil -} - -func sha256sum(input interface{}) (string, error) { - s, err := altStrJSON(input) - if err != nil { - return "", err - } - - hash := sha256.Sum256([]byte(s)) - return hex.EncodeToString(hash[:]), nil -} - -func sha1sum(input interface{}) (string, error) { - s, err := altStrJSON(input) - if err != nil { - return "", err - } - - hash := sha1.Sum([]byte(s)) - return hex.EncodeToString(hash[:]), nil -} diff --git a/slice/template/funcs_test.go b/slice/template/funcs_test.go deleted file mode 100644 index 2847675..0000000 --- a/slice/template/funcs_test.go +++ /dev/null @@ -1,659 +0,0 @@ -package template - -import ( - "math/rand" - "os" - "strings" - "testing" - "time" - - "github.com/stretchr/testify/require" -) - -func Test_mapValueByIndexEmpty(t *testing.T) { - tests := []struct { - name string - index string - m map[string]interface{} - want interface{} - }{ - { - name: "nil map", - index: "foo", - m: nil, - want: "", - }, - { - name: "empty index", - index: "", - m: map[string]interface{}{}, - want: "", - }, - { - name: "fetch existent field", - index: "foo", - m: map[string]interface{}{ - "foo": "bar", - }, - want: "bar", - }, - { - name: "fetch nonexistent field", - index: "baz", - m: map[string]interface{}{ - "foo": "bar", - }, - want: "", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := mapValueByIndexOrEmpty(tt.index, tt.m) - require.Equal(t, tt.want, got) - }) - } -} - -func Test_mapValueByIndex(t *testing.T) { - tests := []struct { - name string - index string - m map[string]interface{} - want interface{} - wantErr bool - }{ - { - name: "nil map", - index: "foo", - m: nil, - wantErr: true, - }, - { - name: "empty index", - index: "", - m: map[string]interface{}{}, - wantErr: true, - }, - { - name: "fetch existent field", - index: "foo", - m: map[string]interface{}{ - "foo": "bar", - }, - want: "bar", - }, - { - name: "fetch nonexistent field", - index: "baz", - m: map[string]interface{}{ - "foo": "bar", - }, - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := mapValueByIndex(tt.index, tt.m) - requireErrorIf(t, tt.wantErr, err) - require.Equal(t, tt.want, got) - }) - } -} - -func Test_strJSON(t *testing.T) { - tests := []struct { - name string - val interface{} - want string - wantErr bool - }{ - { - name: "string conversion", - val: "foo", - want: "foo", - }, - { - name: "bool true conversion", - val: true, - want: "true", - }, - { - name: "bool false conversion", - val: false, - want: "false", - }, - { - name: "float64 conversion", - val: 3.141592654, - want: "3.141592654", - }, - { - name: "incorrect data type conversion", - val: []string{}, - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := strJSON(tt.val) - requireErrorIf(t, tt.wantErr, err) - require.Equal(t, tt.want, got) - }) - } -} - -func Test_jsonAlphanumify(t *testing.T) { - tests := []struct { - name string - val interface{} - want string - wantErr bool - }{ - { - name: "remove dots", - val: "foo.bar", - want: "foobar", - }, - { - name: "remove dots and slashes", - val: "foo.bar/baz", - want: "foobarbaz", - }, - { - name: "remove all special characters", - val: "foo.bar/baz!@#$%^&*()_+-=[]{}\\|;:'\",<.>/?", - want: "foobarbaz", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := jsonAlphanumify(tt.val) - requireErrorIf(t, tt.wantErr, err) - require.Equal(t, tt.want, got) - }) - } -} - -func Test_jsonAlphanumdash(t *testing.T) { - tests := []struct { - name string - val interface{} - want string - wantErr bool - }{ - { - name: "remove dots", - val: "foo.bar-baz", - want: "foobar-baz", - }, - { - name: "remove dots and slashes", - val: "foo.bar/baz-daz", - want: "foobarbaz-daz", - }, - { - name: "remove all special characters", - val: "foo.bar/baz!@#$%^&*()_+=[]{}\\|;:'\",<.>/?-daz", - want: "foobarbaz-daz", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := jsonAlphanumdash(tt.val) - requireErrorIf(t, tt.wantErr, err) - require.Equal(t, tt.want, got) - }) - } -} - -func Test_jsonDotToDash(t *testing.T) { - tests := []struct { - name string - val interface{} - want string - wantErr bool - }{ - { - name: "single dot", - val: "foo.bar", - want: "foo-bar", - }, - { - name: "multi dot", - val: "foo...bar", - want: "foo---bar", - }, - { - name: "no dot", - val: "foobar", - want: "foobar", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := jsonDotToDash(tt.val) - requireErrorIf(t, tt.wantErr, err) - require.Equal(t, tt.want, got) - }) - } -} - -func Test_jsonDotToUnder(t *testing.T) { - tests := []struct { - name string - val interface{} - want string - wantErr bool - }{ - { - name: "single dot", - val: "foo.bar", - want: "foo_bar", - }, - { - name: "multi dot", - val: "foo...bar", - want: "foo___bar", - }, - { - name: "no dot", - val: "foobar", - want: "foobar", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := jsonDotToUnder(tt.val) - requireErrorIf(t, tt.wantErr, err) - require.Equal(t, tt.want, got) - }) - } -} - -func Test_jsonReplace(t *testing.T) { - type args struct { - search string - replace string - val interface{} - } - - tests := []struct { - name string - args args - want string - wantErr bool - }{ - { - name: "basic replace", - args: args{ - search: "foo", - replace: "bar", - val: "foobar", - }, - want: "barbar", - }, - { - name: "non existent replacement", - args: args{ - search: "foo", - replace: "bar", - val: "barbar", - }, - want: "barbar", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := jsonReplace(tt.args.search, tt.args.replace, tt.args.val) - requireErrorIf(t, tt.wantErr, err) - require.Equal(t, tt.want, got) - }) - } -} - -func Test_env(t *testing.T) { - letters := []rune("abcdefghijklmnopqrstuvwxyz") - - randSeq := func(n int) string { - rnd := rand.New(rand.NewSource(time.Now().UnixNano())) - b := make([]rune, n) - for i := range b { - b[i] = letters[rnd.Intn(len(letters))] - } - return string(b) - } - - type args struct { - key string - env map[string]string - } - - tests := []struct { - name string - args args - want string - }{ - { - name: "generic", - args: args{ - key: "foo", - env: map[string]string{ - "foo": "bar", - }, - }, - want: "bar", - }, - { - name: "non-existent", - args: args{ - key: "fooofooo", - env: map[string]string{ - "foo": "bar", - }, - }, - want: "", - }, - { - name: "case insensitive key", - args: args{ - key: "FOO", - env: map[string]string{ - "foo": "bar", - }, - }, - want: "bar", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - prefix := randSeq(10) + "_" - - for k, v := range tt.args.env { - key := strings.ToUpper(prefix + k) - os.Setenv(key, v) - defer os.Unsetenv(key) - } - - require.Equal(t, tt.want, env(prefix+tt.args.key)) - }) - } -} - -func Test_jsonRequired(t *testing.T) { - tests := []struct { - name string - val interface{} - want interface{} - wantErr bool - }{ - { - name: "no error", - val: true, // any non empty value will do - want: true, - }, - { - name: "empty item", - val: nil, - wantErr: true, - }, - { - name: "unsupported item", - val: struct{ name string }{name: "foo"}, - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := jsonRequired(tt.val) - requireErrorIf(t, tt.wantErr, err) - require.Equal(t, tt.want, got) - }) - } -} - -func Test_jsonLowerAndUpper(t *testing.T) { - type args struct { - val interface{} - prefix string - suffix string - } - tests := []struct { - name string - args args - lower string - upper string - title string - trimmed string - noprefix string - nosuffix string - wantErr bool - }{ - { - name: "generic first test ", - args: args{ - val: "foo bar baz ", - prefix: "foo ", - suffix: " baz ", - }, - lower: "foo bar baz ", - upper: "FOO BAR BAZ ", - title: "Foo Bar Baz ", - trimmed: "foo bar baz", - noprefix: "bar baz ", - nosuffix: "foo bar", - }, - { - name: "invalid value type", - args: args{ - val: struct{}{}, - }, - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - lowered, err := jsonLower(tt.args.val) - requireErrorIf(t, tt.wantErr, err) - - uppered, err := jsonUpper(tt.args.val) - requireErrorIf(t, tt.wantErr, err) - - titled, err := jsonTitle(tt.args.val) - requireErrorIf(t, tt.wantErr, err) - - trimspaced, err := jsonTrimSpace(tt.args.val) - requireErrorIf(t, tt.wantErr, err) - - prefixed, err := jsonTrimPrefix(tt.args.prefix, tt.args.val) - requireErrorIf(t, tt.wantErr, err) - - suffixed, err := jsonTrimSuffix(tt.args.suffix, tt.args.val) - requireErrorIf(t, tt.wantErr, err) - - require.Equal(t, tt.lower, lowered) - require.Equal(t, tt.upper, uppered) - require.Equal(t, tt.title, titled) - require.Equal(t, tt.trimmed, trimspaced) - require.Equal(t, tt.noprefix, prefixed) - require.Equal(t, tt.nosuffix, suffixed) - }) - } -} - -func Test_fnDefault(t *testing.T) { - type args struct { - defval interface{} - val interface{} - } - - tests := []struct { - name string - args args - want string - wantErr bool - }{ - { - name: "non value use default", - args: args{ - defval: "foo", - val: nil, - }, - want: "foo", - }, - { - name: "existent value skip default", - args: args{ - defval: "foo", - val: "bar", - }, - want: "bar", - }, - { - name: "inconvertible value type use default", - args: args{ - val: []struct{}{}, - }, - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := fnDefault(tt.args.defval, tt.args.val) - requireErrorIf(t, tt.wantErr, err) - require.Equal(t, tt.want, got) - }) - } -} - -func Test_altStrJSON(t *testing.T) { - tests := []struct { - name string - val interface{} - want string - wantErr bool - }{ - { - name: "default", - val: "foo", - want: "foo\n", - }, - { - name: "convert to object", - val: map[string]interface{}{ - "foo": "bar", - }, - want: "foo: bar\n", - }, - { - name: "convert to array", - val: []interface{}{ - "foo", - "bar", - }, - want: "- foo\n- bar\n", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := altStrJSON(tt.val) - requireErrorIf(t, tt.wantErr, err) - require.Equal(t, tt.want, got) - }) - } -} - -func Test_sha256sum(t *testing.T) { - tests := []struct { - name string - input interface{} - want string - wantErr bool - }{ - { - name: "generic string", - input: "foo", - want: "b5bb9d8014a0f9b1d61e21e796d78dccdf1352f23cd32812f4850b878ae4944c", - }, - { - name: "generic array", - input: []interface{}{ - "foo", - "bar", - }, - want: "d50869a9dcda5fe0b6413eb366dec11d0eb7226c5569f7de8dad1fcd917e5480", - }, - { - name: "generic object", - input: map[string]interface{}{ - "foo": "bar", - }, - want: "1dabc4e3cbbd6a0818bd460f3a6c9855bfe95d506c74726bc0f2edb0aecb1f4e", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := sha256sum(tt.input) - requireErrorIf(t, tt.wantErr, err) - require.Equal(t, tt.want, got) - }) - } -} - -func Test_sha1sum(t *testing.T) { - tests := []struct { - name string - input interface{} - want string - wantErr bool - }{ - { - name: "generic string", - input: "foo", - want: "f1d2d2f924e986ac86fdf7b36c94bcdf32beec15", - }, - { - name: "generic array", - input: []interface{}{ - "foo", - "bar", - }, - want: "c11e6a294774caece9f882726f0f85c72691bb19", - }, - { - name: "generic object", - input: map[string]interface{}{ - "foo": "bar", - }, - want: "7e109797e472ae8cbd20d7a4d7e231a96324377c", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := sha1sum(tt.input) - requireErrorIf(t, tt.wantErr, err) - require.Equal(t, tt.want, got) - }) - } -} - -func requireErrorIf(t *testing.T, wantErr bool, err error) { - if wantErr { - require.Error(t, err) - } else { - require.NoError(t, err) - } -} diff --git a/slice/template_test.go b/slice/template_test.go index 5e966d7..bff5e0f 100644 --- a/slice/template_test.go +++ b/slice/template_test.go @@ -2,6 +2,9 @@ package slice import ( "testing" + + "github.com/patrickdappollonio/kubectl-slice/pkg/logger" + "github.com/stretchr/testify/require" ) func TestTemplate_compileTemplate(t *testing.T) { @@ -33,10 +36,14 @@ func TestTemplate_compileTemplate(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := &Split{opts: tt.opts, log: nolog} + s := &Split{opts: tt.opts, log: logger.NOOPLogger} err := s.compileTemplate() - requireErrorIf(t, tt.wantErr, err) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) }) } } diff --git a/slice/validate.go b/slice/validate.go index 8fdd109..d450a92 100644 --- a/slice/validate.go +++ b/slice/validate.go @@ -5,6 +5,8 @@ import ( "fmt" "path/filepath" "regexp" + + "github.com/patrickdappollonio/kubectl-slice/pkg/files" ) var ( @@ -28,7 +30,7 @@ func (s *Split) init() error { if s.opts.InputFile != "" { s.log.Printf("Loading file %s", s.opts.InputFile) var err error - buf, err = loadfile(s.opts.InputFile) + buf, err = files.LoadFile(s.opts.InputFile) if err != nil { return err } @@ -45,7 +47,7 @@ func (s *Split) init() error { s.log.Printf("Loading folder %q", s.opts.InputFolder) var err error var count int - buf, count, err = loadfolder(exts, s.opts.InputFolder, s.opts.Recurse) + buf, count, err = files.LoadFolder(exts, s.opts.InputFolder, s.opts.Recurse) if err != nil { return err }