diff --git a/k8s/helm-charts/seldon-core-v2-setup/templates/_components-deployments.tpl b/k8s/helm-charts/seldon-core-v2-setup/templates/_components-deployments.tpl index e1335c3677..e0f580d5b8 100644 --- a/k8s/helm-charts/seldon-core-v2-setup/templates/_components-deployments.tpl +++ b/k8s/helm-charts/seldon-core-v2-setup/templates/_components-deployments.tpl @@ -526,6 +526,9 @@ spec: - --pprof-port=$(PPROF_PORT) - --pprof-block-rate=$(PPROF_BLOCK_RATE) - --pprof-mutex-rate=$(PPROF_MUTEX_RATE) + - --retry-creating-failed-pipelines-tick=$(RETRY_CREATING_FAILED_PIPELINES_TICK) + - --retry-deleting-failed-pipelines-tick=$(RETRY_DELETING_FAILED_PIPELINES_TICK) + - --max-retry-failed-pipelines=$(MAX_RETRY_FAILED_PIPELINES) command: - /bin/scheduler env: @@ -611,6 +614,12 @@ spec: value: "0" - name: PPROF_MUTEX_RATE value: "0" + - name: RETRY_CREATING_FAILED_PIPELINES_TICK + value: 60s + - name: RETRY_DELETING_FAILED_PIPELINES_TICK + value: 60s + - name: MAX_RETRY_FAILED_PIPELINES + value: "10" image: '{{ .Values.scheduler.image.registry }}/{{ .Values.scheduler.image.repository }}:{{ .Values.scheduler.image.tag }}' imagePullPolicy: '{{ .Values.scheduler.image.pullPolicy }}' diff --git a/k8s/helm-charts/seldon-core-v2-setup/templates/_components-statefulsets.tpl b/k8s/helm-charts/seldon-core-v2-setup/templates/_components-statefulsets.tpl index 528a561121..0daabeda02 100644 --- a/k8s/helm-charts/seldon-core-v2-setup/templates/_components-statefulsets.tpl +++ b/k8s/helm-charts/seldon-core-v2-setup/templates/_components-statefulsets.tpl @@ -526,6 +526,9 @@ spec: - --pprof-port=$(PPROF_PORT) - --pprof-block-rate=$(PPROF_BLOCK_RATE) - --pprof-mutex-rate=$(PPROF_MUTEX_RATE) + - --retry-creating-failed-pipelines-tick=$(RETRY_CREATING_FAILED_PIPELINES_TICK) + - --retry-deleting-failed-pipelines-tick=$(RETRY_DELETING_FAILED_PIPELINES_TICK) + - --max-retry-failed-pipelines=$(MAX_RETRY_FAILED_PIPELINES) command: - /bin/scheduler env: @@ -611,6 +614,12 @@ spec: value: "0" - name: PPROF_MUTEX_RATE value: "0" + - name: RETRY_CREATING_FAILED_PIPELINES_TICK + value: 60s + - name: RETRY_DELETING_FAILED_PIPELINES_TICK + value: 60s + - name: MAX_RETRY_FAILED_PIPELINES + value: "10" image: '{{ .Values.scheduler.image.registry }}/{{ .Values.scheduler.image.repository }}:{{ .Values.scheduler.image.tag }}' imagePullPolicy: '{{ .Values.scheduler.image.pullPolicy }}' diff --git a/k8s/yaml/components.yaml b/k8s/yaml/components.yaml index 4017e76716..a8befe435b 100644 --- a/k8s/yaml/components.yaml +++ b/k8s/yaml/components.yaml @@ -373,6 +373,9 @@ spec: - --pprof-port=$(PPROF_PORT) - --pprof-block-rate=$(PPROF_BLOCK_RATE) - --pprof-mutex-rate=$(PPROF_MUTEX_RATE) + - --retry-creating-failed-pipelines-tick=$(RETRY_CREATING_FAILED_PIPELINES_TICK) + - --retry-deleting-failed-pipelines-tick=$(RETRY_DELETING_FAILED_PIPELINES_TICK) + - --max-retry-failed-pipelines=$(MAX_RETRY_FAILED_PIPELINES) command: - /bin/scheduler env: @@ -454,6 +457,12 @@ spec: value: "0" - name: PPROF_MUTEX_RATE value: "0" + - name: RETRY_CREATING_FAILED_PIPELINES_TICK + value: 60s + - name: RETRY_DELETING_FAILED_PIPELINES_TICK + value: 60s + - name: MAX_RETRY_FAILED_PIPELINES + value: "10" image: 'docker.io/seldonio/seldon-scheduler:latest' imagePullPolicy: 'IfNotPresent' livenessProbe: diff --git a/operator/config/seldonconfigs/default.yaml b/operator/config/seldonconfigs/default.yaml index a2c61012b8..bc6b63889b 100644 --- a/operator/config/seldonconfigs/default.yaml +++ b/operator/config/seldonconfigs/default.yaml @@ -347,6 +347,9 @@ spec: - --pprof-port=$(PPROF_PORT) - --pprof-block-rate=$(PPROF_BLOCK_RATE) - --pprof-mutex-rate=$(PPROF_MUTEX_RATE) + - --retry-creating-failed-pipelines-tick=$(RETRY_CREATING_FAILED_PIPELINES_TICK) + - --retry-deleting-failed-pipelines-tick=$(RETRY_DELETING_FAILED_PIPELINES_TICK) + - --max-retry-failed-pipelines=$(MAX_RETRY_FAILED_PIPELINES) command: - /bin/scheduler env: @@ -386,6 +389,12 @@ spec: value: "0" - name: PPROF_MUTEX_RATE value: "0" + - name: RETRY_CREATING_FAILED_PIPELINES_TICK + value: "60s" + - name: RETRY_DELETING_FAILED_PIPELINES_TICK + value: "60s" + - name: MAX_RETRY_FAILED_PIPELINES + value: "10" image: seldonio/seldon-scheduler:latest imagePullPolicy: Always name: scheduler diff --git a/scheduler/cmd/scheduler/main.go b/scheduler/cmd/scheduler/main.go index 5e8610a6eb..f1d1560257 100644 --- a/scheduler/cmd/scheduler/main.go +++ b/scheduler/cmd/scheduler/main.go @@ -51,37 +51,40 @@ import ( ) var ( - envoyPort uint - agentPort uint - agentMtlsPort uint - schedulerPort uint - schedulerMtlsPort uint - chainerPort uint - healthProbePort uint - namespace string - pipelineGatewayHost string - pipelineGatewayHttpPort int - pipelineGatewayGrpcPort int - logLevel string - tracingConfigPath string - dbPath string - nodeID string - allowPlaintxt bool // scheduler server - autoscalingModelEnabled bool - autoscalingServerEnabled bool - kafkaConfigPath string - scalingConfigPath string - schedulerReadyTimeoutSeconds uint - deletedResourceTTLSeconds uint - serverPackingEnabled bool - serverPackingPercentage float64 - accessLogPath string - enableAccessLog bool - includeSuccessfulRequests bool - enablePprof bool - pprofPort int - pprofMutexRate int - pprofBlockRate int + envoyPort uint + agentPort uint + agentMtlsPort uint + schedulerPort uint + schedulerMtlsPort uint + chainerPort uint + healthProbePort uint + namespace string + pipelineGatewayHost string + pipelineGatewayHttpPort int + pipelineGatewayGrpcPort int + logLevel string + tracingConfigPath string + dbPath string + nodeID string + allowPlaintxt bool // scheduler server + autoscalingModelEnabled bool + autoscalingServerEnabled bool + kafkaConfigPath string + scalingConfigPath string + schedulerReadyTimeoutSeconds uint + deletedResourceTTLSeconds uint + serverPackingEnabled bool + serverPackingPercentage float64 + accessLogPath string + enableAccessLog bool + includeSuccessfulRequests bool + enablePprof bool + pprofPort int + pprofMutexRate int + pprofBlockRate int + retryFailedCreatingPipelinesTick time.Duration + retryFailedDeletePipelinesTick time.Duration + maxRetryFailedPipelines uint ) const ( @@ -172,6 +175,11 @@ func init() { flag.IntVar(&pprofPort, "pprof-port", 6060, "pprof HTTP server port") flag.IntVar(&pprofBlockRate, "pprof-block-rate", 0, "pprof block rate") flag.IntVar(&pprofMutexRate, "pprof-mutex-rate", 0, "pprof mutex rate") + + // frequency to retry creating/deleting pipelines which failed to create/delete + flag.DurationVar(&retryFailedCreatingPipelinesTick, "retry-creating-failed-pipelines-tick", time.Minute, "tick interval for re-attempting to create pipelines which failed to create") + flag.DurationVar(&retryFailedDeletePipelinesTick, "retry-deleting-failed-pipelines-tick", time.Minute, "tick interval for re-attempting to delete pipelines which failed to terminate") + flag.UintVar(&maxRetryFailedPipelines, "max-retry-failed-pipelines", 10, "max number of retry attempts to create/terminate pipelines which failed to create/terminate") } func getNamespace() string { @@ -322,8 +330,11 @@ func main() { logger.WithError(err).Fatal("Failed to start data engine chainer server") } defer cs.Stop() + + ctx, stopPipelinePollers := context.WithCancel(context.Background()) + defer stopPipelinePollers() go func() { - err := cs.StartGrpcServer(chainerPort) + err := cs.StartGrpcServer(ctx, retryFailedCreatingPipelinesTick, retryFailedDeletePipelinesTick, maxRetryFailedPipelines, chainerPort) if err != nil { log.WithError(err).Fatalf("Chainer server start error") } @@ -382,7 +393,8 @@ func main() { ) defer s.Stop() - err = s.StartGrpcServers(allowPlaintxt, schedulerPort, schedulerMtlsPort) + err = s.StartGrpcServers(ctx, allowPlaintxt, schedulerPort, schedulerMtlsPort, retryFailedCreatingPipelinesTick, + retryFailedDeletePipelinesTick, maxRetryFailedPipelines) if err != nil { logger.WithError(err).Fatal("Failed to start server gRPC servers") } @@ -421,6 +433,7 @@ func main() { s.StopSendServerEvents() s.StopSendExperimentEvents() s.StopSendPipelineEvents() + stopPipelinePollers() s.StopSendControlPlaneEvents() as.StopAgentStreams() diff --git a/scheduler/go.mod b/scheduler/go.mod index fb2dc21118..2b89ecdf40 100644 --- a/scheduler/go.mod +++ b/scheduler/go.mod @@ -148,7 +148,10 @@ require ( sigs.k8s.io/yaml v1.5.0 // indirect ) -tool go.uber.org/mock/mockgen +tool ( + go.uber.org/mock/mockgen + golang.org/x/tools/cmd/stringer +) replace github.com/seldonio/seldon-core/components/tls/v2 => ../components/tls diff --git a/scheduler/pkg/kafka/conflict-resolution/conflict_resolution.go b/scheduler/pkg/kafka/conflict-resolution/conflict_resolution.go index accf41dd1b..23bca16b21 100644 --- a/scheduler/pkg/kafka/conflict-resolution/conflict_resolution.go +++ b/scheduler/pkg/kafka/conflict-resolution/conflict_resolution.go @@ -146,6 +146,11 @@ func GetPipelineStatus( messageStr += fmt.Sprintf("%d/%d failed ", failedCount, len(streams)) } + failedTerminatingCount := cr.GetCountResourceWithStatus(pipelineName, pipeline.PipelineFailedTerminating) + if failedTerminatingCount > 0 { + messageStr += fmt.Sprintf("%d/%d failed terminating", failedTerminatingCount, len(streams)) + } + rebalancingCount := cr.GetCountResourceWithStatus(pipelineName, pipeline.PipelineRebalancing) if rebalancingCount > 0 { messageStr += fmt.Sprintf("%d/%d rebalancing ", rebalancingCount, len(streams)) @@ -170,8 +175,8 @@ func GetPipelineStatus( } if message.Update.Op == chainer.PipelineUpdateMessage_Delete { - if failedCount > 0 { - return pipeline.PipelineFailed, messageStr + if failedTerminatingCount > 0 { + return pipeline.PipelineFailedTerminating, messageStr } if terminatedCount == len(streams) { return pipeline.PipelineTerminated, messageStr diff --git a/scheduler/pkg/kafka/conflict-resolution/conflict_resolution_test.go b/scheduler/pkg/kafka/conflict-resolution/conflict_resolution_test.go index 9637ff519d..4f1b16f4cb 100644 --- a/scheduler/pkg/kafka/conflict-resolution/conflict_resolution_test.go +++ b/scheduler/pkg/kafka/conflict-resolution/conflict_resolution_test.go @@ -150,11 +150,17 @@ func TestIsMessageOutdated(t *testing.T) { func TestGetPipelineStatus(t *testing.T) { g := gomega.NewGomegaWithT(t) + type expect struct { + status pipeline.PipelineStatus + msg string + } + tests := []struct { name string op chainer.PipelineUpdateMessage_PipelineOperation statuses map[string]pipeline.PipelineStatus - expected pipeline.PipelineStatus + expect expect + msg string }{ { name: "create creating", @@ -163,7 +169,7 @@ func TestGetPipelineStatus(t *testing.T) { "a": pipeline.PipelineReady, "b": pipeline.PipelineStatusUnknown, }, - expected: pipeline.PipelineCreating, + expect: expect{status: pipeline.PipelineCreating, msg: "1/2 ready "}, }, { name: "create ready (all ready)", @@ -172,7 +178,7 @@ func TestGetPipelineStatus(t *testing.T) { "a": pipeline.PipelineReady, "b": pipeline.PipelineReady, }, - expected: pipeline.PipelineReady, + expect: expect{status: pipeline.PipelineReady, msg: "2/2 ready "}, }, { name: "create creating (some ready)", @@ -181,7 +187,7 @@ func TestGetPipelineStatus(t *testing.T) { "a": pipeline.PipelineReady, "b": pipeline.PipelineFailed, }, - expected: pipeline.PipelineReady, + expect: expect{status: pipeline.PipelineReady, msg: "1/2 ready 1/2 failed "}, }, { name: "create failed", @@ -190,7 +196,7 @@ func TestGetPipelineStatus(t *testing.T) { "a": pipeline.PipelineFailed, "b": pipeline.PipelineFailed, }, - expected: pipeline.PipelineFailed, + expect: expect{status: pipeline.PipelineFailed, msg: "2/2 failed "}, }, { name: "delete terminating", @@ -199,15 +205,15 @@ func TestGetPipelineStatus(t *testing.T) { "a": pipeline.PipelineTerminated, "b": pipeline.PipelineStatusUnknown, }, - expected: pipeline.PipelineTerminating, + expect: expect{status: pipeline.PipelineTerminating, msg: "1/2 terminated "}, }, { name: "delete failed", op: chainer.PipelineUpdateMessage_Delete, statuses: map[string]pipeline.PipelineStatus{ - "a": pipeline.PipelineFailed, + "a": pipeline.PipelineFailedTerminating, }, - expected: pipeline.PipelineFailed, + expect: expect{status: pipeline.PipelineFailedTerminating, msg: "1/1 failed terminating"}, }, { name: "rebalance failed", @@ -216,7 +222,7 @@ func TestGetPipelineStatus(t *testing.T) { "a": pipeline.PipelineFailed, "b": pipeline.PipelineFailed, }, - expected: pipeline.PipelineFailed, + expect: expect{status: pipeline.PipelineFailed, msg: "2/2 failed "}, }, { name: "rebalanced", @@ -225,7 +231,7 @@ func TestGetPipelineStatus(t *testing.T) { "a": pipeline.PipelineReady, "b": pipeline.PipelineReady, }, - expected: pipeline.PipelineReady, + expect: expect{status: pipeline.PipelineReady, msg: "2/2 ready "}, }, { name: "rebalanced (some ready)", @@ -234,7 +240,7 @@ func TestGetPipelineStatus(t *testing.T) { "a": pipeline.PipelineReady, "b": pipeline.PipelineFailed, }, - expected: pipeline.PipelineReady, + expect: expect{status: pipeline.PipelineReady, msg: "1/2 ready 1/2 failed "}, }, { name: "rebalancing all", @@ -243,7 +249,7 @@ func TestGetPipelineStatus(t *testing.T) { "a": pipeline.PipelineRebalancing, "b": pipeline.PipelineRebalancing, }, - expected: pipeline.PipelineRebalancing, + expect: expect{status: pipeline.PipelineRebalancing, msg: "2/2 rebalancing "}, }, { name: "rebalancing some", @@ -252,7 +258,24 @@ func TestGetPipelineStatus(t *testing.T) { "a": pipeline.PipelineReady, "b": pipeline.PipelineRebalancing, }, - expected: pipeline.PipelineRebalancing, + expect: expect{status: pipeline.PipelineRebalancing, msg: "1/2 ready 1/2 rebalancing "}, + }, + { + name: "delete failed", + op: chainer.PipelineUpdateMessage_Delete, + statuses: map[string]pipeline.PipelineStatus{ + "a": pipeline.PipelineFailedTerminating, + }, + expect: expect{status: pipeline.PipelineFailedTerminating, msg: "1/1 failed terminating"}, + }, + { + name: "delete failed and pipeline failed to create", + op: chainer.PipelineUpdateMessage_Delete, + statuses: map[string]pipeline.PipelineStatus{ + "a": pipeline.PipelineFailedTerminating, + "b": pipeline.PipelineFailed, + }, + expect: expect{status: pipeline.PipelineFailedTerminating, msg: "1/2 failed 1/2 failed terminating"}, }, } @@ -275,8 +298,9 @@ func TestGetPipelineStatus(t *testing.T) { }, } - status, _ := GetPipelineStatus(cr, "p1", msg) - g.Expect(status).To(gomega.Equal(test.expected)) + status, outputMsg := GetPipelineStatus(cr, "p1", msg) + g.Expect(status).To(gomega.Equal(test.expect.status)) + g.Expect(outputMsg).To(gomega.Equal(test.expect.msg)) }) } } diff --git a/scheduler/pkg/kafka/dataflow/server.go b/scheduler/pkg/kafka/dataflow/server.go index 1b771763f7..f710657a89 100644 --- a/scheduler/pkg/kafka/dataflow/server.go +++ b/scheduler/pkg/kafka/dataflow/server.go @@ -15,6 +15,7 @@ import ( "fmt" "net" "sync" + "time" log "github.com/sirupsen/logrus" "google.golang.org/grpc" @@ -54,6 +55,18 @@ type ChainerServer struct { currentScalingConfig scaling_config.ScalingConfig done chan struct{} grpcServer *grpc.Server + muFailedCreate sync.Mutex + // TODO we should update PipelineHandler to store state for dataflow-engine, as we do for model-gw. That way we + // won't have to use failedCreatePipelines and failedDeletePipelines. + // failedCreatePipelines keyed off pipeline UID + version + failedCreatePipelines map[string]pipeline.PipelineVersion + muFailedDelete sync.Mutex + // failedDeletePipelines keyed off pipeline UID + version + failedDeletePipelines map[string]pipeline.PipelineVersion + muRetriedFailedPipelines sync.Mutex + // retriedFailedPipelines keyed off pipeline UID + version. Tracks how many attempts have been made to create/terminate + // a pipeline + retriedFailedPipelines map[string]uint chainer.UnimplementedChainerServer health.UnimplementedHealthCheckServiceServer } @@ -79,16 +92,19 @@ func NewChainerServer( return nil, err } c := &ChainerServer{ - logger: logger.WithField("source", "dataflow"), - streams: make(map[string]*ChainerSubscription), - eventHub: eventHub, - pipelineHandler: pipelineHandler, - topicNamer: topicNamer, - loadBalancer: loadBalancer, - conflictResolutioner: conflictResolutioner, - chainerMutex: sync.Map{}, - scalingConfigUpdates: make(chan scaling_config.ScalingConfig), - done: make(chan struct{}), + logger: logger.WithField("source", "dataflow"), + streams: make(map[string]*ChainerSubscription), + eventHub: eventHub, + pipelineHandler: pipelineHandler, + topicNamer: topicNamer, + loadBalancer: loadBalancer, + conflictResolutioner: conflictResolutioner, + chainerMutex: sync.Map{}, + scalingConfigUpdates: make(chan scaling_config.ScalingConfig), + done: make(chan struct{}), + failedCreatePipelines: make(map[string]pipeline.PipelineVersion, 0), + failedDeletePipelines: make(map[string]pipeline.PipelineVersion, 0), + retriedFailedPipelines: make(map[string]uint, 0), } eventHub.RegisterPipelineEventHandler( @@ -123,12 +139,15 @@ func (c *ChainerServer) Stop() { c.StopSendPipelineEvents() } -func (c *ChainerServer) StartGrpcServer(agentPort uint) error { +func (c *ChainerServer) StartGrpcServer(ctx context.Context, pollerFailedCreatePipelines, pollerFailedDeletePipelines time.Duration, maxRetry, agentPort uint) error { lis, err := net.Listen("tcp", fmt.Sprintf(":%d", agentPort)) if err != nil { log.Fatalf("failed to create listener: %v", err) } + go c.pollerFailedTerminatingPipelines(ctx, pollerFailedDeletePipelines, maxRetry) + go c.pollerFailedCreatingPipelines(ctx, pollerFailedCreatePipelines, maxRetry) + kaep := util.GetServerKeepAliveEnforcementPolicy() var grpcOptions []grpc.ServerOption @@ -143,25 +162,66 @@ func (c *ChainerServer) StartGrpcServer(agentPort uint) error { return grpcServer.Serve(lis) } -func (c *ChainerServer) PipelineUpdateEvent(ctx context.Context, message *chainer.PipelineUpdateStatusMessage) (*chainer.PipelineUpdateStatusResponse, error) { +func (c *ChainerServer) mkPipelineRetryKey(uid string, version uint32) string { + return fmt.Sprintf("%s_%d", uid, version) +} + +func (c *ChainerServer) storeFailedCreate(m *chainer.PipelineUpdateMessage) { + c.muFailedCreate.Lock() + defer c.muFailedCreate.Unlock() + c.failedCreatePipelines[c.mkPipelineRetryKey(m.Uid, m.Version)] = pipeline.PipelineVersion{ + Name: m.Pipeline, + Version: m.Version, + UID: m.Uid, + } +} + +func (c *ChainerServer) storeFailedDelete(m *chainer.PipelineUpdateMessage) { + c.muFailedDelete.Lock() + defer c.muFailedDelete.Unlock() + c.failedDeletePipelines[c.mkPipelineRetryKey(m.Uid, m.Version)] = pipeline.PipelineVersion{ + Name: m.Pipeline, + Version: m.Version, + UID: m.Uid, + } +} + +func (c *ChainerServer) resetPipelineRetryCount(msg *chainer.PipelineUpdateMessage) { + c.muRetriedFailedPipelines.Lock() + defer c.muRetriedFailedPipelines.Unlock() + c.retriedFailedPipelines[c.mkPipelineRetryKey(msg.Uid, msg.Version)] = 0 +} + +func (c *ChainerServer) removePipelineRetryCount(msg *chainer.PipelineUpdateMessage) { + c.muRetriedFailedPipelines.Lock() + defer c.muRetriedFailedPipelines.Unlock() + delete(c.retriedFailedPipelines, c.mkPipelineRetryKey(msg.Uid, msg.Version)) +} + +func (c *ChainerServer) PipelineUpdateEvent(_ context.Context, message *chainer.PipelineUpdateStatusMessage) (*chainer.PipelineUpdateStatusResponse, error) { c.mu.Lock() defer c.mu.Unlock() logger := c.logger.WithField("func", "PipelineUpdateEvent") var statusVal pipeline.PipelineStatus + switch message.Update.Op { // create, delete, rebalance operation from the scheduler case chainer.PipelineUpdateMessage_Create: if message.Success { + c.resetPipelineRetryCount(message.Update) statusVal = pipeline.PipelineReady } else { + c.storeFailedCreate(message.Update) statusVal = pipeline.PipelineFailed } case chainer.PipelineUpdateMessage_Delete: if message.Success { + c.removePipelineRetryCount(message.Update) statusVal = pipeline.PipelineTerminated } else { - statusVal = pipeline.PipelineFailed + c.storeFailedDelete(message.Update) + statusVal = pipeline.PipelineFailedTerminating } // internal rebalancing operation case chainer.PipelineUpdateMessage_Rebalance: @@ -445,7 +505,6 @@ func (c *ChainerServer) sendPipelineMsgToSelectedServers(msg *chainer.PipelineUp for _, serverId := range servers { if subscription, ok := c.streams[serverId]; ok { - select { case <-subscription.stream.Context().Done(): logger.WithError(subscription.stream.Context().Err()).Errorf("Failed to send msg to pipeline %s - stream ctx cancelled", pv.String()) @@ -469,84 +528,238 @@ func contains(slice []string, val string) bool { return false } -func (c *ChainerServer) rebalance() { - c.mu.Lock() - defer c.mu.Unlock() +func (c *ChainerServer) pollerFailedTerminatingPipelines(ctx context.Context, tick time.Duration, maxRetry uint) { + ticker := time.NewTicker(tick) + defer ticker.Stop() + logger := c.logger.WithField("func", "pollerFailedTerminatingPipelines") + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + // check for any pipelines which failed to create and retry + logger.Debug("Checking for pipelines which failed to terminate") + c.muFailedDelete.Lock() + if len(c.failedDeletePipelines) == 0 { + c.muFailedDelete.Unlock() + logger.Debug("No pipelines found that failed to terminate") + continue + } + + c.mu.Lock() + for _, p := range c.failedDeletePipelines { + key := c.mkPipelineRetryKey(p.UID, p.Version) + c.muRetriedFailedPipelines.Lock() + c.retriedFailedPipelines[key]++ + + if c.retriedFailedPipelines[key] > maxRetry { + c.muRetriedFailedPipelines.Unlock() + logger.Warnf("Failed to terminate pipeline %s reached max retries", p.Name) + delete(c.failedDeletePipelines, key) + continue + } + c.muRetriedFailedPipelines.Unlock() + + logger.Debugf("Attempting to terminate pipeline which failed to terminate %s", p.Name) + pv, err := c.pipelineHandler.GetPipelineVersion(p.Name, p.Version, p.UID) + if err != nil { + notFound := &pipeline.PipelineNotFoundErr{} + uidMisMatch := &pipeline.PipelineVersionUidMismatchErr{} + verNotFound := &pipeline.PipelineVersionNotFoundErr{} + + if errors.As(err, ¬Found) || errors.As(err, &uidMisMatch) || errors.As(err, &verNotFound) { + delete(c.failedDeletePipelines, key) + logger.Debugf("Pipeline %s not found, removing from poller list", p.Name) + continue + } + + logger.WithError(err).Errorf("Failed to get pipeline %s", p.Name) + continue + } + logger.Debugf("Found pipeline %s attempting to terminate", p.Name) + + // note we are forcing keeping topics here, so there may be unwanted orphaned topics left in Kafka even + // though customer deleted pipeline and set pipeline config to delete topics. This is because we don't + // know if the termination request was initiated by customer i.e. deleted Pipeline CR, or from a rebalance + // of pipelines across dataflow-engine replicas (in which case we force keeping topics) + c.terminatePipeline(pv, true) + // remove from list as we've successfully retried termination + delete(c.failedDeletePipelines, key) + } + + c.mu.Unlock() + c.muFailedDelete.Unlock() + } + } +} + +func (c *ChainerServer) pollerFailedCreatingPipelines(ctx context.Context, tick time.Duration, maxRetry uint) { + ticker := time.NewTicker(tick) + defer ticker.Stop() + + logger := c.logger.WithField("func", "pollerFailedCreatingPipelines") + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + // check for any pipelines which failed to create and retry + logger.Debug("Checking for pipelines which failed to create") + c.muFailedCreate.Lock() + if len(c.failedCreatePipelines) == 0 { + c.muFailedCreate.Unlock() + logger.Debug("No pipelines found that failed to create") + continue + } + + c.mu.Lock() + for _, p := range c.failedCreatePipelines { + key := c.mkPipelineRetryKey(p.UID, p.Version) + c.muRetriedFailedPipelines.Lock() + c.retriedFailedPipelines[key]++ + logger.Debugf("Attempting to create failed pipeline %s", p.Name) + + if c.retriedFailedPipelines[key] > maxRetry { + c.muRetriedFailedPipelines.Unlock() + logger.Warnf("Failed to create pipeline %s reached max retries", p.Name) + delete(c.failedCreatePipelines, key) + continue + } + c.muRetriedFailedPipelines.Unlock() + + // we only want to create this pipeline if it's the latest version, it could have failed to create + // and customer has since updated the pipeline and has created successfully, we'd then end up + // overwriting the new pipeline + if isLatest, err := c.pipelineHandler.IsLatestVersion(p.Name, p.Version, p.UID); err != nil { + logger.WithError(err).Errorf("Failed checking pipeline %s is latest version before creating", p.Name) + delete(c.failedCreatePipelines, key) + continue + } else if !isLatest { + logger.Debugf("Pipeline %s not the latest, ignoring", p.Name) + delete(c.failedCreatePipelines, key) + continue + } + + if err := c.rebalancePipeline(p.Name, p.Version, p.UID); err != nil { + notFound := &pipeline.PipelineNotFoundErr{} + uidMisMatch := &pipeline.PipelineVersionUidMismatchErr{} + verNotFound := &pipeline.PipelineVersionNotFoundErr{} + + if errors.As(err, ¬Found) || errors.As(err, &uidMisMatch) || errors.As(err, &verNotFound) { + delete(c.failedCreatePipelines, key) + logger.Debugf("Pipeline %s not found, removing from poller list", p.Name) + continue + } + + // don't remove from map as we want to retry on next tick + logger.WithError(err).Errorf("Failed to create pipeline %s", p.Name) + continue + } + + // remove from list as we've successfully retried creating + delete(c.failedCreatePipelines, key) + } + + c.mu.Unlock() + c.muFailedCreate.Unlock() + } + } +} + +func (c *ChainerServer) rebalancePipeline(pipelineName string, pipelineVersion uint32, pipelineUID string) error { logger := c.logger.WithField("func", "rebalance") - // note that we are not retrying PipelineFailed pipelines, consider adding this - evts := c.pipelineHandler.GetAllRunningPipelineVersions() - for _, event := range evts { - pv, err := c.pipelineHandler.GetPipelineVersion(event.PipelineName, event.PipelineVersion, event.UID) - if err != nil { - logger.WithError(err).Errorf("Failed to get pipeline from event %s", event.String()) - continue + pv, err := c.pipelineHandler.GetPipelineVersion(pipelineName, pipelineVersion, pipelineUID) + if err != nil { + logger.WithError(err).Errorf("Failed to get pipeline %s UID %d version %s", pipelineName, pipelineVersion, pipelineUID) + return err + } + + c.logger.Debugf("Rebalancing pipeline %s:%d with state %s", pipelineName, pipelineVersion, pv.State.Status.String()) + if len(c.streams) == 0 { + pipelineState := pipeline.PipelineCreate + // if no dataflow engines available then we think we can terminate pipelines. + if pv.State.Status == pipeline.PipelineTerminating { + pipelineState = pipeline.PipelineTerminated } - c.logger.Debugf("Rebalancing pipeline %s:%d with state %s", event.PipelineName, event.PipelineVersion, pv.State.Status.String()) - if len(c.streams) == 0 { - pipelineState := pipeline.PipelineCreate - // if no dataflow engines available then we think we can terminate pipelines. + + c.logger.Debugf("No dataflow engines available to handle pipeline %s, setting state to %s", pv.String(), pipelineState.String()) + if err := c.pipelineHandler.SetPipelineState( + pv.Name, + pv.Version, + pv.UID, + pipelineState, + "no dataflow engines available to handle pipeline", + util.SourceChainerServer, + ); err != nil { + logger.WithError(err).Errorf("Failed to set pipeline state to creating for %s", pv.String()) + return fmt.Errorf("failed setting pipeline state: %w", err) + } + + return nil + } + + var msg *chainer.PipelineUpdateMessage + servers := c.loadBalancer.GetServersForKey(pv.UID) + cr.CreateNewPipelineIteration(c.conflictResolutioner, pv.Name, servers) + + var errs error + for server, subscription := range c.streams { + if contains(servers, server) { + // we do not need to set pipeline state to creating if it is already in terminating state, and we need to delete it if pv.State.Status == pipeline.PipelineTerminating { - pipelineState = pipeline.PipelineTerminated + msg = c.createPipelineDeletionMessage(pv, false) + } else { + msg = c.createPipelineCreationMessage(pv) + pipelineState := pipeline.PipelineCreating + if err := c.pipelineHandler.SetPipelineState(pv.Name, pv.Version, pv.UID, pipelineState, "Rebalance", util.SourceChainerServer); err != nil { + logger.WithError(err).Errorf("Failed to set pipeline state to creating for %s", pv.String()) + } } - c.logger.Debugf("No dataflow engines available to handle pipeline %s, setting state to %s", pv.String(), pipelineState.String()) - if err := c.pipelineHandler.SetPipelineState( - pv.Name, - pv.Version, - pv.UID, - pipelineState, - "no dataflow engines available to handle pipeline", - util.SourceChainerServer, - ); err != nil { - logger.WithError(err).Errorf("Failed to set pipeline state to creating for %s", pv.String()) + msg.Timestamp = c.conflictResolutioner.GetTimestamp(pv.Name) + + select { + case <-subscription.stream.Context().Done(): + err := subscription.stream.Context().Err() + logger.WithError(err).Errorf("Failed to send create rebalance msg to pipeline %s stream ctx cancelled", pv.String()) + errs = errors.Join(errs, err) + default: + if err := subscription.stream.Send(msg); err != nil { + logger.WithError(err).Errorf("Failed to send create rebalance msg to pipeline %s", pv.String()) + errs = errors.Join(errs, err) + } } - } else { - var msg *chainer.PipelineUpdateMessage - servers := c.loadBalancer.GetServersForKey(pv.UID) - cr.CreateNewPipelineIteration(c.conflictResolutioner, pv.Name, servers) - - for server, subscription := range c.streams { - if contains(servers, server) { - // we do not need to set pipeline state to creating if it is already in terminating state, and we need to delete it - if pv.State.Status == pipeline.PipelineTerminating { - msg = c.createPipelineDeletionMessage(pv, false) - } else { - msg = c.createPipelineCreationMessage(pv) - pipelineState := pipeline.PipelineCreating - if err := c.pipelineHandler.SetPipelineState(pv.Name, pv.Version, pv.UID, pipelineState, "Rebalance", util.SourceChainerServer); err != nil { - logger.WithError(err).Errorf("Failed to set pipeline state to creating for %s", pv.String()) - } - } - msg.Timestamp = c.conflictResolutioner.GetTimestamp(pv.Name) - - select { - case <-subscription.stream.Context().Done(): - err := subscription.stream.Context().Err() - logger.WithError(err).Errorf("Failed to send create rebalance msg to pipeline %s stream ctx cancelled", pv.String()) - default: - if err := subscription.stream.Send(msg); err != nil { - logger.WithError(err).Errorf("Failed to send create rebalance msg to pipeline %s", pv.String()) - } - } + continue + } - } else { - msg = c.createPipelineDeletionMessage(pv, true) - msg.Timestamp = c.conflictResolutioner.GetTimestamp(pv.Name) - - select { - case <-subscription.stream.Context().Done(): - err := subscription.stream.Context().Err() - logger.WithError(err).Errorf("Failed to send delete rebalance msg to pipeline %s stream ctx cancelled", pv.String()) - default: - if err := subscription.stream.Send(msg); err != nil { - logger.WithError(err).Errorf("Failed to send delete rebalance msg to pipeline %s", pv.String()) - } - } + msg = c.createPipelineDeletionMessage(pv, true) + msg.Timestamp = c.conflictResolutioner.GetTimestamp(pv.Name) - } + select { + case <-subscription.stream.Context().Done(): + err := subscription.stream.Context().Err() + logger.WithError(err).Errorf("Failed to send delete rebalance msg to pipeline %s stream ctx cancelled", pv.String()) + errs = errors.Join(errs, err) + default: + if err := subscription.stream.Send(msg); err != nil { + logger.WithError(err).Errorf("Failed to send delete rebalance msg to pipeline %s", pv.String()) + errs = errors.Join(errs, err) } } } + + return errs +} + +func (c *ChainerServer) rebalance() { + c.mu.Lock() + defer c.mu.Unlock() + + for _, event := range c.pipelineHandler.GetAllRunningPipelineVersions() { + if err := c.rebalancePipeline(event.PipelineName, event.PipelineVersion, event.UID); err != nil { + c.logger.WithError(err).Errorf("Failed to rebalance pipeline %s", event.PipelineName) + } + } } func (c *ChainerServer) handleScalingConfigChanges() { @@ -638,12 +851,16 @@ func (c *ChainerServer) handlePipelineEvent(event coordinator.PipelineEventMsg) c.sendPipelineMsgToSelectedServers(msg, pv) case pipeline.PipelineTerminate: - err := c.pipelineHandler.SetPipelineState(pv.Name, pv.Version, pv.UID, pipeline.PipelineTerminating, "", util.SourceChainerServer) - if err != nil { - logger.WithError(err).Errorf("Failed to set pipeline state to terminating for %s", pv.String()) - } - msg := c.createPipelineDeletionMessage(pv, event.KeepTopics) // note pv is a copy and does not include the new change to terminating state - c.sendPipelineMsgToSelectedServers(msg, pv) + c.terminatePipeline(pv, event.KeepTopics) } }() } + +func (c *ChainerServer) terminatePipeline(pv *pipeline.PipelineVersion, keepTopics bool) { + err := c.pipelineHandler.SetPipelineState(pv.Name, pv.Version, pv.UID, pipeline.PipelineTerminating, "", util.SourceChainerServer) + if err != nil { + c.logger.WithError(err).Errorf("Failed to set pipeline state to terminating for %s", pv.Name) + } + msg := c.createPipelineDeletionMessage(pv, keepTopics) // note pv is a copy and does not include the new change to terminating state + c.sendPipelineMsgToSelectedServers(msg, pv) +} diff --git a/scheduler/pkg/kafka/dataflow/server_test.go b/scheduler/pkg/kafka/dataflow/server_test.go index c717da61b2..92312a0752 100644 --- a/scheduler/pkg/kafka/dataflow/server_test.go +++ b/scheduler/pkg/kafka/dataflow/server_test.go @@ -11,6 +11,7 @@ package dataflow import ( "context" + "errors" "fmt" "os" "sync" @@ -19,6 +20,7 @@ import ( . "github.com/onsi/gomega" log "github.com/sirupsen/logrus" + "go.uber.org/mock/gomock" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" @@ -27,13 +29,821 @@ import ( kafka_config "github.com/seldonio/seldon-core/components/kafka/v2/pkg/config" "github.com/seldonio/seldon-core/scheduler/v2/pkg/coordinator" - testing_utils "github.com/seldonio/seldon-core/scheduler/v2/pkg/internal/testing_utils" + "github.com/seldonio/seldon-core/scheduler/v2/pkg/internal/testing_utils" "github.com/seldonio/seldon-core/scheduler/v2/pkg/kafka" + cr "github.com/seldonio/seldon-core/scheduler/v2/pkg/kafka/conflict-resolution" "github.com/seldonio/seldon-core/scheduler/v2/pkg/store" "github.com/seldonio/seldon-core/scheduler/v2/pkg/store/pipeline" + "github.com/seldonio/seldon-core/scheduler/v2/pkg/store/pipeline/mock" "github.com/seldonio/seldon-core/scheduler/v2/pkg/util" + mock2 "github.com/seldonio/seldon-core/scheduler/v2/pkg/util/mock" ) +func TestPollerFailedTerminatingPipelines(t *testing.T) { + tests := []struct { + name string + failedPipelines map[string]pipeline.PipelineVersion + needsLoadBalancer bool + needsConflictResolver bool + setupMocks func(mockPipelineHandler *mock.MockPipelineHandler, mockLoadBalancer *mock2.MockLoadBalancer, failedPipelines map[string]pipeline.PipelineVersion) + contextTimeout time.Duration + tickDuration time.Duration + validateResult func(g *WithT, server *ChainerServer) + expectGomegaWithT bool + maxRetry uint + }{ + { + name: "should return when context is cancelled", + failedPipelines: make(map[string]pipeline.PipelineVersion), + needsLoadBalancer: false, + needsConflictResolver: false, + setupMocks: func(mockPipelineHandler *mock.MockPipelineHandler, mockLoadBalancer *mock2.MockLoadBalancer, failedPipelines map[string]pipeline.PipelineVersion) { + // No expectations - context cancelled before first tick + }, + contextTimeout: 0, // Cancel immediately + tickDuration: 100 * time.Millisecond, + expectGomegaWithT: false, + }, + { + name: "should skip processing when no failed pipelines exist", + failedPipelines: make(map[string]pipeline.PipelineVersion), + needsLoadBalancer: false, + needsConflictResolver: false, + setupMocks: func(mockPipelineHandler *mock.MockPipelineHandler, mockLoadBalancer *mock2.MockLoadBalancer, failedPipelines map[string]pipeline.PipelineVersion) { + // No expectations - empty map means no processing + }, + contextTimeout: 150 * time.Millisecond, + tickDuration: 50 * time.Millisecond, + expectGomegaWithT: false, + }, + { + name: "should retry terminating failed pipeline successfully", + maxRetry: 1, + failedPipelines: map[string]pipeline.PipelineVersion{ + "test-uid-123_1": { + Name: "test-pipeline", + Version: 1, + UID: "test-uid-123", + }, + }, + needsLoadBalancer: true, + needsConflictResolver: true, + setupMocks: func(mockPipelineHandler *mock.MockPipelineHandler, mockLoadBalancer *mock2.MockLoadBalancer, failedPipelines map[string]pipeline.PipelineVersion) { + mockLoadBalancer.EXPECT(). + GetServersForKey("test-uid-123"). + Return([]string{}) + + mockPipelineHandler.EXPECT(). + GetPipelineVersion("test-pipeline", uint32(1), "test-uid-123"). + Return(&pipeline.PipelineVersion{ + Name: "test-pipeline", + Version: 1, + UID: "test-uid-123", + State: &pipeline.PipelineState{ + Status: pipeline.PipelineTerminating, + }, + }, nil) + + mockPipelineHandler.EXPECT(). + SetPipelineState("test-pipeline", uint32(1), "test-uid-123", pipeline.PipelineTerminating, gomock.Any(), gomock.Any()). + Return(nil) + }, + contextTimeout: 150 * time.Millisecond, + tickDuration: 50 * time.Millisecond, + expectGomegaWithT: false, + }, + { + name: "should remove pipeline from failed list when not found", + maxRetry: 1, + failedPipelines: map[string]pipeline.PipelineVersion{ + "test-uid-123_1": { + Name: "test-pipeline", + Version: 1, + UID: "test-uid-123", + }, + }, + needsLoadBalancer: false, + needsConflictResolver: false, + setupMocks: func(mockPipelineHandler *mock.MockPipelineHandler, mockLoadBalancer *mock2.MockLoadBalancer, failedPipelines map[string]pipeline.PipelineVersion) { + mockPipelineHandler.EXPECT(). + GetPipelineVersion("test-pipeline", uint32(1), "test-uid-123"). + Return(nil, &pipeline.PipelineNotFoundErr{}) + }, + contextTimeout: 150 * time.Millisecond, + tickDuration: 50 * time.Millisecond, + validateResult: func(g *WithT, server *ChainerServer) { + g.Expect(server.failedDeletePipelines).ToNot(HaveKey("test-uid-123_1")) + }, + expectGomegaWithT: true, + }, + { + name: "should remove pipeline from failed list on UID mismatch", + maxRetry: 1, + failedPipelines: map[string]pipeline.PipelineVersion{ + "test-uid-123_1": { + Name: "test-pipeline", + Version: 1, + UID: "test-uid-123", + }, + }, + needsLoadBalancer: false, + needsConflictResolver: false, + setupMocks: func(mockPipelineHandler *mock.MockPipelineHandler, mockLoadBalancer *mock2.MockLoadBalancer, failedPipelines map[string]pipeline.PipelineVersion) { + mockPipelineHandler.EXPECT(). + GetPipelineVersion("test-pipeline", uint32(1), "test-uid-123"). + Return(nil, &pipeline.PipelineVersionUidMismatchErr{}) + }, + contextTimeout: 150 * time.Millisecond, + tickDuration: 50 * time.Millisecond, + validateResult: func(g *WithT, server *ChainerServer) { + g.Expect(server.failedDeletePipelines).ToNot(HaveKey("test-uid-123_1")) + }, + expectGomegaWithT: true, + }, + { + name: "should remove pipeline from failed list on version not found", + maxRetry: 1, + failedPipelines: map[string]pipeline.PipelineVersion{ + "test-uid-123_1": { + Name: "test-pipeline", + Version: 1, + UID: "test-uid-123", + }, + }, + needsLoadBalancer: false, + needsConflictResolver: false, + setupMocks: func(mockPipelineHandler *mock.MockPipelineHandler, mockLoadBalancer *mock2.MockLoadBalancer, failedPipelines map[string]pipeline.PipelineVersion) { + mockPipelineHandler.EXPECT(). + GetPipelineVersion("test-pipeline", uint32(1), "test-uid-123"). + Return(nil, &pipeline.PipelineVersionNotFoundErr{}) + }, + contextTimeout: 150 * time.Millisecond, + tickDuration: 50 * time.Millisecond, + validateResult: func(g *WithT, server *ChainerServer) { + g.Expect(server.failedDeletePipelines).ToNot(HaveKey("test-uid-123_1")) + }, + expectGomegaWithT: true, + }, + { + name: "should continue processing on generic error", + failedPipelines: map[string]pipeline.PipelineVersion{ + "test-uid-123_1": { + Name: "test-pipeline", + Version: 1, + UID: "test-uid-123", + }, + }, + needsLoadBalancer: false, + needsConflictResolver: false, + maxRetry: 100, + setupMocks: func(mockPipelineHandler *mock.MockPipelineHandler, mockLoadBalancer *mock2.MockLoadBalancer, failedPipelines map[string]pipeline.PipelineVersion) { + mockPipelineHandler.EXPECT(). + GetPipelineVersion("test-pipeline", uint32(1), "test-uid-123"). + Return(nil, errors.New("generic error")). + MinTimes(1) + }, + contextTimeout: 150 * time.Millisecond, + tickDuration: 50 * time.Millisecond, + validateResult: func(g *WithT, server *ChainerServer) { + // Pipeline should still be in failed list + g.Expect(server.failedDeletePipelines).To(HaveKey("test-uid-123_1")) + }, + expectGomegaWithT: true, + }, + { + name: "should remove from retry list, max retry limit exceeded", + failedPipelines: map[string]pipeline.PipelineVersion{ + "test-uid-123_1": { + Name: "test-pipeline", + Version: 1, + UID: "test-uid-123", + }, + }, + needsLoadBalancer: false, + needsConflictResolver: false, + maxRetry: 1, + setupMocks: func(mockPipelineHandler *mock.MockPipelineHandler, mockLoadBalancer *mock2.MockLoadBalancer, failedPipelines map[string]pipeline.PipelineVersion) { + mockPipelineHandler.EXPECT(). + GetPipelineVersion("test-pipeline", uint32(1), "test-uid-123"). + Return(nil, errors.New("generic error")). + MinTimes(1) + }, + contextTimeout: 150 * time.Millisecond, + tickDuration: 50 * time.Millisecond, + validateResult: func(g *WithT, server *ChainerServer) { + // Pipeline should still be in failed list + g.Expect(server.failedDeletePipelines).ToNot(HaveKey("test-uid-123_1")) + }, + expectGomegaWithT: true, + }, + { + name: "should process multiple failed pipelines", + maxRetry: 1, + failedPipelines: map[string]pipeline.PipelineVersion{ + "uid-1_1": { + Name: "pipeline-1", + Version: 1, + UID: "uid-1", + }, + "uid-2_1": { + Name: "pipeline-2", + Version: 1, + UID: "uid-2", + }, + }, + needsLoadBalancer: true, + needsConflictResolver: true, + setupMocks: func(mockPipelineHandler *mock.MockPipelineHandler, mockLoadBalancer *mock2.MockLoadBalancer, failedPipelines map[string]pipeline.PipelineVersion) { + mockPipelineHandler.EXPECT(). + GetPipelineVersion("pipeline-1", uint32(1), "uid-1"). + Return(&pipeline.PipelineVersion{ + Name: "pipeline-1", + Version: 1, + UID: "uid-1", + State: &pipeline.PipelineState{ + Status: pipeline.PipelineTerminating, + }, + }, nil) + + mockPipelineHandler.EXPECT(). + GetPipelineVersion("pipeline-2", uint32(1), "uid-2"). + Return(&pipeline.PipelineVersion{ + Name: "pipeline-2", + Version: 1, + UID: "uid-2", + State: &pipeline.PipelineState{ + Status: pipeline.PipelineTerminating, + }, + }, nil) + + mockPipelineHandler.EXPECT(). + SetPipelineState(gomock.Any(), gomock.Any(), gomock.Any(), pipeline.PipelineTerminating, gomock.Any(), gomock.Any()). + Return(nil). + Times(2) + + mockLoadBalancer.EXPECT().GetServersForKey("uid-1").Return([]string{}) + mockLoadBalancer.EXPECT().GetServersForKey("uid-2").Return([]string{}) + }, + contextTimeout: 150 * time.Millisecond, + tickDuration: 50 * time.Millisecond, + expectGomegaWithT: false, + }, + { + name: "should tick multiple times before context cancellation", + failedPipelines: map[string]pipeline.PipelineVersion{ + "test-uid-123_1": { + Name: "test-pipeline", + Version: 1, + UID: "test-uid-123", + }, + }, + needsLoadBalancer: false, + needsConflictResolver: false, + maxRetry: 100, + setupMocks: func(mockPipelineHandler *mock.MockPipelineHandler, mockLoadBalancer *mock2.MockLoadBalancer, failedPipelines map[string]pipeline.PipelineVersion) { + // Expect at least 2 calls (multiple ticks) + mockPipelineHandler.EXPECT(). + GetPipelineVersion("test-pipeline", uint32(1), "test-uid-123"). + Return(nil, errors.New("some error")). + MinTimes(2) + }, + contextTimeout: 250 * time.Millisecond, + tickDuration: 50 * time.Millisecond, + expectGomegaWithT: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var g *WithT + if tt.expectGomegaWithT { + g = NewGomegaWithT(t) + } + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockPipelineHandler := mock.NewMockPipelineHandler(ctrl) + var mockLoadBalancer *mock2.MockLoadBalancer + if tt.needsLoadBalancer { + mockLoadBalancer = mock2.NewMockLoadBalancer(ctrl) + } + + // Setup mocks for this test case + if tt.setupMocks != nil { + tt.setupMocks(mockPipelineHandler, mockLoadBalancer, tt.failedPipelines) + } + + server := &ChainerServer{ + logger: log.New(), + pipelineHandler: mockPipelineHandler, + failedDeletePipelines: tt.failedPipelines, + streams: make(map[string]*ChainerSubscription), + retriedFailedPipelines: make(map[string]uint), + } + + if tt.needsLoadBalancer { + server.loadBalancer = mockLoadBalancer + } + + if tt.needsConflictResolver { + server.conflictResolutioner = cr.NewConflictResolution[pipeline.PipelineStatus](log.New()) + } + + var ctx context.Context + var cancel context.CancelFunc + + if tt.contextTimeout == 0 { + // Cancel immediately + ctx, cancel = context.WithCancel(context.Background()) + cancel() + } else { + ctx, cancel = context.WithTimeout(context.Background(), tt.contextTimeout) + defer cancel() + } + + done := make(chan bool) + go func() { + server.pollerFailedTerminatingPipelines(ctx, tt.tickDuration, tt.maxRetry) + done <- true + }() + + // Calculate appropriate timeout based on context timeout + testTimeout := tt.contextTimeout + 1*time.Second + if testTimeout < 1*time.Second { + testTimeout = 1 * time.Second + } + if testTimeout > 2*time.Second { + testTimeout = 2 * time.Second + } + + select { + case <-done: + // Test passed - function returned as expected + if tt.validateResult != nil { + tt.validateResult(g, server) + } + case <-time.After(testTimeout): + t.Fatal("pollerFailedTerminatingPipelines did not return in time") + } + }) + } +} + +func TestPollerFailedCreatingPipelines(t *testing.T) { + tests := []struct { + name string + failedPipelines map[string]pipeline.PipelineVersion + setupMocks func(mockPipelineHandler *mock.MockPipelineHandler, mockLoadBalancer *mock2.MockLoadBalancer, failedPipelines map[string]pipeline.PipelineVersion) + contextTimeout time.Duration + tickDuration time.Duration + validateResult func(g *WithT, server *ChainerServer) + expectGomegaWithT bool + maxRetry uint + }{ + { + name: "should return when context is cancelled", + failedPipelines: make(map[string]pipeline.PipelineVersion), + setupMocks: func(mockPipelineHandler *mock.MockPipelineHandler, mockLoadBalancer *mock2.MockLoadBalancer, failedPipelines map[string]pipeline.PipelineVersion) { + // No expectations - context cancelled before first tick + }, + contextTimeout: 0, // Cancel immediately + tickDuration: 100 * time.Millisecond, + expectGomegaWithT: false, + }, + { + name: "should skip processing when no failed pipelines exist", + failedPipelines: make(map[string]pipeline.PipelineVersion), + setupMocks: func(mockPipelineHandler *mock.MockPipelineHandler, mockLoadBalancer *mock2.MockLoadBalancer, failedPipelines map[string]pipeline.PipelineVersion) { + // No expectations - empty map means no processing + }, + contextTimeout: 150 * time.Millisecond, + tickDuration: 50 * time.Millisecond, + expectGomegaWithT: false, + }, + { + name: "failure - not latest pipeline, remove from list", + failedPipelines: map[string]pipeline.PipelineVersion{ + "test-uid-123_1": { + Name: "test-pipeline", + Version: 1, + UID: "test-uid-123", + }, + }, + setupMocks: func(mockPipelineHandler *mock.MockPipelineHandler, mockLoadBalancer *mock2.MockLoadBalancer, failedPipelines map[string]pipeline.PipelineVersion) { + mockPipelineHandler.EXPECT().IsLatestVersion("test-pipeline", uint32(1), "test-uid-123").Return(false, nil) + }, + contextTimeout: 150 * time.Millisecond, + maxRetry: 1, + tickDuration: 50 * time.Millisecond, + validateResult: func(g *WithT, server *ChainerServer) { + g.Expect(server.failedCreatePipelines).ToNot(HaveKey("test-uid-123_1")) + }, + expectGomegaWithT: true, + }, + { + name: "failure - pipeline not latest, remove from list", + maxRetry: 1, + failedPipelines: map[string]pipeline.PipelineVersion{ + "test-uid-123_1": { + Name: "test-pipeline", + Version: 1, + UID: "test-uid-123", + }, + }, + setupMocks: func(mockPipelineHandler *mock.MockPipelineHandler, mockLoadBalancer *mock2.MockLoadBalancer, failedPipelines map[string]pipeline.PipelineVersion) { + mockPipelineHandler.EXPECT().IsLatestVersion("test-pipeline", uint32(1), "test-uid-123").Return(false, errors.New("some error")) + }, + contextTimeout: 150 * time.Millisecond, + tickDuration: 50 * time.Millisecond, + validateResult: func(g *WithT, server *ChainerServer) { + g.Expect(server.failedCreatePipelines).ToNot(HaveKey("test-uid-123_1")) + }, + expectGomegaWithT: true, + }, + { + name: "should retry creating failed pipeline and remove from list on success", + maxRetry: 1, + failedPipelines: map[string]pipeline.PipelineVersion{ + "test-uid-123_1": { + Name: "test-pipeline", + Version: 1, + UID: "test-uid-123", + }, + }, + setupMocks: func(mockPipelineHandler *mock.MockPipelineHandler, mockLoadBalancer *mock2.MockLoadBalancer, failedPipelines map[string]pipeline.PipelineVersion) { + mockPipelineHandler.EXPECT().IsLatestVersion("test-pipeline", uint32(1), "test-uid-123").Return(true, nil) + + mockPipelineHandler.EXPECT(). + GetPipelineVersion("test-pipeline", uint32(1), "test-uid-123"). + Return(&pipeline.PipelineVersion{ + Name: "test-pipeline", + Version: 1, + UID: "test-uid-123", + State: &pipeline.PipelineState{ + Status: pipeline.PipelineCreating, + }, + }, nil) + + mockPipelineHandler.EXPECT(). + SetPipelineState("test-pipeline", uint32(1), "test-uid-123", pipeline.PipelineCreate, gomock.Any(), util.SourceChainerServer). + Return(nil) + }, + contextTimeout: 150 * time.Millisecond, + tickDuration: 50 * time.Millisecond, + validateResult: func(g *WithT, server *ChainerServer) { + g.Expect(server.failedCreatePipelines).ToNot(HaveKey("test-uid-123_1")) + }, + expectGomegaWithT: true, + }, + { + name: "should remove pipeline from failed list when not found", + maxRetry: 1, + failedPipelines: map[string]pipeline.PipelineVersion{ + "test-uid-123_1": { + Name: "test-pipeline", + Version: 1, + UID: "test-uid-123", + }, + }, + setupMocks: func(mockPipelineHandler *mock.MockPipelineHandler, mockLoadBalancer *mock2.MockLoadBalancer, failedPipelines map[string]pipeline.PipelineVersion) { + mockPipelineHandler.EXPECT().IsLatestVersion("test-pipeline", uint32(1), "test-uid-123").Return(true, nil) + + mockPipelineHandler.EXPECT(). + GetPipelineVersion("test-pipeline", uint32(1), "test-uid-123"). + Return(nil, &pipeline.PipelineNotFoundErr{}) + }, + contextTimeout: 150 * time.Millisecond, + tickDuration: 50 * time.Millisecond, + validateResult: func(g *WithT, server *ChainerServer) { + g.Expect(server.failedCreatePipelines).ToNot(HaveKey("test-uid-123_1")) + }, + expectGomegaWithT: true, + }, + { + name: "should remove pipeline from failed list on UID mismatch", + maxRetry: 1, + failedPipelines: map[string]pipeline.PipelineVersion{ + "test-uid-123_1": { + Name: "test-pipeline", + Version: 1, + UID: "test-uid-123", + }, + }, + setupMocks: func(mockPipelineHandler *mock.MockPipelineHandler, mockLoadBalancer *mock2.MockLoadBalancer, failedPipelines map[string]pipeline.PipelineVersion) { + mockPipelineHandler.EXPECT().IsLatestVersion("test-pipeline", uint32(1), "test-uid-123").Return(true, nil) + + mockPipelineHandler.EXPECT(). + GetPipelineVersion("test-pipeline", uint32(1), "test-uid-123"). + Return(nil, &pipeline.PipelineVersionUidMismatchErr{}). + Times(1) + }, + contextTimeout: 150 * time.Millisecond, + tickDuration: 50 * time.Millisecond, + validateResult: func(g *WithT, server *ChainerServer) { + g.Expect(server.failedCreatePipelines).ToNot(HaveKey("test-uid-123_1")) + }, + expectGomegaWithT: true, + }, + { + name: "should remove pipeline from failed list on version not found", + maxRetry: 1, + failedPipelines: map[string]pipeline.PipelineVersion{ + "test-uid-123_1": { + Name: "test-pipeline", + Version: 1, + UID: "test-uid-123", + }, + }, + setupMocks: func(mockPipelineHandler *mock.MockPipelineHandler, mockLoadBalancer *mock2.MockLoadBalancer, failedPipelines map[string]pipeline.PipelineVersion) { + mockPipelineHandler.EXPECT().IsLatestVersion("test-pipeline", uint32(1), "test-uid-123").Return(true, nil) + + mockPipelineHandler.EXPECT(). + GetPipelineVersion("test-pipeline", uint32(1), "test-uid-123"). + Return(nil, &pipeline.PipelineVersionNotFoundErr{}). + Times(1) + }, + contextTimeout: 150 * time.Millisecond, + tickDuration: 50 * time.Millisecond, + validateResult: func(g *WithT, server *ChainerServer) { + g.Expect(server.failedCreatePipelines).ToNot(HaveKey("test-uid-123_1")) + }, + expectGomegaWithT: true, + }, + { + name: "should keep pipeline in failed list on generic error from rebalancePipeline", + failedPipelines: map[string]pipeline.PipelineVersion{ + "test-uid-123_1": { + Name: "test-pipeline", + Version: 1, + UID: "test-uid-123", + }, + }, + maxRetry: 100, + setupMocks: func(mockPipelineHandler *mock.MockPipelineHandler, mockLoadBalancer *mock2.MockLoadBalancer, failedPipelines map[string]pipeline.PipelineVersion) { + mockPipelineHandler.EXPECT().IsLatestVersion("test-pipeline", uint32(1), "test-uid-123").Return(true, nil).MinTimes(1) + + mockPipelineHandler.EXPECT(). + GetPipelineVersion("test-pipeline", uint32(1), "test-uid-123"). + Return(&pipeline.PipelineVersion{ + Name: "test-pipeline", + Version: 1, + UID: "test-uid-123", + State: &pipeline.PipelineState{ + Status: pipeline.PipelineCreating, + }, + }, nil). + MinTimes(1) + + mockPipelineHandler.EXPECT(). + SetPipelineState("test-pipeline", uint32(1), "test-uid-123", pipeline.PipelineCreate, gomock.Any(), util.SourceChainerServer). + Return(errors.New("failed to set state")). + MinTimes(1) + }, + contextTimeout: 150 * time.Millisecond, + tickDuration: 50 * time.Millisecond, + validateResult: func(g *WithT, server *ChainerServer) { + g.Expect(server.failedCreatePipelines).To(HaveKey("test-uid-123_1")) + }, + expectGomegaWithT: true, + }, + { + name: "max retry reached - remove from retry list", + failedPipelines: map[string]pipeline.PipelineVersion{ + "test-uid-123_1": { + Name: "test-pipeline", + Version: 1, + UID: "test-uid-123", + }, + }, + maxRetry: 1, + setupMocks: func(mockPipelineHandler *mock.MockPipelineHandler, mockLoadBalancer *mock2.MockLoadBalancer, failedPipelines map[string]pipeline.PipelineVersion) { + mockPipelineHandler.EXPECT().IsLatestVersion("test-pipeline", uint32(1), "test-uid-123").Return(true, nil).MinTimes(1) + + mockPipelineHandler.EXPECT(). + GetPipelineVersion("test-pipeline", uint32(1), "test-uid-123"). + Return(nil, errors.New("database connection failed")). + MinTimes(1) + }, + contextTimeout: 150 * time.Millisecond, + tickDuration: 50 * time.Millisecond, + validateResult: func(g *WithT, server *ChainerServer) { + g.Expect(server.failedCreatePipelines).ToNot(HaveKey("test-uid-123_1")) + }, + expectGomegaWithT: true, + }, + { + name: "should keep pipeline in failed list on GetPipelineVersion generic error", + failedPipelines: map[string]pipeline.PipelineVersion{ + "test-uid-123_1": { + Name: "test-pipeline", + Version: 1, + UID: "test-uid-123", + }, + }, + maxRetry: 100, + setupMocks: func(mockPipelineHandler *mock.MockPipelineHandler, mockLoadBalancer *mock2.MockLoadBalancer, failedPipelines map[string]pipeline.PipelineVersion) { + mockPipelineHandler.EXPECT().IsLatestVersion("test-pipeline", uint32(1), "test-uid-123").Return(true, nil).MinTimes(1) + + mockPipelineHandler.EXPECT(). + GetPipelineVersion("test-pipeline", uint32(1), "test-uid-123"). + Return(nil, errors.New("database connection failed")). + MinTimes(1) + }, + contextTimeout: 150 * time.Millisecond, + tickDuration: 50 * time.Millisecond, + validateResult: func(g *WithT, server *ChainerServer) { + g.Expect(server.failedCreatePipelines).To(HaveKey("test-uid-123_1")) + }, + expectGomegaWithT: true, + }, + { + name: "should process multiple failed pipelines", + maxRetry: 1, + failedPipelines: map[string]pipeline.PipelineVersion{ + "uid-1_1": { + Name: "pipeline-1", + Version: 1, + UID: "uid-1", + }, + "uid-2_1": { + Name: "pipeline-2", + Version: 1, + UID: "uid-2", + }, + }, + setupMocks: func(mockPipelineHandler *mock.MockPipelineHandler, mockLoadBalancer *mock2.MockLoadBalancer, failedPipelines map[string]pipeline.PipelineVersion) { + mockPipelineHandler.EXPECT().IsLatestVersion("pipeline-1", uint32(1), "uid-1").Return(true, nil).MinTimes(1) + mockPipelineHandler.EXPECT().IsLatestVersion("pipeline-2", uint32(1), "uid-2").Return(true, nil).MinTimes(1) + + mockPipelineHandler.EXPECT(). + GetPipelineVersion("pipeline-1", uint32(1), "uid-1"). + Return(&pipeline.PipelineVersion{ + Name: "pipeline-1", + Version: 1, + UID: "uid-1", + State: &pipeline.PipelineState{ + Status: pipeline.PipelineCreating, + }, + }, nil) + + mockPipelineHandler.EXPECT(). + GetPipelineVersion("pipeline-2", uint32(1), "uid-2"). + Return(&pipeline.PipelineVersion{ + Name: "pipeline-2", + Version: 1, + UID: "uid-2", + State: &pipeline.PipelineState{ + Status: pipeline.PipelineCreating, + }, + }, nil) + + mockPipelineHandler.EXPECT(). + SetPipelineState(gomock.Any(), gomock.Any(), gomock.Any(), pipeline.PipelineCreate, gomock.Any(), util.SourceChainerServer). + Return(nil). + Times(2) + }, + contextTimeout: 150 * time.Millisecond, + tickDuration: 50 * time.Millisecond, + validateResult: func(g *WithT, server *ChainerServer) { + g.Expect(server.failedCreatePipelines).ToNot(HaveKey("uid-1_1")) + g.Expect(server.failedCreatePipelines).ToNot(HaveKey("uid-2_1")) + }, + expectGomegaWithT: true, + }, + { + name: "should process mixed success and failure scenarios", + failedPipelines: map[string]pipeline.PipelineVersion{ + "uid-success_1": { + Name: "pipeline-success", + Version: 1, + UID: "uid-success", + }, + "uid-fail_1": { + Name: "pipeline-fail", + Version: 1, + UID: "uid-fail", + }, + "uid-notfound_1": { + Name: "pipeline-notfound", + Version: 1, + UID: "uid-notfound", + }, + }, + maxRetry: 100, + setupMocks: func(mockPipelineHandler *mock.MockPipelineHandler, mockLoadBalancer *mock2.MockLoadBalancer, failedPipelines map[string]pipeline.PipelineVersion) { + mockPipelineHandler.EXPECT().IsLatestVersion("pipeline-success", uint32(1), "uid-success").Return(true, nil).MinTimes(1) + mockPipelineHandler.EXPECT().IsLatestVersion("pipeline-fail", uint32(1), "uid-fail").Return(true, nil).MinTimes(1) + mockPipelineHandler.EXPECT().IsLatestVersion("pipeline-notfound", uint32(1), "uid-notfound").Return(true, nil).MinTimes(1) + + // Success case + mockPipelineHandler.EXPECT(). + GetPipelineVersion("pipeline-success", uint32(1), "uid-success"). + Return(&pipeline.PipelineVersion{ + Name: "pipeline-success", + Version: 1, + UID: "uid-success", + State: &pipeline.PipelineState{Status: pipeline.PipelineCreating}, + }, nil). + MinTimes(1) + + mockPipelineHandler.EXPECT(). + SetPipelineState("pipeline-success", uint32(1), "uid-success", pipeline.PipelineCreate, gomock.Any(), util.SourceChainerServer). + Return(nil). + MinTimes(1) + + // Failure case + mockPipelineHandler.EXPECT(). + GetPipelineVersion("pipeline-fail", uint32(1), "uid-fail"). + Return(&pipeline.PipelineVersion{ + Name: "pipeline-fail", + Version: 1, + UID: "uid-fail", + State: &pipeline.PipelineState{Status: pipeline.PipelineCreating}, + }, nil). + MinTimes(1) + + mockPipelineHandler.EXPECT(). + SetPipelineState("pipeline-fail", uint32(1), "uid-fail", pipeline.PipelineCreate, gomock.Any(), util.SourceChainerServer). + Return(errors.New("state update failed")). + MinTimes(1) + + // Not found case + mockPipelineHandler.EXPECT(). + GetPipelineVersion("pipeline-notfound", uint32(1), "uid-notfound"). + Return(nil, &pipeline.PipelineNotFoundErr{}). + MinTimes(1) + }, + contextTimeout: 150 * time.Millisecond, + tickDuration: 50 * time.Millisecond, + validateResult: func(g *WithT, server *ChainerServer) { + // Success and notfound should be removed + g.Expect(server.failedCreatePipelines).ToNot(HaveKey("uid-success_1")) + g.Expect(server.failedCreatePipelines).ToNot(HaveKey("uid-notfound_1")) + // Failure should remain + g.Expect(server.failedCreatePipelines).To(HaveKey("uid-fail_1")) + }, + expectGomegaWithT: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var g *WithT + if tt.expectGomegaWithT { + g = NewGomegaWithT(t) + } + + ctrl := gomock.NewController(t) + + mockPipelineHandler := mock.NewMockPipelineHandler(ctrl) + mockLoadBalancer := mock2.NewMockLoadBalancer(ctrl) + + // Setup mocks for this test case + if tt.setupMocks != nil { + tt.setupMocks(mockPipelineHandler, mockLoadBalancer, tt.failedPipelines) + } + + server := &ChainerServer{ + logger: log.New(), + pipelineHandler: mockPipelineHandler, + loadBalancer: mockLoadBalancer, + failedCreatePipelines: tt.failedPipelines, + streams: make(map[string]*ChainerSubscription), + retriedFailedPipelines: make(map[string]uint), + } + + var ctx context.Context + var cancel context.CancelFunc + + if tt.contextTimeout == 0 { + // Cancel immediately + ctx, cancel = context.WithCancel(context.Background()) + cancel() + } else { + ctx, cancel = context.WithTimeout(context.Background(), tt.contextTimeout) + defer cancel() + } + + done := make(chan bool) + go func() { + server.pollerFailedCreatingPipelines(ctx, tt.tickDuration, tt.maxRetry) + done <- true + }() + + select { + case <-done: + // Test passed - function returned as expected + if tt.validateResult != nil { + tt.validateResult(g, server) + } + case <-time.After(tt.contextTimeout + 1*time.Second): + t.Fatal("pollerFailedCreatingPipelines did not return in time") + } + }) + } +} + func TestCreateTopicSources(t *testing.T) { g := NewGomegaWithT(t) @@ -768,7 +1578,7 @@ func TestPipelineSubscribe(t *testing.T) { t.Fatal(err) } go func() { - _ = s.StartGrpcServer(uint(port)) + _ = s.StartGrpcServer(context.Background(), time.Minute, time.Minute, 1, uint(port)) }() time.Sleep(100 * time.Millisecond) diff --git a/scheduler/pkg/kafka/gateway/infer.go b/scheduler/pkg/kafka/gateway/infer.go index 98ab655e80..3956054c37 100644 --- a/scheduler/pkg/kafka/gateway/infer.go +++ b/scheduler/pkg/kafka/gateway/infer.go @@ -51,8 +51,8 @@ const ( type InferKafkaHandler struct { logger log.FieldLogger mu sync.RWMutex - loadedModels map[string]bool - subscribedTopics map[string]bool + loadedModels map[string]struct{} + subscribedTopics map[string]struct{} workers []*InferWorker consumer *kafka.Consumer producer *kafka.Producer @@ -132,8 +132,8 @@ func NewInferKafkaHandler( done: make(chan bool), tracer: consumerConfig.TraceProvider.GetTraceProvider().Tracer("Worker"), topicNamer: topicNamer, - loadedModels: make(map[string]bool), - subscribedTopics: make(map[string]bool), + loadedModels: make(map[string]struct{}), + subscribedTopics: make(map[string]struct{}), shutdownComplete: make(chan struct{}), consumerConfig: consumerConfig, consumerName: consumerName, @@ -366,21 +366,22 @@ func (kc *InferKafkaHandler) ensureTopicsExist(topicNames []string) error { func (kc *InferKafkaHandler) AddModel(modelName string) error { kc.mu.Lock() defer kc.mu.Unlock() - kc.loadedModels[modelName] = true // create topics inputTopic := kc.topicNamer.GetModelTopicInputs(modelName) outputTopic := kc.topicNamer.GetModelTopicOutputs(modelName) if err := kc.createTopics([]string{inputTopic, outputTopic}); err != nil { - return err + return fmt.Errorf("failed to create topics for model %s: %w", modelName, err) } - kc.subscribedTopics[inputTopic] = true + kc.subscribedTopics[inputTopic] = struct{}{} err := kc.subscribeTopics() if err != nil { - kc.logger.WithError(err).Errorf("failed to subscribe to topics") - return nil + kc.logger.WithError(err).Errorf("Failed to subscribe to topics") + return fmt.Errorf("failed to subscribe to topics: %w", err) } + + kc.loadedModels[modelName] = struct{}{} return nil } diff --git a/scheduler/pkg/server/control_plane_test.go b/scheduler/pkg/server/control_plane_test.go index 53502703fa..bd49f2a699 100644 --- a/scheduler/pkg/server/control_plane_test.go +++ b/scheduler/pkg/server/control_plane_test.go @@ -167,7 +167,9 @@ func TestSubscribeControlPlane(t *testing.T) { t.Fatal(err) } - err = server.startServer(uint(port), false) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + err = server.startServer(ctx, uint(port), false, time.Minute, time.Minute, 1) if err != nil { t.Fatal(err) } diff --git a/scheduler/pkg/server/pipeline_status.go b/scheduler/pkg/server/pipeline_status.go index d0dcd14192..bc5b04bf20 100644 --- a/scheduler/pkg/server/pipeline_status.go +++ b/scheduler/pkg/server/pipeline_status.go @@ -11,11 +11,13 @@ package server import ( "context" + "fmt" + "time" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - chainer "github.com/seldonio/seldon-core/apis/go/v2/mlops/chainer" + "github.com/seldonio/seldon-core/apis/go/v2/mlops/chainer" pb "github.com/seldonio/seldon-core/apis/go/v2/mlops/scheduler" "github.com/seldonio/seldon-core/scheduler/v2/pkg/coordinator" @@ -28,7 +30,73 @@ const ( addPipelineStreamEventSource = "pipeline.store.addpipelinestream" ) -func (s *SchedulerServer) PipelineStatusEvent(ctx context.Context, message *chainer.PipelineUpdateStatusMessage) (*chainer.PipelineUpdateStatusResponse, error) { +// pollerRetryFailedCreatePipelines will retry creating pipelines on pipeline-gw which failed to load. +func (s *SchedulerServer) pollerRetryFailedCreatePipelines(ctx context.Context, tick time.Duration, maxRetry uint) { + s.pollerRetryFailedPipelines(ctx, tick, "pollerRetryFailedCreatePipelines", pipeline.PipelineFailed, "create", maxRetry) +} + +// pollerRetryFailedDeletePipelines will retry deleting pipelines on pipeline-gw which failed to terminate. +func (s *SchedulerServer) pollerRetryFailedDeletePipelines(ctx context.Context, tick time.Duration, maxRetry uint) { + s.pollerRetryFailedPipelines(ctx, tick, "pollerRetryFailedDeletePipelines", pipeline.PipelineFailedTerminating, "delete", maxRetry) +} + +func (s *SchedulerServer) pollerRetryFailedPipelines(ctx context.Context, tick time.Duration, funcName string, targetStatus pipeline.PipelineStatus, operation string, maxRetry uint) { + logger := s.logger.WithField("func", funcName) + ticker := time.NewTicker(tick) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + logger.Debugf("Poller retry failed %s pipelines on pipeline-gw", operation) + s.pipelineEventStream.mu.Lock() + pipelines := s.pipelineHandler.GetPipelinesPipelineGwStatus(targetStatus) + + if len(pipelines) == 0 { + logger.Debug("No failed pipelines found") + s.pipelineEventStream.mu.Unlock() + continue + } + + filteredPipelines := pipelines[:0] + s.muRetriedFailedPipelines.Lock() + for _, p := range pipelines { + key := s.mkPipelineKey(p.UID, p.PipelineVersion) + s.retriedFailedPipelines[key]++ + if s.retriedFailedPipelines[key] > maxRetry { + logger.Debugf("Retry failed %s pipeline %s, reached max retries", operation, p.PipelineName) + continue + } + filteredPipelines = append(filteredPipelines, p) + } + s.muRetriedFailedPipelines.Unlock() + + logger.WithField("pipelines", filteredPipelines).Debug("Found failed pipelines") + s.pipelineGwRebalancePipelines(filteredPipelines) + s.pipelineEventStream.mu.Unlock() + } + } +} + +func (s *SchedulerServer) mkPipelineKey(uid string, version uint32) string { + return fmt.Sprintf("%s_%d", uid, version) +} + +func (s *SchedulerServer) resetPipelineRetryCount(msg *chainer.PipelineUpdateMessage) { + s.muRetriedFailedPipelines.Lock() + defer s.muRetriedFailedPipelines.Unlock() + s.retriedFailedPipelines[s.mkPipelineKey(msg.Uid, msg.Version)] = 0 +} + +func (s *SchedulerServer) removePipelineRetryCount(msg *chainer.PipelineUpdateMessage) { + s.muRetriedFailedPipelines.Lock() + defer s.muRetriedFailedPipelines.Unlock() + delete(s.retriedFailedPipelines, s.mkPipelineKey(msg.Uid, msg.Version)) +} + +func (s *SchedulerServer) PipelineStatusEvent(_ context.Context, message *chainer.PipelineUpdateStatusMessage) (*chainer.PipelineUpdateStatusResponse, error) { s.pipelineEventStream.mu.Lock() defer s.pipelineEventStream.mu.Unlock() @@ -39,15 +107,17 @@ func (s *SchedulerServer) PipelineStatusEvent(ctx context.Context, message *chai switch message.Update.Op { case chainer.PipelineUpdateMessage_Create: if message.Success { + s.resetPipelineRetryCount(message.Update) statusVal = pipeline.PipelineReady } else { statusVal = pipeline.PipelineFailed } case chainer.PipelineUpdateMessage_Delete: if message.Success { + s.removePipelineRetryCount(message.Update) statusVal = pipeline.PipelineTerminated } else { - statusVal = pipeline.PipelineFailed + statusVal = pipeline.PipelineFailedTerminating } } @@ -197,7 +267,10 @@ func (s *SchedulerServer) createPipelineCreationMessage(pv *pipeline.PipelineVer func (s *SchedulerServer) pipelineGwRebalance() { s.pipelineEventStream.mu.Lock() defer s.pipelineEventStream.mu.Unlock() + s.pipelineGwRebalancePipelines(s.pipelineHandler.GetAllPipelineGwRunningPipelineVersions()) +} +func (s *SchedulerServer) pipelineGwRebalancePipelines(pipelines []coordinator.PipelineEventMsg) { // get only the pipeline gateway streams streams := []*PipelineSubscription{} for _, subscription := range s.pipelineEventStream.streams { @@ -206,8 +279,7 @@ func (s *SchedulerServer) pipelineGwRebalance() { } } - evts := s.pipelineHandler.GetAllPipelineGwRunningPipelineVersions() - for _, event := range evts { + for _, event := range pipelines { pv, err := s.pipelineHandler.GetPipelineVersion(event.PipelineName, event.PipelineVersion, event.UID) if err != nil { s.logger.WithError(err).Errorf("Failed to get pipeline version for %s:%d (%s)", event.PipelineName, event.PipelineVersion, event.UID) @@ -235,7 +307,7 @@ func (s *SchedulerServer) pipelineGWRebalanceNoStreams(pv *pipeline.PipelineVers ) pipelineState := pipeline.PipelineCreate - if pv.State.PipelineGwStatus == pipeline.PipelineTerminating { + if pv.State.PipelineGwStatus == pipeline.PipelineTerminating || pv.State.PipelineGwStatus == pipeline.PipelineFailedTerminating { // since there are no streams, we can directly set the state to terminated pipelineState = pipeline.PipelineTerminated } @@ -301,8 +373,8 @@ func (s *SchedulerServer) pipelineGwRebalanceStreams( s.logger.Debug("pipeline-gateway replica contains pipeline, sending status update for ", server) var msg *pb.PipelineStatusResponse - if pv.State.PipelineGwStatus == pipeline.PipelineTerminating { - s.logger.Debugf("Pipeline %s is terminating, sending deletion message", pv.Name) + if pv.State.PipelineGwStatus == pipeline.PipelineTerminating || pv.State.PipelineGwStatus == pipeline.PipelineFailedTerminating { + s.logger.Debugf("Pipeline %s in state %s, sending deletion message", pv.Name, pv.State.PipelineGwStatus) msg = s.createPipelineDeletionMessage(pv) } else { s.logger.Debugf("Pipeline %s is available or progressing, sending creation message", pv.Name) diff --git a/scheduler/pkg/server/pipeline_status_test.go b/scheduler/pkg/server/pipeline_status_test.go index 89e5f028af..9a8c4ad9f5 100644 --- a/scheduler/pkg/server/pipeline_status_test.go +++ b/scheduler/pkg/server/pipeline_status_test.go @@ -17,14 +17,196 @@ import ( . "github.com/onsi/gomega" log "github.com/sirupsen/logrus" + "go.uber.org/mock/gomock" pb "github.com/seldonio/seldon-core/apis/go/v2/mlops/scheduler" "github.com/seldonio/seldon-core/scheduler/v2/pkg/coordinator" "github.com/seldonio/seldon-core/scheduler/v2/pkg/store/pipeline" + "github.com/seldonio/seldon-core/scheduler/v2/pkg/store/pipeline/mock" "github.com/seldonio/seldon-core/scheduler/v2/pkg/util" ) +func TestPollerRetryFailedPipelines(t *testing.T) { + tests := []struct { + name string + funcName string + targetStatus pipeline.PipelineStatus + operation string + failedPipelines []coordinator.PipelineEventMsg + setupMocks func(mockPipelineHandler *mock.MockPipelineHandler, failedPipelines []coordinator.PipelineEventMsg) + tickCount int + contextTimeout time.Duration + tickDuration time.Duration + validateBehavior func(g *WithT, mockPipelineHandler *mock.MockPipelineHandler) + maxRetries uint + }{ + { + name: "context cancelled immediately", + funcName: "testFunc", + targetStatus: pipeline.PipelineFailed, + operation: "create", + failedPipelines: []coordinator.PipelineEventMsg{}, + setupMocks: func(mockPipelineHandler *mock.MockPipelineHandler, failedPipelines []coordinator.PipelineEventMsg) { + // No expectations - context cancelled before first tick + }, + contextTimeout: 0, // Cancel immediately + tickDuration: 100 * time.Millisecond, + }, + { + name: "no failed pipelines found", + funcName: "testFunc", + targetStatus: pipeline.PipelineFailed, + operation: "create", + failedPipelines: []coordinator.PipelineEventMsg{}, + setupMocks: func(mockPipelineHandler *mock.MockPipelineHandler, failedPipelines []coordinator.PipelineEventMsg) { + mockPipelineHandler.EXPECT(). + GetPipelinesPipelineGwStatus(pipeline.PipelineFailed). + Return([]coordinator.PipelineEventMsg{}). + MinTimes(1) + }, + contextTimeout: 150 * time.Millisecond, + tickDuration: 50 * time.Millisecond, + }, + { + name: "single failed create pipeline", + funcName: "pollerRetryFailedCreatePipelines", + targetStatus: pipeline.PipelineFailed, + operation: "create", + maxRetries: 1, + failedPipelines: []coordinator.PipelineEventMsg{ + { + PipelineName: "test-pipeline", + PipelineVersion: 1, + UID: "uid-1", + }, + }, + setupMocks: func(mockPipelineHandler *mock.MockPipelineHandler, failedPipelines []coordinator.PipelineEventMsg) { + mockPipelineHandler.EXPECT(). + GetPipelinesPipelineGwStatus(pipeline.PipelineFailed). + Return(failedPipelines) + + mockPipelineHandler.EXPECT(). + GetPipelinesPipelineGwStatus(pipeline.PipelineFailed). + Return([]coordinator.PipelineEventMsg{}) + + mockPipelineHandler.EXPECT().GetPipelineVersion(failedPipelines[0].PipelineName, + failedPipelines[0].PipelineVersion, failedPipelines[0].UID).Return(&pipeline.PipelineVersion{ + Name: failedPipelines[0].PipelineName, + Version: failedPipelines[0].PipelineVersion, + UID: failedPipelines[0].UID, + State: &pipeline.PipelineState{ + PipelineGwStatus: pipeline.PipelineFailed, + }, + }, nil) + + mockPipelineHandler.EXPECT().SetPipelineGwPipelineState( + failedPipelines[0].PipelineName, + failedPipelines[0].PipelineVersion, + failedPipelines[0].UID, pipeline.PipelineCreate, + "No pipeline gateway available to handle pipeline", util.SourcePipelineStatusEvent).Return(nil) + }, + contextTimeout: 100 * time.Millisecond, + tickDuration: 50 * time.Millisecond, + }, + { + name: "single failed delete pipeline", + funcName: "pollerRetryFailedDeletePipelines", + targetStatus: pipeline.PipelineFailedTerminating, + operation: "delete", + maxRetries: 1, + failedPipelines: []coordinator.PipelineEventMsg{ + { + PipelineName: "test-pipeline", + PipelineVersion: 1, + UID: "uid-1", + }, + }, + setupMocks: func(mockPipelineHandler *mock.MockPipelineHandler, failedPipelines []coordinator.PipelineEventMsg) { + mockPipelineHandler.EXPECT(). + GetPipelinesPipelineGwStatus(pipeline.PipelineFailedTerminating). + Return(failedPipelines). + Times(1) + + mockPipelineHandler.EXPECT(). + GetPipelinesPipelineGwStatus(pipeline.PipelineFailedTerminating). + Return([]coordinator.PipelineEventMsg{}) + + mockPipelineHandler.EXPECT().GetPipelineVersion(failedPipelines[0].PipelineName, + failedPipelines[0].PipelineVersion, failedPipelines[0].UID).Return(&pipeline.PipelineVersion{ + Name: failedPipelines[0].PipelineName, + Version: failedPipelines[0].PipelineVersion, + UID: failedPipelines[0].UID, + State: &pipeline.PipelineState{ + PipelineGwStatus: pipeline.PipelineFailedTerminating, + }, + }, nil) + + mockPipelineHandler.EXPECT().SetPipelineGwPipelineState( + failedPipelines[0].PipelineName, + failedPipelines[0].PipelineVersion, + failedPipelines[0].UID, pipeline.PipelineTerminated, + "No pipeline gateway available to handle pipeline", util.SourcePipelineStatusEvent).Return(nil) + }, + contextTimeout: 100 * time.Millisecond, + tickDuration: 50 * time.Millisecond, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewGomegaWithT(t) + ctrl := gomock.NewController(t) + + mockPipelineHandler := mock.NewMockPipelineHandler(ctrl) + + if tt.setupMocks != nil { + tt.setupMocks(mockPipelineHandler, tt.failedPipelines) + } + + eventHub, err := coordinator.NewEventHub(log.New()) + g.Expect(err).Should(BeNil()) + + server := &SchedulerServer{ + logger: log.New(), + pipelineHandler: mockPipelineHandler, + eventHub: eventHub, + retriedFailedPipelines: map[string]uint{}, + } + + var ctx context.Context + var cancel context.CancelFunc + + if tt.contextTimeout == 0 { + // Cancel immediately + ctx, cancel = context.WithCancel(context.Background()) + cancel() + } else { + ctx, cancel = context.WithTimeout(context.Background(), tt.contextTimeout) + defer cancel() + } + + done := make(chan bool) + go func() { + server.pollerRetryFailedPipelines(ctx, tt.tickDuration, tt.funcName, tt.targetStatus, tt.operation, tt.maxRetries) + done <- true + }() + + select { + case <-done: + // Test passed - function returned as expected + case <-time.After(tt.contextTimeout + 1*time.Second): + t.Fatal("pollerRetryFailedPipelines did not return in time") + } + + // Custom validation if provided + if tt.validateBehavior != nil { + tt.validateBehavior(g, mockPipelineHandler) + } + }) + } +} + func receiveMessageFromPipelineStream( t *testing.T, stream *stubPipelineStatusServer, ) *pb.PipelineStatusResponse { diff --git a/scheduler/pkg/server/server.go b/scheduler/pkg/server/server.go index 9ebc508927..66aface4f8 100644 --- a/scheduler/pkg/server/server.go +++ b/scheduler/pkg/server/server.go @@ -63,29 +63,37 @@ var ErrAddServerEmptyServerName = status.Errorf(codes.FailedPrecondition, "Empty type SchedulerServer struct { pb.UnimplementedSchedulerServer health.UnimplementedHealthCheckServiceServer - logger log.FieldLogger - modelStore store.ModelStore - experimentServer experiment.ExperimentServer - pipelineHandler pipeline.PipelineHandler - scheduler scheduler2.Scheduler - modelEventStream ModelEventStream - serverEventStream ServerEventStream - experimentEventStream ExperimentEventStream - pipelineEventStream PipelineEventStream - controlPlaneStream ControlPlaneStream - timeout time.Duration - synchroniser synchroniser.Synchroniser - config SchedulerServerConfig - modelGwLoadBalancer *util.RingLoadBalancer - pipelineGWLoadBalancer *util.RingLoadBalancer - scalingConfigUpdates chan scaling_config.ScalingConfig - currentScalingConfig *scaling_config.ScalingConfig - mu sync.Mutex - done chan struct{} - grpcServer *grpc.Server - consumerGroupConfig *ConsumerGroupConfig - eventHub *coordinator.EventHub - tlsOptions seldontls.TLSOptions + logger log.FieldLogger + modelStore store.ModelStore + experimentServer experiment.ExperimentServer + pipelineHandler pipeline.PipelineHandler + scheduler scheduler2.Scheduler + modelEventStream ModelEventStream + serverEventStream ServerEventStream + experimentEventStream ExperimentEventStream + pipelineEventStream PipelineEventStream + controlPlaneStream ControlPlaneStream + timeout time.Duration + synchroniser synchroniser.Synchroniser + config SchedulerServerConfig + modelGwLoadBalancer *util.RingLoadBalancer + pipelineGWLoadBalancer *util.RingLoadBalancer + scalingConfigUpdates chan scaling_config.ScalingConfig + currentScalingConfig *scaling_config.ScalingConfig + mu sync.Mutex + done chan struct{} + grpcServer *grpc.Server + consumerGroupConfig *ConsumerGroupConfig + eventHub *coordinator.EventHub + tlsOptions seldontls.TLSOptions + muRetriedFailedPipelines sync.Mutex + // TODO this would ideally be stored within the pipeline handler, as now we have to + // retrieve the pipeline from the memory store even if it has reached max retires + // retriedFailedPipelines keyed off pipeline UID + version, value is retried count + retriedFailedPipelines map[string]uint + muRetriedFailedModels sync.Mutex + // retriedFailedModels keyed off model name, value is retried count + retriedFailedModels map[string]uint } type SchedulerServerConfig struct { @@ -183,7 +191,7 @@ func NewConsumerGroupConfig(namespace, consumerGroupIdPrefix string, modelGatewa } } -func (s *SchedulerServer) startServer(port uint, secure bool) error { +func (s *SchedulerServer) startServer(ctx context.Context, port uint, secure bool, pollerTickCreate, pollerTickDelete time.Duration, maxRetry uint) error { logger := s.logger.WithField("func", "startServer") lis, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) if err != nil { @@ -209,17 +217,26 @@ func (s *SchedulerServer) startServer(port uint, secure bool) error { err := grpcServer.Serve(lis) logger.WithError(err).Fatalf("Scheduler mTLS server failed on port %d mtls:%v", port, secure) }() + + s.startPollers(ctx, pollerTickCreate, pollerTickDelete, maxRetry) return nil } -func (s *SchedulerServer) StartGrpcServers(allowPlainTxt bool, schedulerPort uint, schedulerTlsPort uint) error { +func (s *SchedulerServer) startPollers(ctx context.Context, pollerTickCreate, pollerTickDelete time.Duration, maxRetry uint) { + go s.pollerRetryFailedCreateModels(ctx, pollerTickCreate, maxRetry) + go s.pollerRetryFailedDeleteModels(ctx, pollerTickDelete, maxRetry) + go s.pollerRetryFailedCreatePipelines(ctx, pollerTickCreate, maxRetry) + go s.pollerRetryFailedDeletePipelines(ctx, pollerTickDelete, maxRetry) +} + +func (s *SchedulerServer) StartGrpcServers(ctx context.Context, allowPlainTxt bool, schedulerPort uint, schedulerTlsPort uint, pollerTickCreate, pollerTickDelete time.Duration, maxRetry uint) error { logger := s.logger.WithField("func", "StartGrpcServers") if !allowPlainTxt && s.tlsOptions.Cert == nil { return fmt.Errorf("one of plain txt or mTLS needs to be defined. But have plain text [%v] and no TLS", allowPlainTxt) } if allowPlainTxt { - err := s.startServer(schedulerPort, false) + err := s.startServer(ctx, schedulerPort, false, pollerTickCreate, pollerTickDelete, maxRetry) if err != nil { return err } @@ -227,7 +244,7 @@ func (s *SchedulerServer) StartGrpcServers(allowPlainTxt bool, schedulerPort uin logger.Info("Not starting scheduler plain text server") } if s.tlsOptions.Cert != nil { - err := s.startServer(schedulerTlsPort, true) + err := s.startServer(ctx, schedulerTlsPort, true, pollerTickCreate, pollerTickDelete, maxRetry) if err != nil { return err } @@ -314,6 +331,8 @@ func NewSchedulerServer( consumerGroupConfig: consumerGroupConfig, eventHub: eventHub, tlsOptions: tlsOptions, + retriedFailedModels: make(map[string]uint), + retriedFailedPipelines: make(map[string]uint), } eventHub.RegisterModelEventHandler( diff --git a/scheduler/pkg/server/server_status.go b/scheduler/pkg/server/server_status.go index 8c3819ae52..a335cfed5a 100644 --- a/scheduler/pkg/server/server_status.go +++ b/scheduler/pkg/server/server_status.go @@ -11,8 +11,11 @@ package server import ( "context" + "fmt" "time" + "github.com/sirupsen/logrus" + pb "github.com/seldonio/seldon-core/apis/go/v2/mlops/scheduler" "github.com/seldonio/seldon-core/scheduler/v2/pkg/coordinator" @@ -25,7 +28,94 @@ const ( modelStatusEventSource = "model-status-server" ) -func (s *SchedulerServer) ModelStatusEvent(ctx context.Context, message *pb.ModelUpdateStatusMessage) (*pb.ModelUpdateStatusResponse, error) { +// pollerRetryFailedCreateModels will retry creating models on model-gw which failed to load. Most likely +// due to connectivity issues with kafka. +func (s *SchedulerServer) pollerRetryFailedCreateModels(ctx context.Context, tick time.Duration, maxRetry uint) { + s.pollerRetryFailedModels(ctx, tick, "pollerRetryFailedCreateModels", store.ModelFailed, "create", maxRetry) +} + +// pollerRetryFailedDeleteModels will retry deleting models on model-gw which failed to terminate. Most likely +// due to connectivity issues with kafka. +func (s *SchedulerServer) pollerRetryFailedDeleteModels(ctx context.Context, tick time.Duration, maxRetry uint) { + s.pollerRetryFailedModels(ctx, tick, "pollerRetryFailedDeleteModels", store.ModelTerminateFailed, "delete", maxRetry) +} + +func (s *SchedulerServer) pollerRetryFailedModels(ctx context.Context, tick time.Duration, funcName string, targetState store.ModelState, operation string, maxRetry uint) { + logger := s.logger.WithField("func", funcName) + ticker := time.NewTicker(tick) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + models := s.getModelsInGwRetryState(logger, targetState, operation, maxRetry) + if len(models) > 0 { + s.modelGwRebalanceForModels(models) + } + } + } +} + +func (s *SchedulerServer) mkModelRetryKey(modelName string, version uint32) string { + return fmt.Sprintf("%s_%d", modelName, version) +} + +func (s *SchedulerServer) getModelsInGwRetryState(logger *logrus.Entry, targetState store.ModelState, operation string, maxRetry uint) []*store.ModelSnapshot { + modelNames := s.modelStore.GetAllModels() + logger.WithField("models", modelNames).Debugf("Poller retry to %s failed models on model-gw", operation) + + models := make([]*store.ModelSnapshot, 0) + + for _, modelName := range modelNames { + model, err := s.modelStore.GetModel(modelName) + if err != nil { + logger.WithError(err).Errorf("Failed to get model %s", modelName) + continue + } + + if model.GetLatest() == nil { + logger.Warnf("Model %s has no versions, skipping", modelName) + continue + } + + modelGwState := model.GetLatest().ModelState().ModelGwState + if modelGwState != targetState { + logger.Debugf("Model-gw model %s state %s != %s, skipping", modelName, modelGwState, targetState) + continue + } + + key := s.mkModelRetryKey(model.Name, model.GetLatest().GetVersion()) + s.muRetriedFailedModels.Lock() + s.retriedFailedModels[key]++ + if s.retriedFailedModels[key] > maxRetry { + s.muRetriedFailedModels.Unlock() + logger.Debugf("Model-gw model %s retry failed, max retries reached", modelName) + continue + } + s.muRetriedFailedModels.Unlock() + + logger.Infof("Model-gw model %s in %s state, retrying %s on model-gw", modelName, targetState, operation) + models = append(models, model) + } + + return models +} + +func (s *SchedulerServer) resetModelRetryCount(msg *pb.ModelUpdateMessage) { + s.muRetriedFailedModels.Lock() + defer s.muRetriedFailedModels.Unlock() + s.retriedFailedModels[s.mkModelRetryKey(msg.Model, msg.Version)] = 0 +} + +func (s *SchedulerServer) removeModelRetryCount(msg *pb.ModelUpdateMessage) { + s.muRetriedFailedModels.Lock() + defer s.muRetriedFailedModels.Unlock() + delete(s.retriedFailedModels, s.mkModelRetryKey(msg.Model, msg.Version)) +} + +func (s *SchedulerServer) ModelStatusEvent(_ context.Context, message *pb.ModelUpdateStatusMessage) (*pb.ModelUpdateStatusResponse, error) { s.modelEventStream.mu.Lock() defer s.modelEventStream.mu.Unlock() @@ -35,12 +125,14 @@ func (s *SchedulerServer) ModelStatusEvent(ctx context.Context, message *pb.Mode switch message.Update.Op { case pb.ModelUpdateMessage_Create: if message.Success { + s.resetModelRetryCount(message.Update) statusVal = store.ModelAvailable } else { statusVal = store.ModelFailed } case pb.ModelUpdateMessage_Delete: if message.Success { + s.removeModelRetryCount(message.Update) statusVal = store.ModelTerminated } else { statusVal = store.ModelTerminateFailed @@ -175,10 +267,19 @@ func contains(slice []string, val string) bool { return false } -func (s *SchedulerServer) GetAllRunningModels() []*store.ModelSnapshot { - var runningModels []*store.ModelSnapshot +func (s *SchedulerServer) allPermittedModels() []*store.ModelSnapshot { + var permittedModels []*store.ModelSnapshot modelNames := s.modelStore.GetAllModels() + allowedModelGwStates := map[store.ModelState]struct{}{ + store.ModelCreate: {}, + store.ModelProgressing: {}, + store.ModelAvailable: {}, + store.ModelTerminating: {}, + // we want to retry models which failed to create on model-gw i.e. likely kafka connectivity issues + store.ModelFailed: {}, + } + for _, modelName := range modelNames { model, err := s.modelStore.GetModel(modelName) if err != nil { @@ -190,19 +291,12 @@ func (s *SchedulerServer) GetAllRunningModels() []*store.ModelSnapshot { continue } - modelState := model.GetLatest().ModelState() - runningStates := map[store.ModelState]struct{}{ - store.ModelCreate: {}, - store.ModelProgressing: {}, - store.ModelAvailable: {}, - store.ModelTerminating: {}, - } - - if _, ok := runningStates[modelState.ModelGwState]; ok { - runningModels = append(runningModels, model) + if _, ok := allowedModelGwStates[model.GetLatest().ModelState().ModelGwState]; ok { + permittedModels = append(permittedModels, model) } } - return runningModels + + return permittedModels } func (s *SchedulerServer) createModelDeletionMessage(model *store.ModelSnapshot, keepTopics bool) (*pb.ModelStatusResponse, error) { @@ -225,12 +319,15 @@ func (s *SchedulerServer) createModelCreationMessage(model *store.ModelSnapshot) } func (s *SchedulerServer) modelGwRebalance() { + runningModels := s.allPermittedModels() + s.logger.Debugf("Rebalancing model gateways for running models: %v", runningModels) + s.modelGwRebalanceForModels(runningModels) +} + +func (s *SchedulerServer) modelGwRebalanceForModels(models []*store.ModelSnapshot) { s.modelEventStream.mu.Lock() defer s.modelEventStream.mu.Unlock() - runningModels := s.GetAllRunningModels() - s.logger.Debugf("Rebalancing model gateways for running models: %v", runningModels) - // get only the model gateway streams streams := []*ModelSubscription{} for _, modelSubscription := range s.modelEventStream.streams { @@ -239,7 +336,7 @@ func (s *SchedulerServer) modelGwRebalance() { } } - for _, model := range runningModels { + for _, model := range models { switch len(streams) { case 0: s.modelGwRebalanceNoStream(model) @@ -251,7 +348,8 @@ func (s *SchedulerServer) modelGwRebalance() { func (s *SchedulerServer) modelGwRebalanceNoStream(model *store.ModelSnapshot) { modelState := store.ModelCreate - if model.GetLatest().ModelState().ModelGwState == store.ModelTerminating { + if model.GetLatest().ModelState().ModelGwState == store.ModelTerminating || + model.GetLatest().ModelState().ModelGwState == store.ModelTerminateFailed { modelState = store.ModelTerminated } @@ -298,8 +396,8 @@ func (s *SchedulerServer) modelGwReblanceStreams(model *store.ModelSnapshot) { var msg *pb.ModelStatusResponse var err error - if state == store.ModelTerminating { - s.logger.Debugf("Model %s is terminating, sending deletion message", model.Name) + if state == store.ModelTerminating || state == store.ModelTerminateFailed { + s.logger.Debugf("Model %s in state %s, sending deletion message", model.Name, state) msg, err = s.createModelDeletionMessage(model, false) } else { s.logger.Debugf("Model %s is available or progressing, sending creation message", model.Name) diff --git a/scheduler/pkg/server/server_status_test.go b/scheduler/pkg/server/server_status_test.go index 8451cb998f..53abef0c62 100644 --- a/scheduler/pkg/server/server_status_test.go +++ b/scheduler/pkg/server/server_status_test.go @@ -17,6 +17,7 @@ import ( . "github.com/onsi/gomega" log "github.com/sirupsen/logrus" + "go.uber.org/mock/gomock" "google.golang.org/protobuf/proto" pba "github.com/seldonio/seldon-core/apis/go/v2/mlops/agent" @@ -28,11 +29,194 @@ import ( "github.com/seldonio/seldon-core/scheduler/v2/pkg/scheduler/cleaner" "github.com/seldonio/seldon-core/scheduler/v2/pkg/store" "github.com/seldonio/seldon-core/scheduler/v2/pkg/store/experiment" + "github.com/seldonio/seldon-core/scheduler/v2/pkg/store/mock" "github.com/seldonio/seldon-core/scheduler/v2/pkg/store/pipeline" "github.com/seldonio/seldon-core/scheduler/v2/pkg/synchroniser" "github.com/seldonio/seldon-core/scheduler/v2/pkg/util" ) +func TestPollerRetryFailedModels(t *testing.T) { + tests := []struct { + name string + funcName string + targetState store.ModelState + operation string + modelNames []string + setupMocks func(mockModelStore *mock.MockModelStore, modelNames []string, targetState store.ModelState) + contextTimeout time.Duration + tickDuration time.Duration + validateMocks func(g *WithT, mockModelStore *mock.MockModelStore) + maxRetries uint + }{ + { + name: "context cancelled immediately", + funcName: "testFunc", + targetState: store.ModelFailed, + operation: "create", + modelNames: []string{}, + setupMocks: func(mockModelStore *mock.MockModelStore, modelNames []string, targetState store.ModelState) { + // No expectations - context cancelled before first tick + }, + contextTimeout: 0, // Cancel immediately + tickDuration: 100 * time.Millisecond, + }, + { + name: "no models exist", + funcName: "testFunc", + targetState: store.ModelFailed, + operation: "create", + modelNames: []string{}, + setupMocks: func(mockModelStore *mock.MockModelStore, modelNames []string, targetState store.ModelState) { + mockModelStore.EXPECT(). + GetAllModels(). + Return([]string{}). + MinTimes(1) + }, + contextTimeout: 150 * time.Millisecond, + tickDuration: 50 * time.Millisecond, + }, + { + name: "single model not in target state", + funcName: "testFunc", + targetState: store.ModelFailed, + operation: "create", + modelNames: []string{"model-1"}, + setupMocks: func(mockModelStore *mock.MockModelStore, modelNames []string, targetState store.ModelState) { + mockModelStore.EXPECT(). + GetAllModels(). + Return(modelNames). + MinTimes(1) + + model := &store.ModelSnapshot{} + model.Name = "model-1" + modelVersion := store.NewModelVersion(&pb.Model{}, 1, "server-1", nil, false, 0) + modelVersion.SetModelState(store.ModelStatus{ + ModelGwState: store.ScheduleFailed, + }) + model.Versions = []*store.ModelVersion{modelVersion} + + mockModelStore.EXPECT(). + GetModel("model-1"). + Return(model, nil). + MinTimes(1) + }, + contextTimeout: 150 * time.Millisecond, + tickDuration: 50 * time.Millisecond, + }, + { + name: "single model in failed state", + funcName: "pollerRetryFailedCreateModels", + targetState: store.ModelFailed, + operation: "create", + modelNames: []string{"failed-model"}, + setupMocks: func(mockModelStore *mock.MockModelStore, modelNames []string, targetState store.ModelState) { + mockModelStore.EXPECT(). + GetAllModels(). + Return(modelNames).MinTimes(1) + + model := &store.ModelSnapshot{} + model.Name = "failed-model" + modelVersion := store.NewModelVersion(&pb.Model{}, 1, "server-1", nil, false, 0) + modelVersion.SetModelState(store.ModelStatus{ + ModelGwState: store.ModelFailed, + }) + model.Versions = []*store.ModelVersion{modelVersion} + + mockModelStore.EXPECT(). + GetModel("failed-model"). + Return(model, nil). + MinTimes(1) + + mockModelStore.EXPECT().SetModelGwModelState( + "failed-model", + uint32(1), + store.ModelCreate, "No model gateway available to handle model", modelStatusEventSource).MinTimes(1) + }, + contextTimeout: 100 * time.Millisecond, + tickDuration: 50 * time.Millisecond, + maxRetries: 3, + }, + { + name: "max retries exceeded, do not retry", + funcName: "pollerRetryFailedCreateModels", + targetState: store.ModelFailed, + operation: "create", + modelNames: []string{"failed-model"}, + setupMocks: func(mockModelStore *mock.MockModelStore, modelNames []string, targetState store.ModelState) { + mockModelStore.EXPECT(). + GetAllModels(). + Return(modelNames).MinTimes(1) + + model := &store.ModelSnapshot{} + model.Name = "failed-model" + modelVersion := store.NewModelVersion(&pb.Model{}, 1, "server-1", nil, false, 0) + modelVersion.SetModelState(store.ModelStatus{ + ModelGwState: store.ModelFailed, + }) + model.Versions = []*store.ModelVersion{modelVersion} + + mockModelStore.EXPECT(). + GetModel("failed-model"). + Return(model, nil). + MinTimes(1) + }, + contextTimeout: 100 * time.Millisecond, + tickDuration: 50 * time.Millisecond, + maxRetries: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewGomegaWithT(t) + ctrl := gomock.NewController(t) + + mockModelStore := mock.NewMockModelStore(ctrl) + + if tt.setupMocks != nil { + tt.setupMocks(mockModelStore, tt.modelNames, tt.targetState) + } + + server := &SchedulerServer{ + logger: log.New(), + modelStore: mockModelStore, + retriedFailedModels: make(map[string]uint), + } + + var ctx context.Context + var cancel context.CancelFunc + + if tt.contextTimeout == 0 { + // Cancel immediately + ctx, cancel = context.WithCancel(context.Background()) + cancel() + } else { + ctx, cancel = context.WithTimeout(context.Background(), tt.contextTimeout) + defer cancel() + } + + done := make(chan bool) + go func() { + server.pollerRetryFailedModels(ctx, tt.tickDuration, tt.funcName, tt.targetState, tt.operation, tt.maxRetries) + done <- true + }() + + select { + case <-done: + // Test passed - function returned as expected + case <-time.After(tt.contextTimeout + 1*time.Second): + t.Fatal("pollerRetryFailedModels did not return in time") + } + + // Custom validation if provided + if tt.validateMocks != nil { + tt.validateMocks(g, mockModelStore) + } + }) + } + +} + func receiveMessageFromModelStream(stream *stubModelStatusServer) *pb.ModelStatusResponse { time.Sleep(500 * time.Millisecond) diff --git a/scheduler/pkg/store/experiment/state_test.go b/scheduler/pkg/store/experiment/state_test.go index 6231758251..e8cda540f1 100644 --- a/scheduler/pkg/store/experiment/state_test.go +++ b/scheduler/pkg/store/experiment/state_test.go @@ -708,6 +708,15 @@ type fakePipelineStore struct { pipelineGwStatus map[string]pipeline.PipelineStatus } +func (f fakePipelineStore) IsLatestVersion(pipelineName string, version uint32, uid string) (bool, error) { + //TODO implement me + panic("implement me") +} + +func (f fakePipelineStore) GetPipelinesPipelineGwStatus(_ pipeline.PipelineStatus) []coordinator.PipelineEventMsg { + panic("implement me") +} + func (f fakePipelineStore) AddPipeline(pipeline *scheduler.Pipeline) error { panic("implement me") } diff --git a/scheduler/pkg/store/mesh.go b/scheduler/pkg/store/mesh.go index 31ca15fbb4..47025fe992 100644 --- a/scheduler/pkg/store/mesh.go +++ b/scheduler/pkg/store/mesh.go @@ -239,14 +239,15 @@ func NewServerReplicaFromConfig(server *Server, replicaIdx int, loadedModels map func cleanCapabilities(capabilities []string) []string { var cleaned []string - for _, cap := range capabilities { - cleaned = append(cleaned, strings.TrimSpace(cap)) + for _, capability := range capabilities { + cleaned = append(cleaned, strings.TrimSpace(capability)) } return cleaned } type ModelState uint32 +//go:generate go tool stringer -type=ModelState const ( ModelStateUnknown ModelState = iota ModelProgressing @@ -261,24 +262,9 @@ const ( ModelTerminate ) -func (m ModelState) String() string { - return [...]string{ - "ModelStateUnknown", - "ModelProgressing", - "ModelAvailable", - "ModelFailed", - "ModelTerminating", - "ModelTerminated", - "ModelTerminateFailed", - "ScheduleFailed", - "ModelScaledDown", - "ModelCreate", - "ModelTerminate", - }[m] -} - type ModelReplicaState uint32 +//go:generate go tool stringer -type=ModelReplicaState const ( ModelReplicaStateUnknown ModelReplicaState = iota LoadRequested @@ -332,10 +318,6 @@ func (m ModelReplicaState) IsLoadingOrLoaded() bool { return (m == Loaded || m == LoadRequested || m == Loading || m == Available || m == LoadedUnavailable) } -func (me ModelReplicaState) String() string { - return [...]string{"ModelReplicaStateUnknown", "LoadRequested", "Loading", "Loaded", "LoadFailed", "UnloadEnvoyRequested", "UnloadRequested", "Unloading", "Unloaded", "UnloadFailed", "Available", "LoadedUnavailable", "Draining"}[me] -} - func (m *Model) HasLatest() bool { return len(m.versions) > 0 } diff --git a/scheduler/pkg/store/modelreplicastate_string.go b/scheduler/pkg/store/modelreplicastate_string.go new file mode 100644 index 0000000000..5d5f876b71 --- /dev/null +++ b/scheduler/pkg/store/modelreplicastate_string.go @@ -0,0 +1,44 @@ +/* +Copyright (c) 2024 Seldon Technologies Ltd. + +Use of this software is governed BY +(1) the license included in the LICENSE file or +(2) if the license included in the LICENSE file is the Business Source License 1.1, +the Change License after the Change Date as each is defined in accordance with the LICENSE file. +*/ + +// Code generated by "stringer -type=ModelReplicaState"; DO NOT EDIT. + +package store + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[ModelReplicaStateUnknown-0] + _ = x[LoadRequested-1] + _ = x[Loading-2] + _ = x[Loaded-3] + _ = x[LoadFailed-4] + _ = x[UnloadEnvoyRequested-5] + _ = x[UnloadRequested-6] + _ = x[Unloading-7] + _ = x[Unloaded-8] + _ = x[UnloadFailed-9] + _ = x[Available-10] + _ = x[LoadedUnavailable-11] + _ = x[Draining-12] +} + +const _ModelReplicaState_name = "ModelReplicaStateUnknownLoadRequestedLoadingLoadedLoadFailedUnloadEnvoyRequestedUnloadRequestedUnloadingUnloadedUnloadFailedAvailableLoadedUnavailableDraining" + +var _ModelReplicaState_index = [...]uint8{0, 24, 37, 44, 50, 60, 80, 95, 104, 112, 124, 133, 150, 158} + +func (i ModelReplicaState) String() string { + if i >= ModelReplicaState(len(_ModelReplicaState_index)-1) { + return "ModelReplicaState(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _ModelReplicaState_name[_ModelReplicaState_index[i]:_ModelReplicaState_index[i+1]] +} diff --git a/scheduler/pkg/store/modelstate_string.go b/scheduler/pkg/store/modelstate_string.go new file mode 100644 index 0000000000..7ed0bce200 --- /dev/null +++ b/scheduler/pkg/store/modelstate_string.go @@ -0,0 +1,42 @@ +/* +Copyright (c) 2024 Seldon Technologies Ltd. + +Use of this software is governed BY +(1) the license included in the LICENSE file or +(2) if the license included in the LICENSE file is the Business Source License 1.1, +the Change License after the Change Date as each is defined in accordance with the LICENSE file. +*/ + +// Code generated by "stringer -type=ModelState"; DO NOT EDIT. + +package store + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[ModelStateUnknown-0] + _ = x[ModelProgressing-1] + _ = x[ModelAvailable-2] + _ = x[ModelFailed-3] + _ = x[ModelTerminating-4] + _ = x[ModelTerminated-5] + _ = x[ModelTerminateFailed-6] + _ = x[ScheduleFailed-7] + _ = x[ModelScaledDown-8] + _ = x[ModelCreate-9] + _ = x[ModelTerminate-10] +} + +const _ModelState_name = "ModelStateUnknownModelProgressingModelAvailableModelFailedModelTerminatingModelTerminatedModelTerminateFailedScheduleFailedModelScaledDownModelCreateModelTerminate" + +var _ModelState_index = [...]uint8{0, 17, 33, 47, 58, 74, 89, 109, 123, 138, 149, 163} + +func (i ModelState) String() string { + if i >= ModelState(len(_ModelState_index)-1) { + return "ModelState(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _ModelState_name[_ModelState_index[i]:_ModelState_index[i+1]] +} diff --git a/scheduler/pkg/store/pipeline/mock/store.go b/scheduler/pkg/store/pipeline/mock/store.go new file mode 100644 index 0000000000..84853fbc82 --- /dev/null +++ b/scheduler/pkg/store/pipeline/mock/store.go @@ -0,0 +1,209 @@ +/* +Copyright (c) 2024 Seldon Technologies Ltd. + +Use of this software is governed BY +(1) the license included in the LICENSE file or +(2) if the license included in the LICENSE file is the Business Source License 1.1, +the Change License after the Change Date as each is defined in accordance with the LICENSE file. +*/ + +// Code generated by MockGen. DO NOT EDIT. +// Source: ./store.go +// +// Generated by this command: +// +// mockgen -source=./store.go -destination=./mock/store.go -package=mock PipelineHandler +// + +// Package mock is a generated GoMock package. +package mock + +import ( + reflect "reflect" + + scheduler "github.com/seldonio/seldon-core/apis/go/v2/mlops/scheduler" + coordinator "github.com/seldonio/seldon-core/scheduler/v2/pkg/coordinator" + pipeline "github.com/seldonio/seldon-core/scheduler/v2/pkg/store/pipeline" + gomock "go.uber.org/mock/gomock" +) + +// MockPipelineHandler is a mock of PipelineHandler interface. +type MockPipelineHandler struct { + ctrl *gomock.Controller + recorder *MockPipelineHandlerMockRecorder +} + +// MockPipelineHandlerMockRecorder is the mock recorder for MockPipelineHandler. +type MockPipelineHandlerMockRecorder struct { + mock *MockPipelineHandler +} + +// NewMockPipelineHandler creates a new mock instance. +func NewMockPipelineHandler(ctrl *gomock.Controller) *MockPipelineHandler { + mock := &MockPipelineHandler{ctrl: ctrl} + mock.recorder = &MockPipelineHandlerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockPipelineHandler) EXPECT() *MockPipelineHandlerMockRecorder { + return m.recorder +} + +// AddPipeline mocks base method. +func (m *MockPipelineHandler) AddPipeline(pipeline *scheduler.Pipeline) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddPipeline", pipeline) + ret0, _ := ret[0].(error) + return ret0 +} + +// AddPipeline indicates an expected call of AddPipeline. +func (mr *MockPipelineHandlerMockRecorder) AddPipeline(pipeline any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddPipeline", reflect.TypeOf((*MockPipelineHandler)(nil).AddPipeline), pipeline) +} + +// GetAllPipelineGwRunningPipelineVersions mocks base method. +func (m *MockPipelineHandler) GetAllPipelineGwRunningPipelineVersions() []coordinator.PipelineEventMsg { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAllPipelineGwRunningPipelineVersions") + ret0, _ := ret[0].([]coordinator.PipelineEventMsg) + return ret0 +} + +// GetAllPipelineGwRunningPipelineVersions indicates an expected call of GetAllPipelineGwRunningPipelineVersions. +func (mr *MockPipelineHandlerMockRecorder) GetAllPipelineGwRunningPipelineVersions() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllPipelineGwRunningPipelineVersions", reflect.TypeOf((*MockPipelineHandler)(nil).GetAllPipelineGwRunningPipelineVersions)) +} + +// GetAllRunningPipelineVersions mocks base method. +func (m *MockPipelineHandler) GetAllRunningPipelineVersions() []coordinator.PipelineEventMsg { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAllRunningPipelineVersions") + ret0, _ := ret[0].([]coordinator.PipelineEventMsg) + return ret0 +} + +// GetAllRunningPipelineVersions indicates an expected call of GetAllRunningPipelineVersions. +func (mr *MockPipelineHandlerMockRecorder) GetAllRunningPipelineVersions() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllRunningPipelineVersions", reflect.TypeOf((*MockPipelineHandler)(nil).GetAllRunningPipelineVersions)) +} + +// GetPipeline mocks base method. +func (m *MockPipelineHandler) GetPipeline(name string) (*pipeline.Pipeline, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPipeline", name) + ret0, _ := ret[0].(*pipeline.Pipeline) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPipeline indicates an expected call of GetPipeline. +func (mr *MockPipelineHandlerMockRecorder) GetPipeline(name any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPipeline", reflect.TypeOf((*MockPipelineHandler)(nil).GetPipeline), name) +} + +// GetPipelineVersion mocks base method. +func (m *MockPipelineHandler) GetPipelineVersion(name string, version uint32, uid string) (*pipeline.PipelineVersion, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPipelineVersion", name, version, uid) + ret0, _ := ret[0].(*pipeline.PipelineVersion) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPipelineVersion indicates an expected call of GetPipelineVersion. +func (mr *MockPipelineHandlerMockRecorder) GetPipelineVersion(name, version, uid any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPipelineVersion", reflect.TypeOf((*MockPipelineHandler)(nil).GetPipelineVersion), name, version, uid) +} + +// GetPipelines mocks base method. +func (m *MockPipelineHandler) GetPipelines() ([]*pipeline.Pipeline, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPipelines") + ret0, _ := ret[0].([]*pipeline.Pipeline) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPipelines indicates an expected call of GetPipelines. +func (mr *MockPipelineHandlerMockRecorder) GetPipelines() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPipelines", reflect.TypeOf((*MockPipelineHandler)(nil).GetPipelines)) +} + +// GetPipelinesPipelineGwStatus mocks base method. +func (m *MockPipelineHandler) GetPipelinesPipelineGwStatus(pipelineGwStatus pipeline.PipelineStatus) []coordinator.PipelineEventMsg { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPipelinesPipelineGwStatus", pipelineGwStatus) + ret0, _ := ret[0].([]coordinator.PipelineEventMsg) + return ret0 +} + +// GetPipelinesPipelineGwStatus indicates an expected call of GetPipelinesPipelineGwStatus. +func (mr *MockPipelineHandlerMockRecorder) GetPipelinesPipelineGwStatus(pipelineGwStatus any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPipelinesPipelineGwStatus", reflect.TypeOf((*MockPipelineHandler)(nil).GetPipelinesPipelineGwStatus), pipelineGwStatus) +} + +// IsLatestVersion mocks base method. +func (m *MockPipelineHandler) IsLatestVersion(pipelineName string, version uint32, uid string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsLatestVersion", pipelineName, version, uid) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// IsLatestVersion indicates an expected call of IsLatestVersion. +func (mr *MockPipelineHandlerMockRecorder) IsLatestVersion(pipelineName, version, uid any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsLatestVersion", reflect.TypeOf((*MockPipelineHandler)(nil).IsLatestVersion), pipelineName, version, uid) +} + +// RemovePipeline mocks base method. +func (m *MockPipelineHandler) RemovePipeline(name string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemovePipeline", name) + ret0, _ := ret[0].(error) + return ret0 +} + +// RemovePipeline indicates an expected call of RemovePipeline. +func (mr *MockPipelineHandlerMockRecorder) RemovePipeline(name any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePipeline", reflect.TypeOf((*MockPipelineHandler)(nil).RemovePipeline), name) +} + +// SetPipelineGwPipelineState mocks base method. +func (m *MockPipelineHandler) SetPipelineGwPipelineState(name string, version uint32, uid string, state pipeline.PipelineStatus, reason, source string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetPipelineGwPipelineState", name, version, uid, state, reason, source) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetPipelineGwPipelineState indicates an expected call of SetPipelineGwPipelineState. +func (mr *MockPipelineHandlerMockRecorder) SetPipelineGwPipelineState(name, version, uid, state, reason, source any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetPipelineGwPipelineState", reflect.TypeOf((*MockPipelineHandler)(nil).SetPipelineGwPipelineState), name, version, uid, state, reason, source) +} + +// SetPipelineState mocks base method. +func (m *MockPipelineHandler) SetPipelineState(name string, version uint32, uid string, state pipeline.PipelineStatus, reason, source string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetPipelineState", name, version, uid, state, reason, source) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetPipelineState indicates an expected call of SetPipelineState. +func (mr *MockPipelineHandlerMockRecorder) SetPipelineState(name, version, uid, state, reason, source any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetPipelineState", reflect.TypeOf((*MockPipelineHandler)(nil).SetPipelineState), name, version, uid, state, reason, source) +} diff --git a/scheduler/pkg/store/pipeline/pipeline.go b/scheduler/pkg/store/pipeline/pipeline.go index a8339bb8a2..5357335048 100644 --- a/scheduler/pkg/store/pipeline/pipeline.go +++ b/scheduler/pkg/store/pipeline/pipeline.go @@ -70,16 +70,18 @@ type KubernetesMeta struct { type PipelineStatus uint32 +//go:generate go tool stringer -type=PipelineStatus const ( - PipelineStatusUnknown PipelineStatus = iota - PipelineCreate // Received signal to create pipeline. - PipelineCreating // In the process of creating pipeline. - PipelineReady // Pipeline is ready to be used. - PipelineFailed // Pipeline creation/deletion failed. - PipelineTerminate // Received signal that pipeline should be terminated. - PipelineTerminating // In the process of doing cleanup/housekeeping for pipeline termination. - PipelineTerminated // Pipeline has been terminated. - PipelineRebalancing // Pipeline is rebalancing + PipelineStatusUnknown PipelineStatus = iota + PipelineCreate // Received signal to create pipeline. + PipelineCreating // In the process of creating pipeline. + PipelineReady // Pipeline is ready to be used. + PipelineFailed // Pipeline creation failed. + PipelineTerminate // Received signal that pipeline should be terminated. + PipelineTerminating // In the process of doing cleanup/housekeeping for pipeline termination. + PipelineTerminated // Pipeline has been terminated. + PipelineRebalancing // Pipeline is rebalancing + PipelineFailedTerminating // Pipeline has failed to terminate. ) type PipelineState struct { @@ -91,10 +93,6 @@ type PipelineState struct { Timestamp time.Time } -func (ps PipelineStatus) String() string { - return [...]string{"PipelineStatusUnknown", "PipelineCreate", "PipelineCreating", "PipelineReady", "PipelineFailed", "PipelineTerminate", "PipelineTerminating", "PipelineTerminated", "PipelineRebalancing"}[ps] -} - func (ps *PipelineState) setState(status PipelineStatus, reason string) { ps.Status = status ps.Reason = reason diff --git a/scheduler/pkg/store/pipeline/pipelinestatus_string.go b/scheduler/pkg/store/pipeline/pipelinestatus_string.go new file mode 100644 index 0000000000..f504f9d422 --- /dev/null +++ b/scheduler/pkg/store/pipeline/pipelinestatus_string.go @@ -0,0 +1,41 @@ +/* +Copyright (c) 2024 Seldon Technologies Ltd. + +Use of this software is governed BY +(1) the license included in the LICENSE file or +(2) if the license included in the LICENSE file is the Business Source License 1.1, +the Change License after the Change Date as each is defined in accordance with the LICENSE file. +*/ + +// Code generated by "stringer -type=PipelineStatus"; DO NOT EDIT. + +package pipeline + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[PipelineStatusUnknown-0] + _ = x[PipelineCreate-1] + _ = x[PipelineCreating-2] + _ = x[PipelineReady-3] + _ = x[PipelineFailed-4] + _ = x[PipelineTerminate-5] + _ = x[PipelineTerminating-6] + _ = x[PipelineTerminated-7] + _ = x[PipelineRebalancing-8] + _ = x[PipelineFailedTerminating-9] +} + +const _PipelineStatus_name = "PipelineStatusUnknownPipelineCreatePipelineCreatingPipelineReadyPipelineFailedPipelineTerminatePipelineTerminatingPipelineTerminatedPipelineRebalancingPipelineFailedTerminating" + +var _PipelineStatus_index = [...]uint8{0, 21, 35, 51, 64, 78, 95, 114, 132, 151, 176} + +func (i PipelineStatus) String() string { + if i >= PipelineStatus(len(_PipelineStatus_index)-1) { + return "PipelineStatus(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _PipelineStatus_name[_PipelineStatus_index[i]:_PipelineStatus_index[i+1]] +} diff --git a/scheduler/pkg/store/pipeline/status.go b/scheduler/pkg/store/pipeline/status.go index e76883bfd6..5684a09c08 100644 --- a/scheduler/pkg/store/pipeline/status.go +++ b/scheduler/pkg/store/pipeline/status.go @@ -92,6 +92,7 @@ func updatePipelinesFromModelAvailability(references map[string]void, modelName logger := loggerIn.WithField("func", "updatePipelinesFromModelAvailability") logger.Debugf("Updating pipeline state from model %s available:%v", modelName, modelAvailable) var changedPipelines []*coordinator.PipelineEventMsg + for pipelineName := range references { if pipeline, ok := pipelines[pipelineName]; ok { latestPipeline := pipeline.GetLatestPipelineVersion() diff --git a/scheduler/pkg/store/pipeline/store.go b/scheduler/pkg/store/pipeline/store.go index fd94ffbf91..408445cb7d 100644 --- a/scheduler/pkg/store/pipeline/store.go +++ b/scheduler/pkg/store/pipeline/store.go @@ -10,6 +10,7 @@ the Change License after the Change Date as each is defined in accordance with t package pipeline import ( + "fmt" "os" "path/filepath" "sync" @@ -36,6 +37,7 @@ const ( modelEventHandlerName = "pipeline.store.models" ) +//go:generate go tool mockgen -source=./store.go -destination=./mock/store.go -package=mock PipelineHandler type PipelineHandler interface { AddPipeline(pipeline *scheduler.Pipeline) error RemovePipeline(name string) error @@ -46,6 +48,8 @@ type PipelineHandler interface { SetPipelineGwPipelineState(name string, version uint32, uid string, state PipelineStatus, reason string, source string) error GetAllRunningPipelineVersions() []coordinator.PipelineEventMsg GetAllPipelineGwRunningPipelineVersions() []coordinator.PipelineEventMsg + GetPipelinesPipelineGwStatus(pipelineGwStatus PipelineStatus) []coordinator.PipelineEventMsg + IsLatestVersion(pipelineName string, version uint32, uid string) (bool, error) } type PipelineStore struct { @@ -84,6 +88,48 @@ func getPipelineDbFolder(basePath string) string { return filepath.Join(basePath, pipelineDbFolder) } +func (ps *PipelineStore) IsLatestVersion(pipelineName string, version uint32, uid string) (bool, error) { + ps.mu.RLock() + defer ps.mu.RUnlock() + + pipeline, ok := ps.pipelines[pipelineName] + if !ok { + return false, fmt.Errorf("pipeline %s not found", pipelineName) + } + + latestVersion := pipeline.GetLatestPipelineVersion() + if latestVersion == nil { + return false, fmt.Errorf("pipeline %s has no latest version", pipelineName) + } + + return latestVersion.Version == version && latestVersion.UID == uid, nil +} + +func (ps *PipelineStore) GetPipelinesPipelineGwStatus(status PipelineStatus) []coordinator.PipelineEventMsg { + ps.mu.RLock() + defer ps.mu.RUnlock() + + var events []coordinator.PipelineEventMsg + for _, p := range ps.pipelines { + pv := p.GetLatestPipelineVersion() + if pv == nil { + ps.logger.Warnf("Pipeline %s versions empty", p.Name) + continue + } + + if pv.State.PipelineGwStatus != status { + continue + } + + events = append(events, coordinator.PipelineEventMsg{ + PipelineName: pv.Name, + PipelineVersion: pv.Version, + UID: pv.UID, + }) + } + return events +} + func (ps *PipelineStore) InitialiseOrRestoreDB(path string, deletedResourceTTL uint) error { logger := ps.logger.WithField("func", "initialiseDB") pipelineDbPath := getPipelineDbFolder(path) @@ -297,6 +343,7 @@ func (ps *PipelineStore) removePipelineImpl(name string) (*coordinator.PipelineE func (ps *PipelineStore) GetPipelineVersion(name string, versionNumber uint32, uid string) (*PipelineVersion, error) { ps.mu.RLock() defer ps.mu.RUnlock() + if pipeline, ok := ps.pipelines[name]; ok { if pipelineVersion := pipeline.GetPipelineVersion(versionNumber); pipelineVersion != nil { if pipelineVersion.UID == uid { @@ -339,14 +386,23 @@ func (ps *PipelineStore) getAllRunningPipelineVersions( var events []coordinator.PipelineEventMsg for _, p := range ps.pipelines { pv := p.GetLatestPipelineVersion() - switch statusSelector(pv) { + if pv == nil { + ps.logger.Warnf("Pipeline %s versions empty", p.Name) + continue + } + + status := statusSelector(pv) + switch status { // we consider PipelineTerminating as running as it is still active - case PipelineCreate, PipelineCreating, PipelineReady, PipelineRebalancing, PipelineTerminating: + // we want to attempt to create failed pipelines as could have failed for temporary error such as kafka unavailable + case PipelineCreate, PipelineCreating, PipelineReady, PipelineRebalancing, PipelineTerminating, PipelineFailed: events = append(events, coordinator.PipelineEventMsg{ PipelineName: pv.Name, PipelineVersion: pv.Version, UID: pv.UID, }) + default: + ps.logger.Debugf("Pipeline %s state %s not considered running", pv.Name, status) } } return events @@ -544,6 +600,7 @@ func (ps *PipelineStore) handleModelEvents(event coordinator.ModelEventMsg) { go func() { ps.modelStatusHandler.mu.RLock() defer ps.modelStatusHandler.mu.RUnlock() + refs := ps.modelStatusHandler.modelReferences[event.ModelName] if len(refs) > 0 { model, err := ps.modelStatusHandler.store.GetModel(event.ModelName) diff --git a/scheduler/pkg/store/pipeline/store_test.go b/scheduler/pkg/store/pipeline/store_test.go index 4d68d752ff..95e7b1d728 100644 --- a/scheduler/pkg/store/pipeline/store_test.go +++ b/scheduler/pkg/store/pipeline/store_test.go @@ -19,9 +19,324 @@ import ( "github.com/seldonio/seldon-core/apis/go/v2/mlops/scheduler" + "github.com/seldonio/seldon-core/scheduler/v2/pkg/coordinator" "github.com/seldonio/seldon-core/scheduler/v2/pkg/store" ) +func TestGetPipelinesPipelineGwStatus(t *testing.T) { + tests := []struct { + name string + pipelines map[string]*Pipeline + queryStatus PipelineStatus + expectedCount int + expectedNames []string + expectedVersions []uint32 + expectedUIDs []string + validate func(g *WithT, events []coordinator.PipelineEventMsg) + }{ + { + name: "empty pipelines map", + pipelines: make(map[string]*Pipeline), + queryStatus: PipelineReady, + expectedCount: 0, + }, + { + name: "no matching status", + pipelines: map[string]*Pipeline{ + "test-pipeline": { + Name: "test-pipeline", + LastVersion: 1, + Versions: []*PipelineVersion{ + { + Name: "test-pipeline", + Version: 1, + UID: "uid-1", + State: &PipelineState{ + Status: PipelineReady, + PipelineGwStatus: PipelineCreating, + }, + }, + }, + }, + }, + queryStatus: PipelineReady, + expectedCount: 0, + }, + { + name: "single matching pipeline", + pipelines: map[string]*Pipeline{ + "test-pipeline": { + Name: "test-pipeline", + LastVersion: 1, + Versions: []*PipelineVersion{ + { + Name: "test-pipeline", + Version: 1, + UID: "uid-1", + State: &PipelineState{ + Status: PipelineReady, + PipelineGwStatus: PipelineReady, + }, + }, + }, + }, + }, + queryStatus: PipelineReady, + expectedCount: 1, + expectedNames: []string{"test-pipeline"}, + expectedVersions: []uint32{1}, + expectedUIDs: []string{"uid-1"}, + }, + { + name: "multiple matching pipelines", + pipelines: map[string]*Pipeline{ + "pipeline-1": { + Name: "pipeline-1", + LastVersion: 1, + Versions: []*PipelineVersion{ + { + Name: "pipeline-1", + Version: 1, + UID: "uid-1", + State: &PipelineState{ + Status: PipelineReady, + PipelineGwStatus: PipelineCreating, + }, + }, + }, + }, + "pipeline-2": { + Name: "pipeline-2", + LastVersion: 1, + Versions: []*PipelineVersion{ + { + Name: "pipeline-2", + Version: 1, + UID: "uid-2", + State: &PipelineState{ + Status: PipelineReady, + PipelineGwStatus: PipelineCreating, + }, + }, + }, + }, + }, + queryStatus: PipelineCreating, + expectedCount: 2, + validate: func(g *WithT, events []coordinator.PipelineEventMsg) { + pipelineNames := []string{events[0].PipelineName, events[1].PipelineName} + g.Expect(pipelineNames).To(ContainElement("pipeline-1")) + g.Expect(pipelineNames).To(ContainElement("pipeline-2")) + }, + }, + { + name: "mixed statuses - only return matching", + pipelines: map[string]*Pipeline{ + "ready-pipeline": { + Name: "ready-pipeline", + LastVersion: 1, + Versions: []*PipelineVersion{ + { + Name: "ready-pipeline", + Version: 1, + UID: "uid-ready", + State: &PipelineState{ + Status: PipelineReady, + PipelineGwStatus: PipelineReady, + }, + }, + }, + }, + "creating-pipeline": { + Name: "creating-pipeline", + LastVersion: 1, + Versions: []*PipelineVersion{ + { + Name: "creating-pipeline", + Version: 1, + UID: "uid-creating", + State: &PipelineState{ + Status: PipelineCreating, + PipelineGwStatus: PipelineCreating, + }, + }, + }, + }, + "terminating-pipeline": { + Name: "terminating-pipeline", + LastVersion: 1, + Versions: []*PipelineVersion{ + { + Name: "terminating-pipeline", + Version: 1, + UID: "uid-terminating", + State: &PipelineState{ + Status: PipelineTerminating, + PipelineGwStatus: PipelineTerminating, + }, + }, + }, + }, + }, + queryStatus: PipelineReady, + expectedCount: 1, + expectedNames: []string{"ready-pipeline"}, + expectedVersions: []uint32{1}, + expectedUIDs: []string{"uid-ready"}, + }, + { + name: "multiple versions - return latest", + pipelines: map[string]*Pipeline{ + "test-pipeline": { + Name: "test-pipeline", + LastVersion: 3, + Versions: []*PipelineVersion{ + { + Name: "test-pipeline", + Version: 1, + UID: "uid-1", + State: &PipelineState{ + Status: PipelineTerminated, + PipelineGwStatus: PipelineTerminated, + }, + }, + { + Name: "test-pipeline", + Version: 2, + UID: "uid-2", + State: &PipelineState{ + Status: PipelineTerminated, + PipelineGwStatus: PipelineTerminated, + }, + }, + { + Name: "test-pipeline", + Version: 3, + UID: "uid-3", + State: &PipelineState{ + Status: PipelineReady, + PipelineGwStatus: PipelineReady, + }, + }, + }, + }, + }, + queryStatus: PipelineReady, + expectedCount: 1, + expectedNames: []string{"test-pipeline"}, + expectedVersions: []uint32{3}, + expectedUIDs: []string{"uid-3"}, + }, + { + name: "check PipelineGwStatus not Status", + pipelines: map[string]*Pipeline{ + "test-pipeline": { + Name: "test-pipeline", + LastVersion: 1, + Versions: []*PipelineVersion{ + { + Name: "test-pipeline", + Version: 1, + UID: "uid-1", + State: &PipelineState{ + Status: PipelineReady, + PipelineGwStatus: PipelineCreating, + }, + }, + }, + }, + }, + queryStatus: PipelineCreating, + expectedCount: 1, + expectedNames: []string{"test-pipeline"}, + expectedVersions: []uint32{1}, + expectedUIDs: []string{"uid-1"}, + }, + { + name: "pipeline with no versions", + pipelines: map[string]*Pipeline{ + "test-pipeline": { + Name: "test-pipeline", + LastVersion: 0, + Versions: []*PipelineVersion{}, + }, + }, + queryStatus: PipelineReady, + expectedCount: 0, + }, + { + name: "status matches but PipelineGwStatus doesn't", + pipelines: map[string]*Pipeline{ + "pipeline-match-both": { + Name: "pipeline-match-both", + LastVersion: 1, + Versions: []*PipelineVersion{ + { + Name: "pipeline-match-both", + Version: 1, + UID: "uid-match", + State: &PipelineState{ + Status: PipelineReady, + PipelineGwStatus: PipelineReady, + }, + }, + }, + }, + "pipeline-match-status-only": { + Name: "pipeline-match-status-only", + LastVersion: 1, + Versions: []*PipelineVersion{ + { + Name: "pipeline-match-status-only", + Version: 1, + UID: "uid-no-match", + State: &PipelineState{ + Status: PipelineReady, + PipelineGwStatus: PipelineCreating, + }, + }, + }, + }, + }, + queryStatus: PipelineReady, + expectedCount: 1, + expectedNames: []string{"pipeline-match-both"}, + expectedVersions: []uint32{1}, + expectedUIDs: []string{"uid-match"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewGomegaWithT(t) + + s := &PipelineStore{ + logger: logrus.New(), + pipelines: tt.pipelines, + } + + events := s.GetPipelinesPipelineGwStatus(tt.queryStatus) + + g.Expect(events).To(HaveLen(tt.expectedCount)) + + if tt.validate != nil { + tt.validate(g, events) + } else if tt.expectedCount > 0 { + // Default validation for single expected result + for i := 0; i < tt.expectedCount && i < len(tt.expectedNames); i++ { + g.Expect(events[i].PipelineName).To(Equal(tt.expectedNames[i])) + if len(tt.expectedVersions) > i { + g.Expect(events[i].PipelineVersion).To(Equal(tt.expectedVersions[i])) + } + if len(tt.expectedUIDs) > i { + g.Expect(events[i].UID).To(Equal(tt.expectedUIDs[i])) + } + } + } + }) + } +} + func TestAddPipeline(t *testing.T) { g := NewGomegaWithT(t) type test struct { diff --git a/scheduler/pkg/util/loadbalancer.go b/scheduler/pkg/util/loadbalancer.go index 6a7a39bcef..e1a0a19325 100644 --- a/scheduler/pkg/util/loadbalancer.go +++ b/scheduler/pkg/util/loadbalancer.go @@ -15,6 +15,7 @@ import ( "github.com/serialx/hashring" ) +//go:generate go tool mockgen -source=./loadbalancer.go -destination=./mock/loadbalancer.go -package=mock LoadBalancer type LoadBalancer interface { AddServer(serverName string) RemoveServer(serverName string) diff --git a/scheduler/pkg/util/mock/loadbalancer.go b/scheduler/pkg/util/mock/loadbalancer.go new file mode 100644 index 0000000000..c5e5e310a2 --- /dev/null +++ b/scheduler/pkg/util/mock/loadbalancer.go @@ -0,0 +1,98 @@ +/* +Copyright (c) 2024 Seldon Technologies Ltd. + +Use of this software is governed BY +(1) the license included in the LICENSE file or +(2) if the license included in the LICENSE file is the Business Source License 1.1, +the Change License after the Change Date as each is defined in accordance with the LICENSE file. +*/ + +// Code generated by MockGen. DO NOT EDIT. +// Source: ./loadbalancer.go +// +// Generated by this command: +// +// mockgen -source=./loadbalancer.go -destination=./mock/loadbalancer.go -package=mock LoadBalancer +// + +// Package mock is a generated GoMock package. +package mock + +import ( + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockLoadBalancer is a mock of LoadBalancer interface. +type MockLoadBalancer struct { + ctrl *gomock.Controller + recorder *MockLoadBalancerMockRecorder +} + +// MockLoadBalancerMockRecorder is the mock recorder for MockLoadBalancer. +type MockLoadBalancerMockRecorder struct { + mock *MockLoadBalancer +} + +// NewMockLoadBalancer creates a new mock instance. +func NewMockLoadBalancer(ctrl *gomock.Controller) *MockLoadBalancer { + mock := &MockLoadBalancer{ctrl: ctrl} + mock.recorder = &MockLoadBalancerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockLoadBalancer) EXPECT() *MockLoadBalancerMockRecorder { + return m.recorder +} + +// AddServer mocks base method. +func (m *MockLoadBalancer) AddServer(serverName string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "AddServer", serverName) +} + +// AddServer indicates an expected call of AddServer. +func (mr *MockLoadBalancerMockRecorder) AddServer(serverName any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddServer", reflect.TypeOf((*MockLoadBalancer)(nil).AddServer), serverName) +} + +// GetServersForKey mocks base method. +func (m *MockLoadBalancer) GetServersForKey(key string) []string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetServersForKey", key) + ret0, _ := ret[0].([]string) + return ret0 +} + +// GetServersForKey indicates an expected call of GetServersForKey. +func (mr *MockLoadBalancerMockRecorder) GetServersForKey(key any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServersForKey", reflect.TypeOf((*MockLoadBalancer)(nil).GetServersForKey), key) +} + +// RemoveServer mocks base method. +func (m *MockLoadBalancer) RemoveServer(serverName string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "RemoveServer", serverName) +} + +// RemoveServer indicates an expected call of RemoveServer. +func (mr *MockLoadBalancerMockRecorder) RemoveServer(serverName any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveServer", reflect.TypeOf((*MockLoadBalancer)(nil).RemoveServer), serverName) +} + +// UpdatePartitions mocks base method. +func (m *MockLoadBalancer) UpdatePartitions(numPartitions int) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdatePartitions", numPartitions) +} + +// UpdatePartitions indicates an expected call of UpdatePartitions. +func (mr *MockLoadBalancerMockRecorder) UpdatePartitions(numPartitions any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatePartitions", reflect.TypeOf((*MockLoadBalancer)(nil).UpdatePartitions), numPartitions) +}