From 1eb159becfb65f44c9dab41e68a798547d37115a Mon Sep 17 00:00:00 2001 From: sviat stoliarenko Date: Wed, 16 Jul 2025 00:24:21 +0300 Subject: [PATCH] Introduce generator customisation via GeneratorOptions For https://github.com/modelcontextprotocol/go-sdk/issues/136 --- jsonschema/infer.go | 44 +++++++++++++++++---- jsonschema/infer_test.go | 85 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 122 insertions(+), 7 deletions(-) diff --git a/jsonschema/infer.go b/jsonschema/infer.go index 654e6197..7efc0182 100644 --- a/jsonschema/infer.go +++ b/jsonschema/infer.go @@ -14,6 +14,23 @@ import ( "github.com/modelcontextprotocol/go-sdk/internal/util" ) +// GeneratorOptions contains options for the schema generator. +// It allows defining custom AdditionalProperties for a specific type. +// Also, SchemaRegistry can be used to provide pre-defined schemas for specific types (e.g., struct, interfaces) +type GeneratorOptions struct { + AdditionalProperties func(reflect.Type) *Schema // input is type name + SchemaRegistry map[reflect.Type]*Schema +} + +// defaultGeneratorOptions is the default set of options for the schema generator. +// Used by [For] function, +var defaultGeneratorOptions = GeneratorOptions{ + AdditionalProperties: func(t reflect.Type) *Schema { + return falseSchema() + }, + SchemaRegistry: make(map[reflect.Type]*Schema), +} + // For constructs a JSON schema object for the given type argument. // // It translates Go types into compatible JSON schema types, as follows: @@ -45,9 +62,20 @@ import ( // For future compatibility, descriptions must not start with "WORD=", where WORD is a // sequence of non-whitespace characters. func For[T any]() (*Schema, error) { + return CustomizedFor[T](defaultGeneratorOptions) +} + +// See [For] description for details. +// +// Main difference is that it allows customizing things like: +// - AdditionalProperties for a specific type +// - Pre-defined schemas for specific types (e.g., struct, interfaces) +// +// For more details, see [GeneratorOptions] documentation. +func CustomizedFor[T any](options GeneratorOptions) (*Schema, error) { // TODO: consider skipping incompatible fields, instead of failing. seen := make(map[reflect.Type]bool) - s, err := forType(reflect.TypeFor[T](), seen) + s, err := forType(reflect.TypeFor[T](), seen, options) if err != nil { var z T return nil, fmt.Errorf("For[%T](): %w", z, err) @@ -55,7 +83,7 @@ func For[T any]() (*Schema, error) { return s, nil } -func forType(t reflect.Type, seen map[reflect.Type]bool) (*Schema, error) { +func forType(t reflect.Type, seen map[reflect.Type]bool, options GeneratorOptions) (*Schema, error) { // Follow pointers: the schema for *T is almost the same as for T, except that // an explicit JSON "null" is allowed for the pointer. allowNull := false @@ -92,6 +120,9 @@ func forType(t reflect.Type, seen map[reflect.Type]bool) (*Schema, error) { s.Type = "number" case reflect.Interface: + if schema, ok := options.SchemaRegistry[t]; ok { + s = schema + } // Unrestricted case reflect.Map: @@ -99,14 +130,14 @@ func forType(t reflect.Type, seen map[reflect.Type]bool) (*Schema, error) { return nil, fmt.Errorf("unsupported map key type %v", t.Key().Kind()) } s.Type = "object" - s.AdditionalProperties, err = forType(t.Elem(), seen) + s.AdditionalProperties, err = forType(t.Elem(), seen, options) if err != nil { return nil, fmt.Errorf("computing map value schema: %v", err) } case reflect.Slice, reflect.Array: s.Type = "array" - s.Items, err = forType(t.Elem(), seen) + s.Items, err = forType(t.Elem(), seen, options) if err != nil { return nil, fmt.Errorf("computing element schema: %v", err) } @@ -120,8 +151,7 @@ func forType(t reflect.Type, seen map[reflect.Type]bool) (*Schema, error) { case reflect.Struct: s.Type = "object" - // no additional properties are allowed - s.AdditionalProperties = falseSchema() + s.AdditionalProperties = options.AdditionalProperties(t) for i := range t.NumField() { field := t.Field(i) @@ -132,7 +162,7 @@ func forType(t reflect.Type, seen map[reflect.Type]bool) (*Schema, error) { if s.Properties == nil { s.Properties = make(map[string]*Schema) } - fs, err := forType(field.Type, seen) + fs, err := forType(field.Type, seen, options) if err != nil { return nil, err } diff --git a/jsonschema/infer_test.go b/jsonschema/infer_test.go index 106e5375..f307435f 100644 --- a/jsonschema/infer_test.go +++ b/jsonschema/infer_test.go @@ -5,6 +5,7 @@ package jsonschema_test import ( + "reflect" "strings" "testing" @@ -128,6 +129,90 @@ func TestFor(t *testing.T) { } } +func customizedForType[T any](options jsonschema.GeneratorOptions) *jsonschema.Schema { + s, err := jsonschema.CustomizedFor[T](options) + if err != nil { + panic(err) + } + return s +} + +func TestCustomizedFor(t *testing.T) { + type schema = jsonschema.Schema + + type S struct { + B int `jsonschema:"bdesc"` + } + sType := reflect.TypeOf((*S)(nil)).Elem() + + type CustomS interface { + X() string + } + customSSchema := schema{Type: "object", Properties: map[string]*schema{ + "X": {Type: "string", Description: "custom interface property"}, + }} + customSType := reflect.TypeOf((*CustomS)(nil)).Elem() + + genOptions := jsonschema.GeneratorOptions{ + AdditionalProperties: func(t reflect.Type) *jsonschema.Schema { + if t == sType { + return &schema{AnyOf: []*schema{ + {Type: "integer"}, + {Type: "string"}, + }} + } + return &schema{} + }, + SchemaRegistry: map[reflect.Type]*jsonschema.Schema{ + customSType: &customSSchema, + }, + } + + tests := []struct { + name string + got *jsonschema.Schema + want *jsonschema.Schema + }{ + { + "interface", + customizedForType[CustomS](genOptions), + &schema{ + Type: "object", + Properties: map[string]*schema{ + "X": {Type: "string", Description: "custom interface property"}, + }, + }, + }, + { + "customized struct", + customizedForType[S](genOptions), + &schema{ + Type: "object", + Properties: map[string]*schema{ + "B": {Type: "integer", Description: "bdesc"}, + }, + Required: []string{"B"}, + AdditionalProperties: &schema{AnyOf: []*schema{ + {Type: "integer"}, + {Type: "string"}, + }}, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if diff := cmp.Diff(test.want, test.got, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { + t.Fatalf("ForType mismatch (-want +got):\n%s", diff) + } + // These schemas should all resolve. + if _, err := test.got.Resolve(nil); err != nil { + t.Fatalf("Resolving: %v", err) + } + }) + } +} + func forErr[T any]() error { _, err := jsonschema.For[T]() return err