From 709db01947f5ff114678f3b5f3372d590f05c298 Mon Sep 17 00:00:00 2001 From: "yanhu.cheng" Date: Wed, 20 Aug 2025 20:14:03 +0400 Subject: [PATCH] add switch context tool --- pkg/kubernetes/configuration.go | 39 +++++++++++++++++++++++++++++++-- pkg/mcp/common_test.go | 2 +- pkg/mcp/configuration.go | 26 ++++++++++++++++++++-- pkg/mcp/configuration_test.go | 25 +++++++-------------- 4 files changed, 70 insertions(+), 22 deletions(-) diff --git a/pkg/kubernetes/configuration.go b/pkg/kubernetes/configuration.go index df88530f..a160f0db 100644 --- a/pkg/kubernetes/configuration.go +++ b/pkg/kubernetes/configuration.go @@ -106,10 +106,45 @@ func (m *Manager) ConfigurationView(minify bool) (runtime.Object, error) { return nil, err } } + + // Remove certificate-authority-data from clusters to reduce output size + newcfg := cfg.DeepCopy() + for _, cluster := range newcfg.Clusters { + if cluster != nil { + cluster.CertificateAuthorityData = nil + } + } + // Remove AuthInfos + newcfg.AuthInfos = nil + //nolint:staticcheck - if err = clientcmdapi.FlattenConfig(&cfg); err != nil { + if err = clientcmdapi.FlattenConfig(newcfg); err != nil { // ignore error //return "", err } - return latest.Scheme.ConvertToVersion(&cfg, latest.ExternalVersion) + return latest.Scheme.ConvertToVersion(newcfg, latest.ExternalVersion) +} + +// SwitchContext switches the current Kubernetes context to the specified context +func (m *Manager) SwitchContext(contextName string) error { + // Get the raw config + rawConfig, err := m.clientCmdConfig.RawConfig() + if err != nil { + return err + } + + // Check if the context exists + if _, exists := rawConfig.Contexts[contextName]; !exists { + return err + } + + // Update the current context + rawConfig.CurrentContext = contextName + + // Write the updated config back + if err := clientcmd.ModifyConfig(m.clientCmdConfig.ConfigAccess(), rawConfig, true); err != nil { + return err + } + + return nil } diff --git a/pkg/mcp/common_test.go b/pkg/mcp/common_test.go index 8e4e49d3..ebf432ab 100644 --- a/pkg/mcp/common_test.go +++ b/pkg/mcp/common_test.go @@ -213,7 +213,7 @@ func (c *mcpContext) withKubeConfig(rc *rest.Config) *api.Config { fakeConfig.Contexts["fake-context"].AuthInfo = "fake" fakeConfig.Contexts["additional-context"] = api.NewContext() fakeConfig.Contexts["additional-context"].Cluster = "additional-cluster" - fakeConfig.Contexts["additional-context"].AuthInfo = "additional-auth" + fakeConfig.Contexts["additional-context"].AuthInfo = "" fakeConfig.CurrentContext = "fake-context" kubeConfig := filepath.Join(c.tempDir, "config") _ = clientcmd.WriteToFile(*fakeConfig, kubeConfig) diff --git a/pkg/mcp/configuration.go b/pkg/mcp/configuration.go index 79ebaef4..54828d21 100644 --- a/pkg/mcp/configuration.go +++ b/pkg/mcp/configuration.go @@ -4,10 +4,9 @@ import ( "context" "fmt" + "github.com/containers/kubernetes-mcp-server/pkg/output" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" - - "github.com/containers/kubernetes-mcp-server/pkg/output" ) func (s *Server) initConfiguration() []server.ServerTool { @@ -24,6 +23,15 @@ func (s *Server) initConfiguration() []server.ServerTool { mcp.WithDestructiveHintAnnotation(false), mcp.WithOpenWorldHintAnnotation(true), ), Handler: s.configurationView}, + {Tool: mcp.NewTool("configuration_switch_context", + mcp.WithDescription("Switch the current Kubernetes context to a different context"), + mcp.WithString("context", mcp.Description("The name of the context to switch to. Use configuration_view to see available contexts.")), + // Tool annotations + mcp.WithTitleAnnotation("Configuration: Switch Context"), + mcp.WithReadOnlyHintAnnotation(false), + mcp.WithDestructiveHintAnnotation(false), + mcp.WithOpenWorldHintAnnotation(false), + ), Handler: s.configurationSwitchContext}, } return tools } @@ -44,3 +52,17 @@ func (s *Server) configurationView(_ context.Context, ctr mcp.CallToolRequest) ( } return NewTextResult(configurationYaml, err), nil } + +func (s *Server) configurationSwitchContext(_ context.Context, ctr mcp.CallToolRequest) (*mcp.CallToolResult, error) { + contextName, ok := ctr.GetArguments()["context"].(string) + if !ok || contextName == "" { + return NewTextResult("", fmt.Errorf("context parameter is required and must be a string")), nil + } + + err := s.k.SwitchContext(contextName) + if err != nil { + return NewTextResult("", fmt.Errorf("failed to switch context: %v", err)), nil + } + + return NewTextResult(fmt.Sprintf("Successfully switched to context: %s", contextName), nil), nil +} diff --git a/pkg/mcp/configuration_test.go b/pkg/mcp/configuration_test.go index 57fea486..a6771e11 100644 --- a/pkg/mcp/configuration_test.go +++ b/pkg/mcp/configuration_test.go @@ -55,11 +55,8 @@ func TestConfigurationView(t *testing.T) { } }) t.Run("configuration_view returns auth info", func(t *testing.T) { - if len(decoded.AuthInfos) != 1 { - t.Errorf("invalid auth info count, expected 1, got %v", len(decoded.AuthInfos)) - } - if decoded.AuthInfos[0].Name != "fake" { - t.Errorf("fake-auth not found: %v", decoded.AuthInfos) + if len(decoded.AuthInfos) != 0 { + t.Errorf("invalid auth info count, expected 0, got %v", len(decoded.AuthInfos)) } }) toolResult, err = c.callTool("configuration_view", map[string]interface{}{ @@ -86,8 +83,8 @@ func TestConfigurationView(t *testing.T) { if decoded.Contexts[0].Context.Cluster != "additional-cluster" { t.Errorf("additional-cluster not found: %v", decoded.Contexts) } - if decoded.Contexts[0].Context.AuthInfo != "additional-auth" { - t.Errorf("additional-auth not found: %v", decoded.Contexts) + if decoded.Contexts[0].Context.AuthInfo != "" { + t.Errorf("expected empty auth info for additional-context, got: %v", decoded.Contexts[0].Context.AuthInfo) } if decoded.Contexts[1].Name != "fake-context" { t.Errorf("fake-context not found: %v", decoded.Contexts) @@ -102,11 +99,8 @@ func TestConfigurationView(t *testing.T) { } }) t.Run("configuration_view with minified=false returns auth info", func(t *testing.T) { - if len(decoded.AuthInfos) != 2 { - t.Errorf("invalid auth info count, expected 2, got %v", len(decoded.AuthInfos)) - } - if decoded.AuthInfos[0].Name != "additional-auth" { - t.Errorf("additional-auth not found: %v", decoded.AuthInfos) + if len(decoded.AuthInfos) != 0 { + t.Errorf("invalid auth info count, expected 0, got %v", len(decoded.AuthInfos)) } }) }) @@ -167,11 +161,8 @@ func TestConfigurationViewInCluster(t *testing.T) { } }) t.Run("configuration_view returns auth info", func(t *testing.T) { - if len(decoded.AuthInfos) != 1 { - t.Fatalf("invalid auth info count, expected 1, got %v", len(decoded.AuthInfos)) - } - if decoded.AuthInfos[0].Name != "user" { - t.Fatalf("user not found: %v", decoded.AuthInfos) + if len(decoded.AuthInfos) != 0 { + t.Fatalf("invalid auth info count, expected 0, got %v", len(decoded.AuthInfos)) } }) })