Skip to content

Commit dfd419d

Browse files
authored
Add support for Tool Override to CLI. (#1841)
This change adds a new `--tools-override` option to `thv run` and similar commands. The new option accepts a path to a file containing a JSON structure like the following ```json { "toolsOverride": { "actual-name": { "name": "new-name", "description": "Overridden description." } } } ``` The override file is read once when the MCP server is started, and is not reloaded on `thv restart`. This behaviour might change in the future. Note that Tool Override feature is mostly orthogonal to Tool Filtering one, with the exception of the name used for filtering. Specifically, overrides just change the name or description of an existing tool, but do not limit access to other tools; filtering does the exact opposite. When both an override and a filter are specified for the same tool, the name to be used when specifying the filter is the overridden name. We decided to implement it this way for two reasons 1. the user is likely to think of tools in the MCP server in terms of the new names rather than the actual ones, since the former are those visible to the Client, and 2. the MCP spec does not specify means for MCP servers to publish the full list of tools without first running it, and determining those available requires starting the server itself and issuing a `tools/list` call, i.e. require acting as a client. As a consequence, it's impossible to reliably determine the full list of tools, and a valid filter in one scenario might not work in another one. Fixes #1511
1 parent b362c70 commit dfd419d

File tree

5 files changed

+302
-21
lines changed

5 files changed

+302
-21
lines changed

cmd/thv/app/run_flags.go

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99

1010
"github.com/stacklok/toolhive/pkg/auth"
1111
"github.com/stacklok/toolhive/pkg/authz"
12+
"github.com/stacklok/toolhive/pkg/cli"
1213
cfg "github.com/stacklok/toolhive/pkg/config"
1314
"github.com/stacklok/toolhive/pkg/container"
1415
"github.com/stacklok/toolhive/pkg/container/runtime"
@@ -89,6 +90,8 @@ type RunFlags struct {
8990

9091
// Tools filter
9192
ToolsFilter []string
93+
// Tools override file
94+
ToolsOverride string
9295

9396
// Configuration import
9497
FromConfig string
@@ -200,6 +203,12 @@ func AddRunFlags(cmd *cobra.Command, config *RunFlags) {
200203
nil,
201204
"Filter MCP server tools (comma-separated list of tool names)",
202205
)
206+
cmd.Flags().StringVar(
207+
&config.ToolsOverride,
208+
"tools-override",
209+
"",
210+
"Path to a JSON file containing overrides for MCP server tools names and descriptions",
211+
)
203212
cmd.Flags().StringVar(&config.FromConfig, "from-config", "", "Load configuration from exported file")
204213

205214
// Environment file processing flags
@@ -420,18 +429,31 @@ func buildRunnerConfig(
420429
}),
421430
}
422431

432+
var toolsOverride map[string]runner.ToolOverride
433+
if runFlags.ToolsOverride != "" {
434+
loadedToolsOverride, err := cli.LoadToolsOverride(runFlags.ToolsOverride)
435+
if err != nil {
436+
return nil, fmt.Errorf("failed to load tools override: %v", err)
437+
}
438+
toolsOverride = *loadedToolsOverride
439+
}
440+
441+
opts = append(opts, runner.WithToolsOverride(toolsOverride))
423442
// Configure middleware from flags
424-
opts = append(opts, runner.WithMiddlewareFromFlags(
425-
oidcConfig,
426-
runFlags.ToolsFilter,
427-
nil,
428-
telemetryConfig,
429-
runFlags.AuthzConfig,
430-
runFlags.EnableAudit,
431-
runFlags.AuditConfig,
432-
runFlags.Name,
433-
runFlags.Transport,
434-
))
443+
opts = append(
444+
opts,
445+
runner.WithMiddlewareFromFlags(
446+
oidcConfig,
447+
runFlags.ToolsFilter,
448+
toolsOverride,
449+
telemetryConfig,
450+
runFlags.AuthzConfig,
451+
runFlags.EnableAudit,
452+
runFlags.AuditConfig,
453+
runFlags.Name,
454+
runFlags.Transport,
455+
),
456+
)
435457

436458
if remoteServerMetadata, ok := serverMetadata.(*registry.RemoteServerMetadata); ok {
437459
if remoteAuthConfig := getRemoteAuthFromRemoteServerMetadata(remoteServerMetadata); remoteAuthConfig != nil {

docs/cli/thv_run.md

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pkg/cli/tools_override.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Package cli provides utility functions specific to the
2+
// CLI that we want to test more thoroughly.
3+
package cli
4+
5+
import (
6+
"encoding/json"
7+
"errors"
8+
"fmt"
9+
"os"
10+
"path/filepath"
11+
12+
"github.com/stacklok/toolhive/pkg/runner"
13+
)
14+
15+
// ToolsOverrideJSON is a struct that represents the tools override JSON file.
16+
type toolsOverrideJSON struct {
17+
ToolsOverride map[string]runner.ToolOverride `json:"toolsOverride"`
18+
}
19+
20+
// LoadToolsOverride loads the tools override JSON file from the given path.
21+
func LoadToolsOverride(path string) (*map[string]runner.ToolOverride, error) {
22+
jsonFile, err := os.Open(filepath.Clean(path))
23+
if err != nil {
24+
return nil, fmt.Errorf("failed to open tools override file: %v", err)
25+
}
26+
defer jsonFile.Close()
27+
28+
var toolsOverride toolsOverrideJSON
29+
decoder := json.NewDecoder(jsonFile)
30+
err = decoder.Decode(&toolsOverride)
31+
if err != nil {
32+
return nil, fmt.Errorf("failed to decode tools override file: %v", err)
33+
}
34+
if toolsOverride.ToolsOverride == nil {
35+
return nil, errors.New("tools override are empty")
36+
}
37+
return &toolsOverride.ToolsOverride, nil
38+
}

pkg/cli/tools_override_test.go

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
package cli
2+
3+
import (
4+
"os"
5+
"path/filepath"
6+
"strings"
7+
"testing"
8+
9+
"github.com/stretchr/testify/assert"
10+
11+
"github.com/stacklok/toolhive/pkg/runner"
12+
)
13+
14+
func TestLoadToolsOverride(t *testing.T) {
15+
t.Parallel()
16+
17+
tests := []struct {
18+
name string
19+
jsonContent string
20+
expectedResult *map[string]runner.ToolOverride
21+
expectError bool
22+
}{
23+
{
24+
name: "valid tools override with name and description",
25+
jsonContent: `{
26+
"toolsOverride": {
27+
"original_tool": {
28+
"name": "renamed_tool",
29+
"description": "A new description for the tool"
30+
}
31+
}
32+
}`,
33+
expectedResult: &map[string]runner.ToolOverride{
34+
"original_tool": {
35+
Name: "renamed_tool",
36+
Description: "A new description for the tool",
37+
},
38+
},
39+
expectError: false,
40+
},
41+
{
42+
name: "valid tools override with only name",
43+
jsonContent: `{
44+
"toolsOverride": {
45+
"tool1": {
46+
"name": "new_tool_name"
47+
}
48+
}
49+
}`,
50+
expectedResult: &map[string]runner.ToolOverride{
51+
"tool1": {
52+
Name: "new_tool_name",
53+
},
54+
},
55+
expectError: false,
56+
},
57+
{
58+
name: "valid tools override with only description",
59+
jsonContent: `{
60+
"toolsOverride": {
61+
"tool2": {
62+
"description": "Updated description only"
63+
}
64+
}
65+
}`,
66+
expectedResult: &map[string]runner.ToolOverride{
67+
"tool2": {
68+
Description: "Updated description only",
69+
},
70+
},
71+
expectError: false,
72+
},
73+
{
74+
name: "valid tools override with multiple tools",
75+
jsonContent: `{
76+
"toolsOverride": {
77+
"tool1": {
78+
"name": "renamed_tool1",
79+
"description": "Description for tool1"
80+
},
81+
"tool2": {
82+
"name": "renamed_tool2"
83+
},
84+
"tool3": {
85+
"description": "Description for tool3"
86+
}
87+
}
88+
}`,
89+
expectedResult: &map[string]runner.ToolOverride{
90+
"tool1": {
91+
Name: "renamed_tool1",
92+
Description: "Description for tool1",
93+
},
94+
"tool2": {
95+
Name: "renamed_tool2",
96+
},
97+
"tool3": {
98+
Description: "Description for tool3",
99+
},
100+
},
101+
expectError: false,
102+
},
103+
{
104+
name: "valid empty tools override",
105+
jsonContent: `{
106+
"toolsOverride": {}
107+
}`,
108+
expectedResult: &map[string]runner.ToolOverride{},
109+
expectError: false,
110+
},
111+
{
112+
name: "invalid JSON syntax",
113+
jsonContent: `{
114+
"toolsOverride": {
115+
"tool1": {
116+
"name": "invalid_json"
117+
}
118+
}
119+
`, // Missing closing brace
120+
expectedResult: nil,
121+
expectError: true,
122+
},
123+
{
124+
name: "missing toolsOverride field",
125+
jsonContent: `{
126+
"otherField": "value"
127+
}`,
128+
expectedResult: nil,
129+
expectError: true,
130+
},
131+
{
132+
name: "null toolsOverride field",
133+
jsonContent: `{
134+
"toolsOverride": null
135+
}`,
136+
expectedResult: nil,
137+
expectError: true,
138+
},
139+
{
140+
name: "empty file",
141+
jsonContent: ``,
142+
expectedResult: nil,
143+
expectError: true,
144+
},
145+
{
146+
name: "non-JSON content",
147+
jsonContent: `This is not JSON content`,
148+
expectedResult: nil,
149+
expectError: true,
150+
},
151+
}
152+
153+
for _, tt := range tests {
154+
t.Run(tt.name, func(t *testing.T) {
155+
t.Parallel()
156+
157+
// Create a temporary file
158+
tmpFile, err := os.CreateTemp("", "tools_override_test_*.json")
159+
if err != nil {
160+
t.Fatalf("Failed to create temp file: %v", err)
161+
}
162+
defer os.Remove(tmpFile.Name())
163+
164+
// Write test content to the file
165+
if tt.jsonContent != "" {
166+
_, err = tmpFile.WriteString(tt.jsonContent)
167+
if err != nil {
168+
t.Fatalf("Failed to write to temp file: %v", err)
169+
}
170+
}
171+
tmpFile.Close()
172+
173+
// Test the LoadToolsOverride function
174+
result, err := LoadToolsOverride(tmpFile.Name())
175+
176+
// Check error expectations
177+
if tt.expectError {
178+
assert.Error(t, err)
179+
assert.Nil(t, result)
180+
} else {
181+
assert.NoError(t, err)
182+
assert.NotNil(t, result)
183+
// Compare the results
184+
assert.Equal(t, tt.expectedResult, result)
185+
}
186+
})
187+
}
188+
}
189+
190+
func TestLoadToolsOverride_FileNotFound(t *testing.T) {
191+
t.Parallel()
192+
193+
// Test with non-existent file
194+
nonExistentFile := filepath.Join(os.TempDir(), "non_existent_file.json")
195+
196+
result, err := LoadToolsOverride(nonExistentFile)
197+
198+
if err == nil {
199+
t.Errorf("Expected error for non-existent file but got none")
200+
}
201+
202+
if result != nil {
203+
t.Errorf("Expected nil result for non-existent file but got: %+v", result)
204+
}
205+
206+
if !strings.Contains(err.Error(), "failed to open tools override file") {
207+
t.Errorf("Expected error to contain 'failed to open tools override file', but got: %v", err)
208+
}
209+
}

pkg/runner/config_builder.go

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,7 @@ func addToolFilterMiddlewares(
475475
toolsFilter []string,
476476
toolsOverride map[string]ToolOverride,
477477
) []types.MiddlewareConfig {
478-
if len(toolsFilter) == 0 {
478+
if len(toolsFilter) == 0 && len(toolsOverride) == 0 {
479479
return middlewareConfigs
480480
}
481481

@@ -789,15 +789,6 @@ func (b *runConfigBuilder) validateConfig(imageMetadata *registry.ImageMetadata)
789789
}
790790
}
791791

792-
if c.ToolsFilter != nil && imageMetadata != nil && imageMetadata.Tools != nil {
793-
logger.Debugf("Using tools filter: %v", c.ToolsFilter)
794-
for _, tool := range c.ToolsFilter {
795-
if !slices.Contains(imageMetadata.Tools, tool) {
796-
return fmt.Errorf("tool %s not found in registry", tool)
797-
}
798-
}
799-
}
800-
801792
if c.ToolsOverride != nil && imageMetadata != nil && imageMetadata.Tools != nil {
802793
logger.Debugf("Using tools override: %v", c.ToolsOverride)
803794
for toolName := range c.ToolsOverride {
@@ -807,6 +798,26 @@ func (b *runConfigBuilder) validateConfig(imageMetadata *registry.ImageMetadata)
807798
}
808799
}
809800

801+
if c.ToolsFilter != nil && imageMetadata != nil && imageMetadata.Tools != nil {
802+
logger.Debugf("Using tools filter: %v", c.ToolsFilter)
803+
for _, tool := range c.ToolsFilter {
804+
name := tool
805+
806+
if c.ToolsOverride != nil {
807+
for actualName, toolOverride := range c.ToolsOverride {
808+
if toolOverride.Name == tool {
809+
name = actualName
810+
break
811+
}
812+
}
813+
}
814+
815+
if !slices.Contains(imageMetadata.Tools, name) {
816+
return fmt.Errorf("tool %s not found in registry", name)
817+
}
818+
}
819+
}
820+
810821
return nil
811822
}
812823

0 commit comments

Comments
 (0)