From bb33dee6b682413c26a7995b2ae055ba3e1faf1a Mon Sep 17 00:00:00 2001 From: Gunjan Vyas Date: Fri, 6 Dec 2024 10:53:58 +0530 Subject: [PATCH] [WIP] bundle: Parallel download and decompression This commit does the following: - Return a reader from the bundle Download function. - Use the reader to stream the bytes to Extract function. This commit replaces grab client with the net/http client to ensure that the bytes are streamed come in correct order to the Extract func. Currently, only zst decompression is being used in the UncompressWithReader function as it is the primary compression algorithm being used in crc. --- cmd/crc-embedder/cmd/embed.go | 2 +- pkg/crc/cache/cache.go | 3 +- pkg/crc/image/image.go | 1 + pkg/crc/machine/bundle/metadata.go | 24 ++++-- pkg/crc/machine/bundle/repository.go | 42 ++++++++- pkg/crc/machine/start.go | 6 +- pkg/crc/preflight/preflight_checks_common.go | 6 +- pkg/download/download.go | 91 +++++++------------- pkg/extract/extract.go | 14 +++ test/extended/util/util.go | 2 +- 10 files changed, 112 insertions(+), 79 deletions(-) diff --git a/cmd/crc-embedder/cmd/embed.go b/cmd/crc-embedder/cmd/embed.go index 24c5f7ca7d..714ae07420 100644 --- a/cmd/crc-embedder/cmd/embed.go +++ b/cmd/crc-embedder/cmd/embed.go @@ -163,7 +163,7 @@ func downloadDataFiles(goos string, components []string, destDir string) ([]stri if !shouldDownload(components, componentName) { continue } - filename, err := download.Download(context.TODO(), dl.url, destDir, dl.permissions, nil) + _, filename, err := download.Download(context.TODO(), dl.url, destDir, dl.permissions, nil) if err != nil { return nil, err } diff --git a/pkg/crc/cache/cache.go b/pkg/crc/cache/cache.go index b27da0a5b6..43dc206445 100644 --- a/pkg/crc/cache/cache.go +++ b/pkg/crc/cache/cache.go @@ -154,7 +154,8 @@ func (c *Cache) getExecutable(destDir string) (string, error) { destPath := filepath.Join(destDir, archiveName) err := embed.Extract(archiveName, destPath) if err != nil { - return download.Download(context.TODO(), c.archiveURL, destDir, 0600, nil) + _, filename, err := download.Download(context.TODO(), c.archiveURL, destDir, 0600, nil) + return filename, err } return destPath, err diff --git a/pkg/crc/image/image.go b/pkg/crc/image/image.go index 8ad08d8bcd..e1df77f6df 100644 --- a/pkg/crc/image/image.go +++ b/pkg/crc/image/image.go @@ -74,6 +74,7 @@ func (img *imageHandler) copyImage(ctx context.Context, destPath string, reportW if ctx == nil { panic("ctx is nil, this should not happen") } + manifestData, err := copy.Image(ctx, policyContext, destRef, srcRef, ©.Options{ ReportWriter: reportWriter, diff --git a/pkg/crc/machine/bundle/metadata.go b/pkg/crc/machine/bundle/metadata.go index 40f6c4f38f..c86dd8c573 100644 --- a/pkg/crc/machine/bundle/metadata.go +++ b/pkg/crc/machine/bundle/metadata.go @@ -344,43 +344,49 @@ func getVerifiedHash(url string, file string) (string, error) { return "", fmt.Errorf("%s hash is missing or shasums are malformed", file) } -func downloadDefault(ctx context.Context, preset crcPreset.Preset) (string, error) { +func downloadDefault(ctx context.Context, preset crcPreset.Preset) (io.Reader, string, error) { downloadInfo, err := getBundleDownloadInfo(preset) if err != nil { - return "", err + return nil, "", err } return downloadInfo.Download(ctx, constants.GetDefaultBundlePath(preset), 0664) } -func Download(ctx context.Context, preset crcPreset.Preset, bundleURI string, enableBundleQuayFallback bool) (string, error) { +func Download(ctx context.Context, preset crcPreset.Preset, bundleURI string, enableBundleQuayFallback bool) (io.Reader, string, error) { // If we are asked to download // ~/.crc/cache/crc_podman_libvirt_4.1.1.crcbundle, this means we want // are downloading the default bundle for this release. This uses a // different codepath from user-specified URIs as for the default // bundles, their sha256sums are known and can be checked. + var reader io.Reader if bundleURI == constants.GetDefaultBundlePath(preset) { switch preset { case crcPreset.OpenShift, crcPreset.Microshift: - downloadedBundlePath, err := downloadDefault(ctx, preset) + var err error + var downloadedBundlePath string + reader, downloadedBundlePath, err = downloadDefault(ctx, preset) if err != nil && enableBundleQuayFallback { logging.Info("Unable to download bundle from mirror, falling back to quay") - return image.PullBundle(ctx, constants.GetDefaultBundleImageRegistry(preset)) + bundle, err := image.PullBundle(ctx, constants.GetDefaultBundleImageRegistry(preset)) + return nil, bundle, err } - return downloadedBundlePath, err + return reader, downloadedBundlePath, err case crcPreset.OKD: fallthrough default: - return image.PullBundle(ctx, constants.GetDefaultBundleImageRegistry(preset)) + bundle, err := image.PullBundle(ctx, constants.GetDefaultBundleImageRegistry(preset)) + return nil, bundle, err } } switch { case strings.HasPrefix(bundleURI, "http://"), strings.HasPrefix(bundleURI, "https://"): return download.Download(ctx, bundleURI, constants.MachineCacheDir, 0644, nil) case strings.HasPrefix(bundleURI, "docker://"): - return image.PullBundle(ctx, bundleURI) + bundle, err := image.PullBundle(ctx, bundleURI) + return nil, bundle, err } // the `bundleURI` parameter turned out to be a local path - return bundleURI, nil + return reader, bundleURI, nil } type Version struct { diff --git a/pkg/crc/machine/bundle/repository.go b/pkg/crc/machine/bundle/repository.go index 46ea2bf424..86c8f0fb08 100644 --- a/pkg/crc/machine/bundle/repository.go +++ b/pkg/crc/machine/bundle/repository.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "io" "os" "path/filepath" "runtime" @@ -124,6 +125,36 @@ func (bundle *CrcBundleInfo) createSymlinkOrCopyPodmanRemote(binDir string) erro return bundle.copyExecutableFromBundle(binDir, PodmanExecutable, constants.PodmanRemoteExecutableName) } +func (repo *Repository) ExtractWithReader(ctx context.Context, reader io.Reader, path string) error { + logging.Debugf("Extracting bundle from reader") + bundleName := filepath.Base(path) + + tmpDir := filepath.Join(repo.CacheDir, "tmp-extract") + _ = os.RemoveAll(tmpDir) // clean up before using it + defer func() { + _ = os.RemoveAll(tmpDir) // clean up after using it + }() + + if _, err := extract.UncompressWithReader(ctx, reader, tmpDir); err != nil { + return err + } + + bundleBaseDir := GetBundleNameWithoutExtension(bundleName) + bundleDir := filepath.Join(repo.CacheDir, bundleBaseDir) + _ = os.RemoveAll(bundleDir) + err := crcerrors.Retry(context.Background(), time.Minute, func() error { + if err := os.Rename(filepath.Join(tmpDir, bundleBaseDir), bundleDir); err != nil { + return &crcerrors.RetriableError{Err: err} + } + return nil + }, 5*time.Second) + if err != nil { + return err + } + + return os.Chmod(bundleDir, 0755) +} + func (repo *Repository) Extract(ctx context.Context, path string) error { bundleName := filepath.Base(path) @@ -198,8 +229,15 @@ func Use(bundleName string) (*CrcBundleInfo, error) { return defaultRepo.Use(bundleName) } -func Extract(ctx context.Context, path string) (*CrcBundleInfo, error) { - if err := defaultRepo.Extract(ctx, path); err != nil { +func Extract(ctx context.Context, reader io.Reader, path string) (*CrcBundleInfo, error) { + var err error + if reader == nil { + err = defaultRepo.Extract(ctx, path) + } else { + err = defaultRepo.ExtractWithReader(ctx, reader, path) + } + + if err != nil { return nil, err } return defaultRepo.Get(filepath.Base(path)) diff --git a/pkg/crc/machine/start.go b/pkg/crc/machine/start.go index 17eacb2c33..ae3a8576ef 100644 --- a/pkg/crc/machine/start.go +++ b/pkg/crc/machine/start.go @@ -48,13 +48,15 @@ func getCrcBundleInfo(ctx context.Context, preset crcPreset.Preset, bundleName, return bundleInfo, nil } logging.Debugf("Failed to load bundle %s: %v", bundleName, err) + logging.Infof("Downloading bundle: %s...", bundleName) - bundlePath, err = bundle.Download(ctx, preset, bundlePath, enableBundleQuayFallback) + reader, bundlePath, err := bundle.Download(ctx, preset, bundlePath, enableBundleQuayFallback) if err != nil { return nil, err } + logging.Infof("Extracting bundle: %s...", bundleName) - if _, err := bundle.Extract(ctx, bundlePath); err != nil { + if _, err := bundle.Extract(ctx, reader, bundlePath); err != nil { return nil, err } return bundle.Use(bundleName) diff --git a/pkg/crc/preflight/preflight_checks_common.go b/pkg/crc/preflight/preflight_checks_common.go index 3c401ae89c..be764e17bc 100644 --- a/pkg/crc/preflight/preflight_checks_common.go +++ b/pkg/crc/preflight/preflight_checks_common.go @@ -3,6 +3,7 @@ package preflight import ( "context" "fmt" + "io" "os" "path/filepath" @@ -116,13 +117,14 @@ func fixBundleExtracted(bundlePath string, preset crcpreset.Preset, enableBundle return fmt.Errorf("Cannot create directory %s: %v", bundleDir, err) } var err error + var reader io.Reader logging.Infof("Downloading bundle: %s...", bundlePath) - if bundlePath, err = bundle.Download(context.TODO(), preset, bundlePath, enableBundleQuayFallback); err != nil { + if reader, bundlePath, err = bundle.Download(context.TODO(), preset, bundlePath, enableBundleQuayFallback); err != nil { return err } logging.Infof("Uncompressing %s", bundlePath) - if _, err := bundle.Extract(context.TODO(), bundlePath); err != nil { + if _, err := bundle.Extract(context.TODO(), reader, bundlePath); err != nil { if errors.Is(err, os.ErrNotExist) { return errors.Wrap(err, "Use `crc setup -b `") } diff --git a/pkg/download/download.go b/pkg/download/download.go index 7c1eef9ce1..16c97ee750 100644 --- a/pkg/download/download.go +++ b/pkg/download/download.go @@ -2,97 +2,66 @@ package download import ( "context" - "crypto/sha256" "encoding/hex" "fmt" "io" + "mime" "net/http" "net/url" "os" "path/filepath" - "time" + "github.com/cavaliergopher/grab/v3" "github.com/crc-org/crc/v2/pkg/crc/logging" "github.com/crc-org/crc/v2/pkg/crc/network/httpproxy" "github.com/crc-org/crc/v2/pkg/crc/version" - "github.com/crc-org/crc/v2/pkg/os/terminal" - - "github.com/cavaliergopher/grab/v3" - "github.com/cheggaaa/pb/v3" "github.com/pkg/errors" ) -func doRequest(client *grab.Client, req *grab.Request) (string, error) { - const minSizeForProgressBar = 100_000_000 - - resp := client.Do(req) - if resp.Size() < minSizeForProgressBar { - <-resp.Done - return resp.Filename, resp.Err() - } - - t := time.NewTicker(500 * time.Millisecond) - defer t.Stop() - var bar *pb.ProgressBar - if terminal.IsShowTerminalOutput() { - bar = pb.Start64(resp.Size()) - bar.Set(pb.Bytes, true) - // This is the same as the 'Default' template https://github.com/cheggaaa/pb/blob/224e0746e1e7b9c5309d6e2637264bfeb746d043/v3/preset.go#L8-L10 - // except that the 'per second' suffix is changed to '/s' (by default it is ' p/s' which is unexpected) - progressBarTemplate := `{{with string . "prefix"}}{{.}} {{end}}{{counters . }} {{bar . }} {{percent . }} {{speed . "%s/s" "??/s"}}{{with string . "suffix"}} {{.}}{{end}}` - bar.SetTemplateString(progressBarTemplate) - defer bar.Finish() - } - -loop: - for { - select { - case <-t.C: - if terminal.IsShowTerminalOutput() { - bar.SetCurrent(resp.BytesComplete()) - } - case <-resp.Done: - break loop - } - } - - return resp.Filename, resp.Err() -} - // Download function takes sha256sum as hex decoded byte // something like hex.DecodeString("33daf4c03f86120fdfdc66bddf6bfff4661c7ca11c5d") -func Download(ctx context.Context, uri, destination string, mode os.FileMode, sha256sum []byte) (string, error) { +func Download(ctx context.Context, uri, destination string, mode os.FileMode, _ []byte) (io.Reader, string, error) { logging.Debugf("Downloading %s to %s", uri, destination) - client := grab.NewClient() - client.UserAgent = version.UserAgent() - client.HTTPClient = &http.Client{Transport: httpproxy.HTTPTransport()} - req, err := grab.NewRequest(destination, uri) - if err != nil { - return "", errors.Wrapf(err, "unable to get request from %s", uri) - } - if ctx == nil { panic("ctx is nil, this should not happen") } + req, err := http.NewRequestWithContext(ctx, "GET", uri, nil) + + if err != nil { + return nil, "", errors.Wrapf(err, "unable to get request from %s", uri) + } + client := http.Client{Transport: &http.Transport{}} + req = req.WithContext(ctx) - if sha256sum != nil { - req.SetChecksum(sha256.New(), sha256sum, true) + resp, err := client.Do(req) + if err != nil { + return nil, "", err } - filename, err := doRequest(client, req) + var filename, dir string + if filepath.Ext(destination) == ".crcbundle" { + dir = filepath.Dir(destination) + } else { + dir = destination + } + if disposition, params, _ := mime.ParseMediaType(resp.Header.Get("Content-Disposition")); disposition == "attachment" { + filename = filepath.Join(dir, params["filename"]) + } else { + filename = filepath.Join(dir, filepath.Base(resp.Request.URL.Path)) + } + file, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, mode) if err != nil { - return "", err + return nil, "", err } if err := os.Chmod(filename, mode); err != nil { _ = os.Remove(filename) - return "", err + return nil, "", err } - logging.Debugf("Download saved to %v", filename) - return filename, nil + return io.TeeReader(resp.Body, file), filename, nil } // InMemory takes a URL and returns a ReadCloser object to the downloaded file @@ -138,10 +107,10 @@ func NewRemoteFile(uri, sha256sum string) *RemoteFile { } -func (r *RemoteFile) Download(ctx context.Context, bundlePath string, mode os.FileMode) (string, error) { +func (r *RemoteFile) Download(ctx context.Context, bundlePath string, mode os.FileMode) (io.Reader, string, error) { sha256bytes, err := hex.DecodeString(r.sha256sum) if err != nil { - return "", err + return nil, "", err } return Download(ctx, r.URI, bundlePath, mode, sha256bytes) } diff --git a/pkg/extract/extract.go b/pkg/extract/extract.go index e1ab4c3637..3441bc486a 100644 --- a/pkg/extract/extract.go +++ b/pkg/extract/extract.go @@ -32,6 +32,20 @@ func Uncompress(ctx context.Context, tarball, targetDir string) ([]string, error return uncompress(ctx, tarball, targetDir, nil, terminal.IsShowTerminalOutput()) } +func UncompressWithReader(ctx context.Context, reader io.Reader, targetDir string) ([]string, error) { + return uncompressWithReader(ctx, reader, targetDir, nil, terminal.IsShowTerminalOutput()) +} + +func uncompressWithReader(ctx context.Context, reader io.Reader, targetDir string, fileFilter func(string) bool, showProgress bool) ([]string, error) { + logging.Debugf("Uncompressing from reader to %s", targetDir) + + reader, err := zstd.NewReader(reader) + if err != nil { + return nil, err + } + return untar(ctx, reader, targetDir, fileFilter, showProgress) +} + func uncompress(ctx context.Context, tarball, targetDir string, fileFilter func(string) bool, showProgress bool) ([]string, error) { logging.Debugf("Uncompressing %s to %s", tarball, targetDir) diff --git a/test/extended/util/util.go b/test/extended/util/util.go index c24ff4ae30..11f026a30e 100644 --- a/test/extended/util/util.go +++ b/test/extended/util/util.go @@ -125,7 +125,7 @@ func DownloadBundle(bundleLocation string, bundleDestination string, bundleName return bundleDestination, err } - filename, err := download.Download(context.TODO(), bundleLocation, bundleDestination, 0644, nil) + _, filename, err := download.Download(context.TODO(), bundleLocation, bundleDestination, 0644, nil) fmt.Printf("Downloading bundle from %s to %s.\n", bundleLocation, bundleDestination) if err != nil { return "", err