Skip to content

Commit 7774025

Browse files
authored
fix(go): allow the use of snake case fields in dotprompt templates (#3744)
1 parent a9a3755 commit 7774025

File tree

2 files changed

+83
-8
lines changed

2 files changed

+83
-8
lines changed

go/ai/prompt.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ package ai
1616

1717
import (
1818
"context"
19+
"encoding/json"
1920
"errors"
2021
"fmt"
2122
"log/slog"
@@ -260,7 +261,16 @@ func buildVariables(variables any) (map[string]any, error) {
260261

261262
v := reflect.Indirect(reflect.ValueOf(variables))
262263
if v.Kind() == reflect.Map {
263-
return variables.(map[string]any), nil
264+
// ensure JSON tags are taken in consideration (allowing snake case fields)
265+
jsonData, err := json.Marshal(variables)
266+
if err != nil {
267+
return nil, fmt.Errorf("unable to marshal prompt field values: %w", err)
268+
}
269+
var resultVariables map[string]any
270+
if err := json.Unmarshal(jsonData, &resultVariables); err != nil {
271+
return nil, fmt.Errorf("unable to unmarshal prompt field values: %w", err)
272+
}
273+
return resultVariables, nil
264274
}
265275
if v.Kind() != reflect.Struct {
266276
return nil, errors.New("prompt.buildVariables: fields not a struct or pointer to a struct or a map")

go/ai/prompt_test.go

Lines changed: 72 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -904,7 +904,7 @@ output:
904904
---
905905
Hello, {{name}}!
906906
`
907-
err := os.WriteFile(mockPromptFile, []byte(mockPromptContent), 0644)
907+
err := os.WriteFile(mockPromptFile, []byte(mockPromptContent), 0o644)
908908
if err != nil {
909909
t.Fatalf("Failed to create mock prompt file: %v", err)
910910
}
@@ -941,6 +941,71 @@ Hello, {{name}}!
941941
}
942942
}
943943

944+
func TestLoadPromptSnakeCase(t *testing.T) {
945+
tempDir := t.TempDir()
946+
mockPromptFile := filepath.Join(tempDir, "snake.prompt")
947+
mockPromptContent := `---
948+
model: googleai/gemini-2.5-flash
949+
input:
950+
schema:
951+
items(array):
952+
teamColor: string
953+
team_name: string
954+
---
955+
{{#each items as |it|}}
956+
{{ it.teamColor }},{{ it.team_name }}
957+
{{/each}}
958+
`
959+
err := os.WriteFile(mockPromptFile, []byte(mockPromptContent), 0o644)
960+
if err != nil {
961+
t.Fatalf("Failed to create mock prompt file: %v", err)
962+
}
963+
964+
reg := registry.New()
965+
LoadPrompt(reg, tempDir, "snake.prompt", "snake-namespace")
966+
967+
prompt := LookupPrompt(reg, "snake-namespace/snake")
968+
if prompt == nil {
969+
t.Fatalf("prompt was not registered")
970+
}
971+
972+
type SnakeInput struct {
973+
TeamColor string `json:"teamColor"` // intentionally leaving camel case to test snake + camel support
974+
TeamName string `json:"team_name"`
975+
}
976+
977+
input := map[string]any{"items": []SnakeInput{
978+
{TeamColor: "RED", TeamName: "Firebase"},
979+
{TeamColor: "BLUE", TeamName: "Gophers"},
980+
{TeamColor: "GREEN", TeamName: "Google"},
981+
}}
982+
983+
actionOpts, err := prompt.Render(context.Background(), input)
984+
if err != nil {
985+
t.Fatalf("error rendering prompt: %v", err)
986+
}
987+
if actionOpts.Messages == nil {
988+
t.Fatal("expecting messages to be rendered")
989+
}
990+
renderedPrompt := actionOpts.Messages[0].Text()
991+
for line := range strings.SplitSeq(renderedPrompt, "\n") {
992+
trimmedLine := strings.TrimSpace(line)
993+
if strings.HasPrefix(trimmedLine, "RED") {
994+
if !strings.Contains(trimmedLine, "Firebase") {
995+
t.Fatalf("wrong template render, want: RED,Firebase, got: %s", trimmedLine)
996+
}
997+
} else if strings.HasPrefix(trimmedLine, "BLUE") {
998+
if !strings.Contains(trimmedLine, "Gophers") {
999+
t.Fatalf("wrong template render, want: BLUE,Gophers, got: %s", trimmedLine)
1000+
}
1001+
} else if strings.HasPrefix(trimmedLine, "GREEN") {
1002+
if !strings.Contains(trimmedLine, "Google") {
1003+
t.Fatalf("wrong template render, want: GREEN,Google, got: %s", trimmedLine)
1004+
}
1005+
}
1006+
}
1007+
}
1008+
9441009
func TestLoadPrompt_FileNotFound(t *testing.T) {
9451010
// Initialize a mock registry
9461011
reg := registry.New()
@@ -962,7 +1027,7 @@ func TestLoadPrompt_InvalidPromptFile(t *testing.T) {
9621027
// Create an invalid .prompt file
9631028
invalidPromptFile := filepath.Join(tempDir, "invalid.prompt")
9641029
invalidPromptContent := `invalid json content`
965-
err := os.WriteFile(invalidPromptFile, []byte(invalidPromptContent), 0644)
1030+
err := os.WriteFile(invalidPromptFile, []byte(invalidPromptContent), 0o644)
9661031
if err != nil {
9671032
t.Fatalf("Failed to create invalid prompt file: %v", err)
9681033
}
@@ -993,7 +1058,7 @@ description: A test prompt
9931058
9941059
Hello, {{name}}!
9951060
`
996-
err := os.WriteFile(mockPromptFile, []byte(mockPromptContent), 0644)
1061+
err := os.WriteFile(mockPromptFile, []byte(mockPromptContent), 0o644)
9971062
if err != nil {
9981063
t.Fatalf("Failed to create mock prompt file: %v", err)
9991064
}
@@ -1018,7 +1083,7 @@ func TestLoadPromptFolder(t *testing.T) {
10181083
// Create mock prompt and partial files
10191084
mockPromptFile := filepath.Join(tempDir, "example.prompt")
10201085
mockSubDir := filepath.Join(tempDir, "subdir")
1021-
err := os.Mkdir(mockSubDir, 0755)
1086+
err := os.Mkdir(mockSubDir, 0o755)
10221087
if err != nil {
10231088
t.Fatalf("Failed to create subdirectory: %v", err)
10241089
}
@@ -1041,14 +1106,14 @@ output:
10411106
Hello, {{name}}!
10421107
`
10431108

1044-
err = os.WriteFile(mockPromptFile, []byte(mockPromptContent), 0644)
1109+
err = os.WriteFile(mockPromptFile, []byte(mockPromptContent), 0o644)
10451110
if err != nil {
10461111
t.Fatalf("Failed to create mock prompt file: %v", err)
10471112
}
10481113

10491114
// Create a mock prompt file in the subdirectory
10501115
mockSubPromptFile := filepath.Join(mockSubDir, "sub_example.prompt")
1051-
err = os.WriteFile(mockSubPromptFile, []byte(mockPromptContent), 0644)
1116+
err = os.WriteFile(mockSubPromptFile, []byte(mockPromptContent), 0o644)
10521117
if err != nil {
10531118
t.Fatalf("Failed to create mock prompt file in subdirectory: %v", err)
10541119
}
@@ -1131,7 +1196,7 @@ You are a pirate!
11311196
{{ role "user" }}
11321197
Hello!
11331198
`
1134-
err := os.WriteFile(mockPromptFile, []byte(mockPromptContent), 0644)
1199+
err := os.WriteFile(mockPromptFile, []byte(mockPromptContent), 0o644)
11351200
if err != nil {
11361201
t.Fatalf("Failed to create mock prompt file: %v", err)
11371202
}

0 commit comments

Comments
 (0)