diff --git a/internal/http/download_test.go b/internal/http/download_test.go index 4e50adb6df06..ccd259ccd4c2 100644 --- a/internal/http/download_test.go +++ b/internal/http/download_test.go @@ -15,7 +15,6 @@ import ( "strings" "sync/atomic" "testing" - "time" "github.com/k0sproject/k0s/internal/testutil" @@ -79,18 +78,21 @@ func TestDownload_ExcessContentLength(t *testing.T) { func TestDownload_CancelDownload(t *testing.T) { ctx, cancel := context.WithCancelCause(t.Context()) + requestDone := make(chan struct{}) baseURL := startFakeDownloadServer(t, false, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - for { - if _, err := w.Write([]byte(t.Name())); !assert.NoError(t, err) { - return - } - - select { - case <-r.Context().Done(): - return - case <-time.After(time.Microsecond): - } + defer close(requestDone) + if _, err := w.Write([]byte(t.Name())); !assert.NoError(t, err) { + return } + + // Need to flush here, otherwise the internal response buffering will + // prevent the client from receiving the data. + w.(http.Flusher).Flush() + + // Wait for the client to cancel the in-flight request. + ctx := r.Context() + <-ctx.Done() + assert.Same(t, context.Canceled, context.Cause(ctx), "HTTP request context wasn't canceled while writing response") })) err := internalhttp.Download(ctx, baseURL, internalio.WriterFunc(func(p []byte) (int, error) { @@ -100,6 +102,10 @@ func TestDownload_CancelDownload(t *testing.T) { assert.ErrorContains(t, err, "while downloading: ") assert.ErrorIs(t, err, assert.AnError) + + // Make sure the server's HTTP handler has finished + // so all the assertions have been made. + <-requestDone } func TestDownload_RedirectLoop(t *testing.T) {