diff --git a/pkg-new/upgrade/upgrade.go b/pkg-new/upgrade/upgrade.go index 44bc99358a..ef05dca0b1 100644 --- a/pkg-new/upgrade/upgrade.go +++ b/pkg-new/upgrade/upgrade.go @@ -2,6 +2,7 @@ package upgrade import ( "context" + "errors" "fmt" "reflect" "time" @@ -23,7 +24,8 @@ import ( "github.com/replicatedhq/embedded-cluster/pkg/runtimeconfig" "github.com/replicatedhq/embedded-cluster/pkg/support" "github.com/sirupsen/logrus" - "k8s.io/apimachinery/pkg/api/errors" + k8serrors "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/util/wait" "sigs.k8s.io/controller-runtime/pkg/client" ) @@ -347,9 +349,9 @@ func upgradeExtensions(ctx context.Context, cli client.Client, hcli helm.Client, func createAutopilotPlan(ctx context.Context, cli client.Client, rc runtimeconfig.RuntimeConfig, desiredVersion string, in *ecv1beta1.Installation, meta *ectypes.ReleaseMetadata, logger logrus.FieldLogger) error { var plan apv1b2.Plan okey := client.ObjectKey{Name: "autopilot"} - if err := cli.Get(ctx, okey, &plan); err != nil && !errors.IsNotFound(err) { + if err := cli.Get(ctx, okey, &plan); err != nil && !k8serrors.IsNotFound(err) { return fmt.Errorf("get upgrade plan: %w", err) - } else if errors.IsNotFound(err) { + } else if k8serrors.IsNotFound(err) { // if the kubernetes version has changed we create an upgrade command logger.WithField("version", desiredVersion).Info("Starting k0s autopilot upgrade plan") @@ -364,15 +366,43 @@ func createAutopilotPlan(ctx context.Context, cli client.Client, rc runtimeconfi } func waitForAutopilotPlan(ctx context.Context, cli client.Client, logger logrus.FieldLogger) (apv1b2.Plan, error) { - for { - var plan apv1b2.Plan - if err := cli.Get(ctx, client.ObjectKey{Name: "autopilot"}, &plan); err != nil { - return plan, fmt.Errorf("get upgrade plan: %w", err) + backoff := wait.Backoff{ + Duration: time.Second, + Factor: 2.0, + Steps: 75, // ~25 minutes with exponential backoff (1s, 2s, 4s, 8s, 16s, then 20s capped) + Cap: 20 * time.Second, + } + + var plan apv1b2.Plan + var lastErr error + + err := wait.ExponentialBackoffWithContext(ctx, backoff, func(ctx context.Context) (bool, error) { + err := cli.Get(ctx, client.ObjectKey{Name: "autopilot"}, &plan) + if err != nil { + lastErr = fmt.Errorf("get autopilot plan: %w", err) + return false, nil } + if autopilot.HasThePlanEnded(plan) { - return plan, nil + return true, nil } + logger.WithField("plan_id", plan.Spec.ID).Info("An autopilot upgrade is in progress") - time.Sleep(5 * time.Second) + return false, nil + }) + + if err != nil { + if errors.Is(err, context.Canceled) { + if lastErr != nil { + err = errors.Join(err, lastErr) + } + return apv1b2.Plan{}, err + } else if lastErr != nil { + return apv1b2.Plan{}, fmt.Errorf("timed out waiting for autopilot plan: %w", lastErr) + } else { + return apv1b2.Plan{}, fmt.Errorf("timed out waiting for autopilot plan") + } } + + return plan, nil } diff --git a/pkg-new/upgrade/upgrade_test.go b/pkg-new/upgrade/upgrade_test.go index 29c4e7eb54..7481eaede4 100644 --- a/pkg-new/upgrade/upgrade_test.go +++ b/pkg-new/upgrade/upgrade_test.go @@ -3,9 +3,13 @@ package upgrade import ( "context" "encoding/json" + "fmt" + "sync/atomic" "testing" + apv1b2 "github.com/k0sproject/k0s/pkg/apis/autopilot/v1beta2" k0sv1beta1 "github.com/k0sproject/k0s/pkg/apis/k0s/v1beta1" + "github.com/k0sproject/k0s/pkg/autopilot/controller/plans/core" ecv1beta1 "github.com/replicatedhq/embedded-cluster/kinds/apis/v1beta1" "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" @@ -369,3 +373,149 @@ config: }) } } + +func TestWaitForAutopilotPlan_Success(t *testing.T) { + logger := logrus.New() + logger.SetLevel(logrus.ErrorLevel) + + scheme := runtime.NewScheme() + require.NoError(t, apv1b2.Install(scheme)) + + plan := &apv1b2.Plan{ + ObjectMeta: metav1.ObjectMeta{ + Name: "autopilot", + }, + Status: apv1b2.PlanStatus{ + State: core.PlanCompleted, + }, + } + + cli := fake.NewClientBuilder(). + WithScheme(scheme). + WithObjects(plan). + Build() + + result, err := waitForAutopilotPlan(t.Context(), cli, logger) + require.NoError(t, err) + assert.Equal(t, "autopilot", result.Name) +} + +func TestWaitForAutopilotPlan_RetriesOnTransientErrors(t *testing.T) { + logger := logrus.New() + logger.SetLevel(logrus.ErrorLevel) + + scheme := runtime.NewScheme() + require.NoError(t, apv1b2.Install(scheme)) + + // Plan that starts completed + plan := &apv1b2.Plan{ + ObjectMeta: metav1.ObjectMeta{ + Name: "autopilot", + }, + Status: apv1b2.PlanStatus{ + State: core.PlanCompleted, + }, + } + + // Mock client that fails first 3 times, then succeeds + var callCount atomic.Int32 + cli := &mockClientWithRetries{ + Client: fake.NewClientBuilder().WithScheme(scheme).WithObjects(plan).Build(), + failCount: 3, + currentCount: &callCount, + } + + result, err := waitForAutopilotPlan(t.Context(), cli, logger) + require.NoError(t, err) + assert.Equal(t, "autopilot", result.Name) + assert.Equal(t, int32(4), callCount.Load(), "Should have retried 3 times before succeeding") +} + +func TestWaitForAutopilotPlan_ContextCanceled(t *testing.T) { + logger := logrus.New() + logger.SetLevel(logrus.ErrorLevel) + + scheme := runtime.NewScheme() + require.NoError(t, apv1b2.Install(scheme)) + + ctx, cancel := context.WithCancel(t.Context()) + cancel() // Cancel immediately + + cli := fake.NewClientBuilder().WithScheme(scheme).Build() + + _, err := waitForAutopilotPlan(ctx, cli, logger) + require.Error(t, err) + assert.Contains(t, err.Error(), "context canceled") +} + +func TestWaitForAutopilotPlan_WaitsForCompletion(t *testing.T) { + logger := logrus.New() + logger.SetLevel(logrus.ErrorLevel) + + scheme := runtime.NewScheme() + require.NoError(t, apv1b2.Install(scheme)) + + // Plan that starts in progress, then completes after some time + plan := &apv1b2.Plan{ + ObjectMeta: metav1.ObjectMeta{ + Name: "autopilot", + }, + Spec: apv1b2.PlanSpec{ + ID: "test-plan", + }, + Status: apv1b2.PlanStatus{ + State: core.PlanSchedulable, + }, + } + + cli := &mockClientWithStateChange{ + Client: fake.NewClientBuilder().WithScheme(scheme).WithObjects(plan).Build(), + plan: plan, + callsUntil: 3, // Will complete after 3 calls + } + + result, err := waitForAutopilotPlan(t.Context(), cli, logger) + require.NoError(t, err) + assert.Equal(t, "autopilot", result.Name) + assert.Equal(t, core.PlanCompleted, result.Status.State) +} + +// Mock client that fails N times before succeeding +type mockClientWithRetries struct { + client.Client + failCount int + currentCount *atomic.Int32 +} + +func (m *mockClientWithRetries) Get(ctx context.Context, key client.ObjectKey, obj client.Object, opts ...client.GetOption) error { + count := m.currentCount.Add(1) + if count <= int32(m.failCount) { + return fmt.Errorf("connection refused") + } + return m.Client.Get(ctx, key, obj, opts...) +} + +// Mock client that changes plan state after N calls +type mockClientWithStateChange struct { + client.Client + plan *apv1b2.Plan + callCount int + callsUntil int +} + +func (m *mockClientWithStateChange) Get(ctx context.Context, key client.ObjectKey, obj client.Object, opts ...client.GetOption) error { + m.callCount++ + err := m.Client.Get(ctx, key, obj, opts...) + if err != nil { + return err + } + + // After N calls, mark the plan as completed + if m.callCount >= m.callsUntil { + if plan, ok := obj.(*apv1b2.Plan); ok { + plan.Status.State = core.PlanCompleted + } + } + + return nil +}