From 4b2f037a9a263c71a594855da53cc9d428f1098c Mon Sep 17 00:00:00 2001 From: ha-m1-top-buddi Date: Sat, 5 Jul 2025 13:46:07 -0700 Subject: [PATCH 1/5] feat: adding streaming and tests --- client.go | 352 +++++++++++++++++++++++++++++++++- client_test.go | 501 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 848 insertions(+), 5 deletions(-) create mode 100644 client_test.go diff --git a/client.go b/client.go index 6550502..711d021 100644 --- a/client.go +++ b/client.go @@ -1,20 +1,25 @@ package bellman import ( + "bufio" "bytes" "context" "encoding/json" "errors" "fmt" - "github.com/modfin/bellman/models/embed" - "github.com/modfin/bellman/models/gen" - "github.com/modfin/bellman/prompt" - "github.com/modfin/bellman/tools" "io" "log/slog" + "net" "net/http" "net/url" "sync/atomic" + "time" + + "github.com/modfin/bellman/models" + "github.com/modfin/bellman/models/embed" + "github.com/modfin/bellman/models/gen" + "github.com/modfin/bellman/prompt" + "github.com/modfin/bellman/tools" ) const Provider = "Bellman" @@ -198,7 +203,344 @@ func (g *generator) SetRequest(request gen.Request) { } func (g *generator) Stream(conversation ...prompt.Prompt) (<-chan *gen.StreamResponse, error) { - return nil, errors.New("not implemented") + var reqc = atomic.AddInt64(&bellmanRequestNo, 1) + + // Check if streaming is supported + if !g.isStreamingSupported() { + return nil, fmt.Errorf("streaming is not supported for the current configuration") + } + + // Build streaming request with proper formatting + request, toolBelt, err := g.buildStreamingRequest(conversation) + if err != nil { + return nil, fmt.Errorf("could not build streaming request; %w", err) + } + + u, err := g.getStreamingEndpoint() + if err != nil { + return nil, fmt.Errorf("could not get streaming endpoint: %w", err) + } + + g.bellman.log("[gen] stream request", + "request", reqc, + "model", g.request.Model.FQN(), + "tools", len(g.request.Tools) > 0, + "tool_choice", g.request.ToolConfig != nil, + "output_schema", g.request.OutputSchema != nil, + "system_prompt", g.request.SystemPrompt != "", + "temperature", g.request.Temperature, + "top_p", g.request.TopP, + "max_tokens", g.request.MaxTokens, + "stop_sequences", g.request.StopSequences, + "stream", true, + ) + + body, err := json.Marshal(request) + if err != nil { + return nil, fmt.Errorf("could not marshal bellman request; %w", err) + } + + ctx := g.request.Context + if ctx == nil { + ctx = context.Background() + } + + req, err := http.NewRequestWithContext(ctx, "POST", u, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("could not create bellman request; %w", err) + } + + // Set streaming-specific headers + g.setStreamingHeaders(req) + + // Use optimized HTTP client for streaming + client := g.createStreamingHTTPClient() + res, err := client.Do(req) + if err != nil { + return nil, g.handleStreamingError(fmt.Errorf("could not post bellman request to %s; %w", u, err), reqc) + } + + if res.StatusCode != http.StatusOK { + b, readErr := io.ReadAll(res.Body) + res.Body.Close() + if readErr != nil { + return nil, g.handleStreamingError(fmt.Errorf("unexpected status code, %d, and failed to read response body: %w", res.StatusCode, readErr), reqc) + } + return nil, g.handleStreamingError(fmt.Errorf("unexpected status code, %d, err: {%s}", res.StatusCode, string(b)), reqc) + } + + reader := bufio.NewReader(res.Body) + stream := make(chan *gen.StreamResponse, 100) // Buffered channel for better performance + + go func() { + defer res.Body.Close() + defer close(stream) + + defer func() { + stream <- &gen.StreamResponse{ + Type: gen.TYPE_EOF, + } + }() + + // Handle context cancellation + ctx := g.request.Context + if ctx == nil { + ctx = context.Background() + } + + for { + // Check for context cancellation + select { + case <-ctx.Done(): + g.bellman.log("[gen] stream cancelled by context", "request", reqc, "error", ctx.Err()) + stream <- &gen.StreamResponse{ + Type: gen.TYPE_ERROR, + Content: fmt.Sprintf("stream cancelled: %v", ctx.Err()), + } + return + default: + // Continue processing + } + + line, _, err := reader.ReadLine() + if err != nil { + // If there's an error, check if it's EOF (end of stream) + if errors.Is(err, http.ErrBodyReadAfterClose) { + g.bellman.log("[gen] stream closed by server (Read after close)", "request", reqc) + break + } + if errors.Is(err, io.EOF) { + g.bellman.log("[gen] stream ended (EOF)", "request", reqc) + break + } + g.bellman.log("[gen] error reading from stream", "request", reqc, "error", err) + stream <- &gen.StreamResponse{ + Type: gen.TYPE_ERROR, + Content: fmt.Sprintf("error reading stream: %v", err), + } + break // Exit the loop on any other error + } + + if len(line) == 0 { + continue + } + if !bytes.HasPrefix(line, []byte("data: ")) { + stream <- &gen.StreamResponse{ + Type: gen.TYPE_ERROR, + Content: "expected 'data' header from sse", + } + break + } + line = line[6:] // removing header + + if bytes.Equal(line, []byte("[DONE]")) { + g.bellman.log("[gen] stream completed", "request", reqc) + break // Exit the loop on end of stream + } + + var streamResp gen.StreamResponse + err = json.Unmarshal(line, &streamResp) + if err != nil { + g.bellman.log("[gen] could not unmarshal stream chunk", "request", reqc, "error", err, "line", string(line)) + stream <- &gen.StreamResponse{ + Type: gen.TYPE_ERROR, + Content: fmt.Sprintf("could not unmarshal stream chunk: %v", err), + } + break + } + + // Process the streaming response + g.processStreamingResponse(&streamResp, toolBelt, reqc) + + // Send the response to the stream + select { + case stream <- &streamResp: + // Successfully sent + case <-ctx.Done(): + // Context was cancelled while trying to send + g.bellman.log("[gen] stream cancelled while sending response", "request", reqc, "error", ctx.Err()) + return + } + } + }() + + return stream, nil +} + +// buildStreamingRequest creates a properly formatted streaming request +func (g *generator) buildStreamingRequest(conversation []prompt.Prompt) (gen.FullRequest, map[string]*tools.Tool, error) { + request := gen.FullRequest{ + Request: g.request, + Prompts: conversation, + } + + // Ensure streaming is enabled + request.Stream = true + + // Validate request parameters for streaming + if err := g.validateStreamingRequest(&request); err != nil { + return request, nil, err + } + + // Build tool belt for tool call references + toolBelt := map[string]*tools.Tool{} + for _, tool := range g.request.Tools { + toolBelt[tool.Name] = &tool + } + + return request, toolBelt, nil +} + +// validateStreamingRequest validates request parameters for streaming +func (g *generator) validateStreamingRequest(request *gen.FullRequest) error { + if request.Model.Name == "" { + return fmt.Errorf("model is required for streaming request") + } + + // Validate that we have prompts + if len(request.Prompts) == 0 { + return fmt.Errorf("at least one prompt is required for streaming request") + } + + // Validate tool configuration if tools are present + if len(request.Tools) > 0 && request.ToolConfig != nil { + // Check if the specified tool exists + toolExists := false + for _, tool := range request.Tools { + if tool.Name == request.ToolConfig.Name { + toolExists = true + break + } + } + if !toolExists { + return fmt.Errorf("specified tool '%s' not found in available tools", request.ToolConfig.Name) + } + } + + return nil +} + +// setStreamingHeaders sets the appropriate headers for streaming requests +func (g *generator) setStreamingHeaders(req *http.Request) { + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+g.bellman.key.String()) + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Cache-Control", "no-cache") + req.Header.Set("Connection", "keep-alive") + req.Header.Set("X-Requested-With", "XMLHttpRequest") // Helps with some proxy configurations +} + +// createStreamingHTTPClient creates an HTTP client optimized for streaming +func (g *generator) createStreamingHTTPClient() *http.Client { + // Use a longer timeout for streaming requests + transport := &http.Transport{ + DisableCompression: true, // Disable compression for streaming + DisableKeepAlives: false, // Keep connections alive for streaming + } + + return &http.Client{ + Transport: transport, + // No timeout for streaming - let context handle cancellation + } +} + +// isRetryableError checks if an error is retryable for streaming requests +func (g *generator) isRetryableError(err error) bool { + if err == nil { + return false + } + + // Check for network-related errors that might be retryable + if errors.Is(err, context.DeadlineExceeded) { + return true + } + + // Check for temporary network errors + var netErr *net.OpError + if errors.As(err, &netErr) { + return netErr.Temporary() + } + + return false +} + +// handleStreamingError handles streaming-specific errors +func (g *generator) handleStreamingError(err error, reqc int64) error { + if g.isRetryableError(err) { + g.bellman.log("[gen] retryable streaming error", "request", reqc, "error", err) + return fmt.Errorf("retryable streaming error: %w", err) + } + + g.bellman.log("[gen] streaming error", "request", reqc, "error", err) + return fmt.Errorf("streaming error: %w", err) +} + +// logStreamingMetrics logs streaming-specific metrics +func (g *generator) logStreamingMetrics(reqc int64, metadata *models.Metadata) { + if metadata != nil { + g.bellman.log("[gen] stream metrics", + "request", reqc, + "model", g.request.Model.FQN(), + "token-input", metadata.InputTokens, + "token-output", metadata.OutputTokens, + "token-total", metadata.TotalTokens, + ) + } +} + +// isStreamingSupported checks if streaming is supported for the current configuration +func (g *generator) isStreamingSupported() bool { + // Basic checks for streaming support + if g.request.Model.Name == "" { + return false + } + + // Check if the model supports streaming (this could be enhanced with model-specific checks) + // For now, assume all models support streaming + return true +} + +// getStreamingEndpoint returns the appropriate streaming endpoint URL +func (g *generator) getStreamingEndpoint() (string, error) { + return url.JoinPath(g.bellman.url, "gen", "stream") +} + +// processStreamingResponse processes a streaming response and adds necessary references +func (g *generator) processStreamingResponse(streamResp *gen.StreamResponse, toolBelt map[string]*tools.Tool, reqc int64) { + // Add tool references for tool calls + if streamResp.ToolCall != nil && streamResp.ToolCall.Ref == nil { + if tool, exists := toolBelt[streamResp.ToolCall.Name]; exists { + streamResp.ToolCall.Ref = tool + } + } + + // Log metrics if metadata is present + if streamResp.Type == gen.TYPE_METADATA && streamResp.Metadata != nil { + g.logStreamingMetrics(reqc, streamResp.Metadata) + } +} + +// StreamWithTimeout creates a streaming request with a timeout context +func (g *generator) StreamWithTimeout(conversation []prompt.Prompt, timeout time.Duration) (<-chan *gen.StreamResponse, error) { + ctx := g.request.Context + if ctx == nil { + ctx = context.Background() + } + + timeoutCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + // Create a new request with timeout context + timeoutRequest := g.request + timeoutRequest.Context = timeoutCtx + + // Create a temporary generator with timeout context + tempGen := &generator{ + bellman: g.bellman, + request: timeoutRequest, + } + + return tempGen.Stream(conversation...) } func (g *generator) Prompt(conversation ...prompt.Prompt) (*gen.Response, error) { diff --git a/client_test.go b/client_test.go new file mode 100644 index 0000000..3000c60 --- /dev/null +++ b/client_test.go @@ -0,0 +1,501 @@ +package bellman + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/modfin/bellman/models/gen" + "github.com/modfin/bellman/prompt" + "github.com/modfin/bellman/schema" + "github.com/modfin/bellman/tools" +) + +func TestStreamingClient_Validation(t *testing.T) { + client := New("http://localhost:8080", Key{Name: "test", Token: "test-token"}) + + tests := []struct { + name string + model gen.Model + prompts []prompt.Prompt + expectError bool + errorMsg string + }{ + { + name: "empty prompts should fail", + model: gen.Model{Provider: "openai", Name: "gpt-3.5-turbo"}, + prompts: []prompt.Prompt{}, + expectError: true, + errorMsg: "at least one prompt is required", + }, + { + name: "valid request should pass validation", + model: gen.Model{Provider: "openai", Name: "gpt-3.5-turbo"}, + prompts: []prompt.Prompt{{Role: prompt.UserRole, Text: "test"}}, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + generator := client.Generator(gen.WithModel(tt.model)) + _, err := generator.Stream(tt.prompts...) + + if tt.expectError { + if err == nil { + t.Errorf("expected error but got none") + return + } + if !strings.Contains(err.Error(), tt.errorMsg) { + t.Errorf("expected error to contain '%s', got '%s'", tt.errorMsg, err.Error()) + } + } else { + if err == nil { + t.Errorf("expected no error but got: %v", err) + } + } + }) + } +} + +func TestStreamingClient_RequestBuilding(t *testing.T) { + client := New("http://localhost:8080", Key{Name: "test", Token: "test-token"}) + generator := client.Generator( + gen.WithModel(gen.Model{Provider: "openai", Name: "gpt-3.5-turbo"}), + gen.WithTemperature(0.7), + gen.WithMaxTokens(100), + ) + + prompts := []prompt.Prompt{ + {Role: prompt.UserRole, Text: "Hello"}, + {Role: prompt.AssistantRole, Text: "Hi there!"}, + } + + // Test that the request is built correctly + stream, err := generator.Stream(prompts...) + if err == nil { + t.Errorf("expected error (server not running) but got none") + return + } + + // The error should indicate the server is not reachable + if !strings.Contains(err.Error(), "connection refused") && + !strings.Contains(err.Error(), "no such host") && + !strings.Contains(err.Error(), "unexpected status code") { + t.Errorf("expected connection error, got: %v", err) + } + + if stream != nil { + t.Errorf("expected nil stream when server is not available") + } +} + +func TestStreamingClient_SSEHandling(t *testing.T) { + // Create a test server that returns SSE data + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check that the request is for streaming + if !strings.Contains(r.URL.Path, "/gen/stream") { + t.Errorf("expected request to /gen/stream, got %s", r.URL.Path) + } + + // Check SSE headers + if r.Header.Get("Accept") != "text/event-stream" { + t.Errorf("expected Accept: text/event-stream, got %s", r.Header.Get("Accept")) + } + + // Set SSE headers + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + + // Send SSE data + responses := []string{ + `{"type": "delta", "role": "assistant", "content": "Hello"}`, + `{"type": "delta", "role": "assistant", "content": " world"}`, + `{"type": "metadata", "metadata": {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}}`, + `[DONE]`, + } + + for _, response := range responses { + fmt.Fprintf(w, "data: %s\n\n", response) + w.(http.Flusher).Flush() + } + })) + defer server.Close() + + // Create client pointing to test server + client := New(server.URL, Key{Name: "test", Token: "test-token"}) + generator := client.Generator(gen.WithModel(gen.Model{Provider: "openai", Name: "gpt-3.5-turbo"})) + + prompts := []prompt.Prompt{{Role: prompt.UserRole, Text: "Hello"}} + + stream, err := generator.Stream(prompts...) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Read from stream + var responses []*gen.StreamResponse + for response := range stream { + responses = append(responses, response) + } + + // Verify responses + if len(responses) != 4 { // 3 data + 1 EOF + t.Errorf("expected 4 responses, got %d", len(responses)) + } + + // Check first delta response + if responses[0].Type != gen.TYPE_DELTA { + t.Errorf("expected TYPE_DELTA, got %s", responses[0].Type) + } + if responses[0].Content != "Hello" { + t.Errorf("expected content 'Hello', got '%s'", responses[0].Content) + } + + // Check second delta response + if responses[1].Type != gen.TYPE_DELTA { + t.Errorf("expected TYPE_DELTA, got %s", responses[1].Type) + } + if responses[1].Content != " world" { + t.Errorf("expected content ' world', got '%s'", responses[1].Content) + } + + // Check metadata response + if responses[2].Type != gen.TYPE_METADATA { + t.Errorf("expected TYPE_METADATA, got %s", responses[2].Type) + } + if responses[2].Metadata == nil { + t.Errorf("expected metadata, got nil") + } + + // Check EOF response + if responses[3].Type != gen.TYPE_EOF { + t.Errorf("expected TYPE_EOF, got %s", responses[3].Type) + } +} + +func TestStreamingClient_ErrorHandling(t *testing.T) { + // Test server that returns an error + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(`{"error": "server error"}`)) + })) + defer server.Close() + + client := New(server.URL, Key{Name: "test", Token: "test-token"}) + generator := client.Generator(gen.WithModel(gen.Model{Provider: "openai", Name: "gpt-3.5-turbo"})) + + prompts := []prompt.Prompt{{Role: prompt.UserRole, Text: "Hello"}} + + _, err := generator.Stream(prompts...) + if err == nil { + t.Errorf("expected error but got none") + } + + if !strings.Contains(err.Error(), "unexpected status code, 500") { + t.Errorf("expected 500 error, got: %v", err) + } +} + +func TestStreamingClient_ContextCancellation(t *testing.T) { + // Test server that delays response + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(2 * time.Second) // Delay to allow cancellation + w.Header().Set("Content-Type", "text/event-stream") + fmt.Fprintf(w, "data: {\"type\": \"delta\", \"content\": \"Hello\"}\n\n") + })) + defer server.Close() + + client := New(server.URL, Key{Name: "test", Token: "test-token"}) + + // Create context with short timeout + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + generator := client.Generator( + gen.WithModel(gen.Model{Provider: "openai", Name: "gpt-3.5-turbo"}), + gen.WithContext(ctx), + ) + + prompts := []prompt.Prompt{{Role: prompt.UserRole, Text: "Hello"}} + + _, err := generator.Stream(prompts...) + if err == nil { + t.Errorf("expected error for context cancellation, got none") + } + + // Check that the error is related to context cancellation + if !strings.Contains(err.Error(), "context deadline exceeded") { + t.Errorf("expected context deadline error, got: %v", err) + } +} + +func TestStreamingClient_ToolCallSupport(t *testing.T) { + // Create test server that returns tool calls + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + // Send simple delta response first + fmt.Fprintf(w, "data: {\"type\": \"delta\", \"content\": \"I'll check the weather\"}\n\n") + + // Send tool call response + arg := base64.StdEncoding.EncodeToString([]byte(`{"location":"New York"}`)) + toolCall := fmt.Sprintf(`{"type":"delta","role":"assistant","content":"","tool_call":{"id":"call_123","name":"get_weather","argument":"%s"}}`, arg) + + fmt.Fprintf(w, "data: %s\n\n", toolCall) + fmt.Fprintf(w, "data: [DONE]\n\n") + })) + defer server.Close() + + // Create tool + weatherTool := tools.Tool{ + Name: "get_weather", + Description: "Get weather information", + ArgumentSchema: &schema.JSON{ + Type: "object", + Properties: map[string]*schema.JSON{ + "location": {Type: "string"}, + }, + }, + } + + client := New(server.URL, Key{Name: "test", Token: "test-token"}) + generator := client.Generator( + gen.WithModel(gen.Model{Provider: "openai", Name: "gpt-3.5-turbo"}), + gen.WithTools(weatherTool), + ) + + prompts := []prompt.Prompt{{Role: prompt.UserRole, Text: "What's the weather?"}} + + stream, err := generator.Stream(prompts...) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Read responses + var responses []*gen.StreamResponse + for response := range stream { + responses = append(responses, response) + } + + // Check tool call response + if len(responses) < 3 { + t.Errorf("expected at least 3 responses, got %d", len(responses)) + } + + // Check first response (should be delta) + firstResponse := responses[0] + if firstResponse.Type != gen.TYPE_DELTA { + t.Errorf("expected first response to be TYPE_DELTA, got %s", firstResponse.Type) + } + + // Check second response (should be tool call) + toolResponse := responses[1] + t.Logf("Tool response type: %s", toolResponse.Type) + t.Logf("Tool response content: %s", toolResponse.Content) + + if toolResponse.Type != gen.TYPE_DELTA { + t.Errorf("expected TYPE_DELTA, got %s", toolResponse.Type) + } + + // Only check if tool call exists, don't access its fields to avoid nil pointer + if toolResponse.ToolCall == nil { + t.Errorf("expected tool call, got nil") + } +} + +func TestStreamingClient_InvalidSSEFormat(t *testing.T) { + // Test server that returns invalid SSE format + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.Write([]byte("invalid sse format\n")) // No "data: " prefix + })) + defer server.Close() + + client := New(server.URL, Key{Name: "test", Token: "test-token"}) + generator := client.Generator(gen.WithModel(gen.Model{Provider: "openai", Name: "gpt-3.5-turbo"})) + + prompts := []prompt.Prompt{{Role: prompt.UserRole, Text: "Hello"}} + + stream, err := generator.Stream(prompts...) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Read from stream - should get error response + response := <-stream + if response.Type != gen.TYPE_ERROR { + t.Errorf("expected TYPE_ERROR for invalid SSE, got %s", response.Type) + } + + if !strings.Contains(response.Content, "expected 'data' header from sse") { + t.Errorf("expected SSE format error, got: %s", response.Content) + } +} + +func TestStreamingClient_JSONParsingError(t *testing.T) { + // Test server that returns invalid JSON + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + fmt.Fprintf(w, "data: {invalid json}\n\n") // Invalid JSON + })) + defer server.Close() + + client := New(server.URL, Key{Name: "test", Token: "test-token"}) + generator := client.Generator(gen.WithModel(gen.Model{Provider: "openai", Name: "gpt-3.5-turbo"})) + + prompts := []prompt.Prompt{{Role: prompt.UserRole, Text: "Hello"}} + + stream, err := generator.Stream(prompts...) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Read from stream - should get error response + response := <-stream + if response.Type != gen.TYPE_ERROR { + t.Errorf("expected TYPE_ERROR for invalid JSON, got %s", response.Type) + } + + if !strings.Contains(response.Content, "could not unmarshal stream chunk") { + t.Errorf("expected JSON parsing error, got: %s", response.Content) + } +} + +func TestStreamingClient_RequestHeaders(t *testing.T) { + var receivedHeaders http.Header + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHeaders = r.Header + w.Header().Set("Content-Type", "text/event-stream") + fmt.Fprintf(w, "data: [DONE]\n\n") + })) + defer server.Close() + + client := New(server.URL, Key{Name: "test", Token: "test-token"}) + generator := client.Generator(gen.WithModel(gen.Model{Provider: "openai", Name: "gpt-3.5-turbo"})) + + prompts := []prompt.Prompt{{Role: prompt.UserRole, Text: "Hello"}} + + stream, err := generator.Stream(prompts...) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Read to completion + for range stream { + } + + // Check headers + expectedHeaders := map[string]string{ + "Content-Type": "application/json", + "Authorization": "Bearer test_test-token", + "Accept": "text/event-stream", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + } + + for key, expectedValue := range expectedHeaders { + if receivedHeaders.Get(key) != expectedValue { + t.Errorf("expected header %s: %s, got: %s", key, expectedValue, receivedHeaders.Get(key)) + } + } +} + +func TestStreamingClient_RequestBody(t *testing.T) { + var requestBody []byte + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Read the request body + bodyBytes, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("failed to read request body: %v", err) + } + requestBody = bodyBytes + + w.Header().Set("Content-Type", "text/event-stream") + fmt.Fprintf(w, "data: [DONE]\n\n") + })) + defer server.Close() + + client := New(server.URL, Key{Name: "test", Token: "test-token"}) + generator := client.Generator( + gen.WithModel(gen.Model{Provider: "openai", Name: "gpt-3.5-turbo"}), + gen.WithTemperature(0.7), + gen.WithMaxTokens(100), + ) + + prompts := []prompt.Prompt{{Role: prompt.UserRole, Text: "Hello"}} + + stream, err := generator.Stream(prompts...) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Read to completion + for range stream { + } + + // Parse request body + var request gen.FullRequest + err = json.Unmarshal(requestBody, &request) + if err != nil { + t.Fatalf("failed to parse request body: %v", err) + } + + // Check that streaming is enabled + if !request.Stream { + t.Errorf("expected Stream to be true") + } + + // Check model + if request.Model.Name != "gpt-3.5-turbo" { + t.Errorf("expected model name 'gpt-3.5-turbo', got '%s'", request.Model.Name) + } + + // Check prompts + if len(request.Prompts) != 1 { + t.Errorf("expected 1 prompt, got %d", len(request.Prompts)) + } + + if request.Prompts[0].Text != "Hello" { + t.Errorf("expected prompt text 'Hello', got '%s'", request.Prompts[0].Text) + } +} + +func BenchmarkStreamingClient(b *testing.B) { + // Create a simple test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + for i := 0; i < 10; i++ { + fmt.Fprintf(w, "data: {\"type\": \"delta\", \"content\": \"chunk %d\"}\n\n", i) + } + fmt.Fprintf(w, "data: [DONE]\n\n") + })) + defer server.Close() + + client := New(server.URL, Key{Name: "test", Token: "test-token"}) + generator := client.Generator(gen.WithModel(gen.Model{Provider: "openai", Name: "gpt-3.5-turbo"})) + + prompts := []prompt.Prompt{{Role: prompt.UserRole, Text: "Hello"}} + + b.ResetTimer() + for i := 0; i < b.N; i++ { + stream, err := generator.Stream(prompts...) + if err != nil { + b.Fatalf("unexpected error: %v", err) + } + + // Read all responses + for range stream { + } + } +} From ec704da6442d305308b84ba89febe7f468a6b0b8 Mon Sep 17 00:00:00 2001 From: ha-m1-top-buddi Date: Sat, 5 Jul 2025 13:47:38 -0700 Subject: [PATCH 2/5] chore: add tests --- .github/workflows/go.yml | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 .github/workflows/go.yml diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml new file mode 100644 index 0000000..838278a --- /dev/null +++ b/.github/workflows/go.yml @@ -0,0 +1,35 @@ +name: Go CI + +on: + push: + branches: [master] + pull_request: + branches: [master] + +jobs: + test: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version-file: go.mod + + - name: Cache Go modules + uses: actions/cache@v4 + with: + path: | + ~/.cache/go-build + ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go- + + - name: Install dependencies + run: go mod download + + - name: Run tests + run: go test -v ./... From cc4b432a2f33861fb590792cf36e1fca759fe1b2 Mon Sep 17 00:00:00 2001 From: ha-m1-top-buddi Date: Sat, 5 Jul 2025 13:49:44 -0700 Subject: [PATCH 3/5] chore: testing --- services/openai/llm.go | 11 ++++++----- services/vertexai/llm.go | 13 +++++++------ 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/services/openai/llm.go b/services/openai/llm.go index dbd77f0..0ad28cf 100644 --- a/services/openai/llm.go +++ b/services/openai/llm.go @@ -7,14 +7,15 @@ import ( "encoding/json" "errors" "fmt" - "github.com/modfin/bellman/models" - "github.com/modfin/bellman/models/gen" - "github.com/modfin/bellman/prompt" - "github.com/modfin/bellman/tools" "io" "log" "net/http" "sync/atomic" + + "github.com/modfin/bellman/models" + "github.com/modfin/bellman/models/gen" + "github.com/modfin/bellman/prompt" + "github.com/modfin/bellman/tools" ) var requestNo int64 @@ -111,7 +112,7 @@ func (g *generator) Stream(conversation ...prompt.Prompt) (<-chan *gen.StreamRes var ss openaiStreamResponse err = json.Unmarshal(line, &ss) if err != nil { - log.Printf("could not unmarshal chunk, %w", err) + log.Printf("could not unmarshal chunk, %v", err) break } diff --git a/services/vertexai/llm.go b/services/vertexai/llm.go index a7675dc..4b3b9f8 100644 --- a/services/vertexai/llm.go +++ b/services/vertexai/llm.go @@ -7,16 +7,17 @@ import ( "encoding/json" "errors" "fmt" - "github.com/modfin/bellman/models" - "github.com/modfin/bellman/models/gen" - "github.com/modfin/bellman/prompt" - "github.com/modfin/bellman/schema" - "github.com/modfin/bellman/tools" "io" "log" "net/http" "sync/atomic" "time" + + "github.com/modfin/bellman/models" + "github.com/modfin/bellman/models/gen" + "github.com/modfin/bellman/prompt" + "github.com/modfin/bellman/schema" + "github.com/modfin/bellman/tools" ) var requestNo int64 @@ -99,7 +100,7 @@ func (g *generator) Stream(prompts ...prompt.Prompt) (<-chan *gen.StreamResponse var ss geminiStreamingResponse err = json.Unmarshal(line, &ss) if err != nil { - log.Printf("could not unmarshal chunk, %w", err) + log.Printf("could not unmarshal chunk, %v", err) break } From d6259d0f3c0ad27587b7b401ee18f19d26e355e0 Mon Sep 17 00:00:00 2001 From: ha-m1-top-buddi Date: Sat, 5 Jul 2025 13:59:30 -0700 Subject: [PATCH 4/5] feat: adding stream --- bellmand/bellamnd.go | 163 +++++++++++++++++-- bellmand/streaming_test.go | 317 +++++++++++++++++++++++++++++++++++++ 2 files changed, 467 insertions(+), 13 deletions(-) create mode 100644 bellmand/streaming_test.go diff --git a/bellmand/bellamnd.go b/bellmand/bellamnd.go index 77cbd5d..2da492a 100644 --- a/bellmand/bellamnd.go +++ b/bellmand/bellamnd.go @@ -5,6 +5,19 @@ import ( "encoding/base64" "encoding/json" "fmt" + "io" + "log" + "log/slog" + "math/rand" + "net/http" + "net/url" + "os" + "os/signal" + "slices" + "strings" + "syscall" + "time" + "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/google/uuid" @@ -23,18 +36,6 @@ import ( "github.com/prometheus/client_golang/prometheus/push" slogchi "github.com/samber/slog-chi" "github.com/urfave/cli/v2" - "io" - "log" - "log/slog" - "math/rand" - "net/http" - "net/url" - "os" - "os/signal" - "slices" - "strings" - "syscall" - "time" ) var logger *slog.Logger @@ -454,7 +455,25 @@ func Gen(proxy *bellman.Proxy, cfg Config) func(r chi.Router) { }, []string{"model", "key", "type"}, ) - prometheus.MustRegister(reqCounter, tokensCounter) + + var streamReqCounter = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "bellman_gen_stream_request_count", + Help: "Number of streaming request per key", + ConstLabels: nil, + }, + []string{"model", "key"}, + ) + + var streamTokensCounter = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "bellman_gen_stream_token_count", + Help: "Number of token processed by model and key in streaming mode", + ConstLabels: nil, + }, + []string{"model", "key", "type"}, + ) + prometheus.MustRegister(reqCounter, tokensCounter, streamReqCounter, streamTokensCounter) return func(r chi.Router) { r.Use(auth(cfg)) @@ -513,6 +532,124 @@ func Gen(proxy *bellman.Proxy, cfg Config) func(r chi.Router) { }) + // New streaming endpoint + r.Post("/stream", func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + err = fmt.Errorf("could not read request, %w", err) + httpErr(w, err, http.StatusBadRequest) + return + } + + var req gen.FullRequest + err = json.Unmarshal(body, &req) + if err != nil { + err = fmt.Errorf("could not decode request, %w", err) + httpErr(w, err, http.StatusBadRequest) + return + } + + // Force streaming mode + req.Stream = true + + gen, err := proxy.Gen(req.Model) + if err != nil { + err = fmt.Errorf("could not get generator, %w", err) + httpErr(w, err, http.StatusInternalServerError) + return + } + + gen = gen.SetConfig(req.Request).WithContext(r.Context()) + + // Get streaming response + stream, err := gen.Stream(req.Prompts...) + if err != nil { + logger.Error("gen stream request", "err", err) + err = fmt.Errorf("could not start streaming, %w", err) + httpErr(w, err, http.StatusInternalServerError) + return + } + + keyName := r.Context().Value("api-key-name") + logger.Info("gen stream request", + "key", keyName, + "model", req.Model.FQN(), + ) + + // Set SSE headers + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Headers", "Cache-Control") + + // Ensure the response is flushed immediately + if flusher, ok := w.(http.Flusher); ok { + flusher.Flush() + } + + // Track metrics + var totalInputTokens, totalOutputTokens int + var modelName string + + // Process streaming responses + for streamResp := range stream { + // Handle context cancellation + select { + case <-r.Context().Done(): + logger.Info("gen stream cancelled", "key", keyName, "model", req.Model.FQN()) + return + default: + } + + // Update metrics + if streamResp.Metadata != nil { + totalInputTokens = streamResp.Metadata.InputTokens + totalOutputTokens = streamResp.Metadata.OutputTokens + modelName = streamResp.Metadata.Model + } + + // Convert to SSE format + data, err := json.Marshal(streamResp) + if err != nil { + logger.Error("gen stream marshal error", "err", err) + continue + } + + // Write SSE event + _, err = fmt.Fprintf(w, "data: %s\n\n", data) + if err != nil { + logger.Error("gen stream write error", "err", err) + break + } + + // Flush the response + if flusher, ok := w.(http.Flusher); ok { + flusher.Flush() + } + + // Check for end of stream + if streamResp.Type == "EOF" { + break + } + } + + // Log final metrics + logger.Info("gen stream completed", + "key", keyName, + "model", req.Model.FQN(), + "token-input", totalInputTokens, + "token-output", totalOutputTokens, + "token-total", totalInputTokens+totalOutputTokens, + ) + + // Update metrics + streamReqCounter.WithLabelValues(modelName, keyName.(string)).Inc() + streamTokensCounter.WithLabelValues(modelName, keyName.(string), "total").Add(float64(totalInputTokens + totalOutputTokens)) + streamTokensCounter.WithLabelValues(modelName, keyName.(string), "input").Add(float64(totalInputTokens)) + streamTokensCounter.WithLabelValues(modelName, keyName.(string), "output").Add(float64(totalOutputTokens)) + }) + } } diff --git a/bellmand/streaming_test.go b/bellmand/streaming_test.go new file mode 100644 index 0000000..33ebe46 --- /dev/null +++ b/bellmand/streaming_test.go @@ -0,0 +1,317 @@ +package main + +import ( + "bytes" + "encoding/json" + "fmt" + "log/slog" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/go-chi/chi/v5" + "github.com/modfin/bellman" + "github.com/modfin/bellman/models" + "github.com/modfin/bellman/models/gen" + "github.com/modfin/bellman/prompt" + "github.com/prometheus/client_golang/prometheus" +) + +// Mock generator that implements the required interface +type mockGen struct { + provider string +} + +func (m *mockGen) Provider() string { + return m.provider +} + +func (m *mockGen) Generator(options ...gen.Option) *gen.Generator { + gen := &gen.Generator{ + Prompter: &mockPrompter{}, + Request: gen.Request{}, + } + for _, op := range options { + gen = op(gen) + } + return gen +} + +// Mock prompter that implements the required interface +type mockPrompter struct { + request gen.Request +} + +func (m *mockPrompter) SetRequest(request gen.Request) { + m.request = request +} + +func (m *mockPrompter) Prompt(conversation ...prompt.Prompt) (*gen.Response, error) { + return nil, fmt.Errorf("not implemented") +} + +func (m *mockPrompter) Stream(conversation ...prompt.Prompt) (<-chan *gen.StreamResponse, error) { + stream := make(chan *gen.StreamResponse, 10) + + go func() { + defer close(stream) + + // Send a few streaming responses + stream <- &gen.StreamResponse{ + Type: gen.TYPE_DELTA, + Role: prompt.AssistantRole, + Content: "Hello", + } + + stream <- &gen.StreamResponse{ + Type: gen.TYPE_DELTA, + Role: prompt.AssistantRole, + Content: " world", + } + + stream <- &gen.StreamResponse{ + Type: gen.TYPE_METADATA, + Metadata: &models.Metadata{ + Model: "test-model", + InputTokens: 10, + OutputTokens: 5, + TotalTokens: 15, + }, + } + + stream <- &gen.StreamResponse{ + Type: gen.TYPE_EOF, + } + }() + + return stream, nil +} + +func TestStreamingEndpoint_Basic(t *testing.T) { + // Initialize logger for testing + logger = slog.Default() + + // Reset prometheus registry to avoid duplicate registration + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + // Create a test configuration + cfg := Config{ + ApiKeys: []string{"test-key"}, + HttpPort: 8080, + } + + // Create a real proxy and register a mock generator + proxy := bellman.NewProxy() + mockGen := &mockGen{provider: "test"} + proxy.RegisterGen(mockGen) + + // Create the router + r := chi.NewRouter() + r.Use(auth(cfg)) + r.Route("/gen", Gen(proxy, cfg)) + + // Create test request + request := gen.FullRequest{ + Request: gen.Request{ + Model: gen.Model{Provider: "test", Name: "test-model"}, + }, + Prompts: []prompt.Prompt{ + {Role: prompt.UserRole, Text: "Hello, how are you?"}, + }, + } + + body, _ := json.Marshal(request) + + // Create HTTP request + req := httptest.NewRequest("POST", "/gen/stream", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer test-key_test-key") + + // Create response recorder + w := httptest.NewRecorder() + + // Serve the request + r.ServeHTTP(w, req) + + // Check response status + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + // Check SSE headers + if w.Header().Get("Content-Type") != "text/event-stream" { + t.Errorf("expected Content-Type text/event-stream, got %s", w.Header().Get("Content-Type")) + } + + if w.Header().Get("Cache-Control") != "no-cache" { + t.Errorf("expected Cache-Control no-cache, got %s", w.Header().Get("Cache-Control")) + } + + if w.Header().Get("Connection") != "keep-alive" { + t.Errorf("expected Connection keep-alive, got %s", w.Header().Get("Connection")) + } + + // Check response body contains SSE data + bodyStr := w.Body.String() + if !strings.Contains(bodyStr, "data: ") { + t.Errorf("expected SSE data format, got: %s", bodyStr) + } + + // Check that we have multiple SSE events + lines := strings.Split(bodyStr, "\n") + dataLines := 0 + for _, line := range lines { + if strings.HasPrefix(line, "data: ") { + dataLines++ + } + } + + if dataLines < 2 { + t.Errorf("expected at least 2 SSE data events, got %d", dataLines) + } + + // Verify the content of the SSE events + if !strings.Contains(bodyStr, `"type":"delta"`) { + t.Errorf("expected delta type in SSE data") + } + + if !strings.Contains(bodyStr, `"type":"metadata"`) { + t.Errorf("expected metadata type in SSE data") + } + + if !strings.Contains(bodyStr, `"type":"EOF"`) { + t.Errorf("expected EOF type in SSE data") + } +} + +func TestStreamingEndpoint_Authentication(t *testing.T) { + // Initialize logger for testing + logger = slog.Default() + + // Reset prometheus registry to avoid duplicate registration + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + // Create a test configuration + cfg := Config{ + ApiKeys: []string{"test-key"}, + HttpPort: 8080, + } + + // Create a real proxy + proxy := bellman.NewProxy() + + // Create the router + r := chi.NewRouter() + r.Use(auth(cfg)) + r.Route("/gen", Gen(proxy, cfg)) + + // Create test request without authentication + request := gen.FullRequest{ + Request: gen.Request{ + Model: gen.Model{Provider: "test", Name: "test-model"}, + }, + Prompts: []prompt.Prompt{ + {Role: prompt.UserRole, Text: "Hello"}, + }, + } + + body, _ := json.Marshal(request) + req := httptest.NewRequest("POST", "/gen/stream", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + // No Authorization header + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + // Should return 401 Unauthorized + if w.Code != http.StatusUnauthorized { + t.Errorf("expected status 401, got %d", w.Code) + } +} + +func TestStreamingEndpoint_InvalidRequest(t *testing.T) { + // Initialize logger for testing + logger = slog.Default() + + // Reset prometheus registry to avoid duplicate registration + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + // Create a test configuration + cfg := Config{ + ApiKeys: []string{"test-key"}, + HttpPort: 8080, + } + + // Create a real proxy + proxy := bellman.NewProxy() + + // Create the router + r := chi.NewRouter() + r.Use(auth(cfg)) + r.Route("/gen", Gen(proxy, cfg)) + + // Create invalid JSON request + req := httptest.NewRequest("POST", "/gen/stream", strings.NewReader("invalid json")) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer test-key_test-key") + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + // Should return 400 Bad Request + if w.Code != http.StatusBadRequest { + t.Errorf("expected status 400, got %d", w.Code) + } +} + +func TestStreamingEndpoint_ProviderNotFound(t *testing.T) { + // Initialize logger for testing + logger = slog.Default() + + // Reset prometheus registry to avoid duplicate registration + prometheus.DefaultRegisterer = prometheus.NewRegistry() + + // Create a test configuration + cfg := Config{ + ApiKeys: []string{"test-key"}, + HttpPort: 8080, + } + + // Create a real proxy without registering any providers + proxy := bellman.NewProxy() + + // Create the router + r := chi.NewRouter() + r.Use(auth(cfg)) + r.Route("/gen", Gen(proxy, cfg)) + + // Create test request with unknown provider + request := gen.FullRequest{ + Request: gen.Request{ + Model: gen.Model{Provider: "unknown", Name: "test-model"}, + }, + Prompts: []prompt.Prompt{ + {Role: prompt.UserRole, Text: "Hello"}, + }, + } + + body, _ := json.Marshal(request) + req := httptest.NewRequest("POST", "/gen/stream", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer test-key_test-key") + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + // Should return 500 Internal Server Error + if w.Code != http.StatusInternalServerError { + t.Errorf("expected status 500, got %d", w.Code) + } + + // Check error message + bodyStr := w.Body.String() + if !strings.Contains(bodyStr, "could not get generator") { + t.Errorf("expected error about generator, got: %s", bodyStr) + } +} From 69cbef4ec3ba04cde174e72852bbf15e47c81ef9 Mon Sep 17 00:00:00 2001 From: ha-m1-top-buddi Date: Sat, 5 Jul 2025 14:23:04 -0700 Subject: [PATCH 5/5] feat: more tests and models --- .gitignore | 1 - agent/agent_test.go | 846 +++++++++++++++++++++++++++++++++++ models/gen/generator_test.go | 232 ++++++++++ models/gen/model_test.go | 27 ++ models/gen/response_test.go | 102 +++++ prompt/prompt_test.go | 147 ++++++ 6 files changed, 1354 insertions(+), 1 deletion(-) create mode 100644 agent/agent_test.go create mode 100644 models/gen/generator_test.go create mode 100644 models/gen/model_test.go create mode 100644 models/gen/response_test.go create mode 100644 prompt/prompt_test.go diff --git a/.gitignore b/.gitignore index b5ca7d1..485dee6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1 @@ .idea -model_test.go diff --git a/agent/agent_test.go b/agent/agent_test.go new file mode 100644 index 0000000..eaf053f --- /dev/null +++ b/agent/agent_test.go @@ -0,0 +1,846 @@ +package agent + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/modfin/bellman/models" + "github.com/modfin/bellman/models/gen" + "github.com/modfin/bellman/prompt" + "github.com/modfin/bellman/schema" + "github.com/modfin/bellman/tools" +) + +// Mock prompter for testing +type mockPrompter struct { + request gen.Request + responses []*gen.Response + responseIndex int + shouldError bool + errorMessage string +} + +func (m *mockPrompter) SetRequest(request gen.Request) { + m.request = request +} + +func (m *mockPrompter) Prompt(prompts ...prompt.Prompt) (*gen.Response, error) { + if m.shouldError { + return nil, errors.New(m.errorMessage) + } + + if m.responseIndex >= len(m.responses) { + return nil, errors.New("no more responses available") + } + + response := m.responses[m.responseIndex] + m.responseIndex++ + return response, nil +} + +func (m *mockPrompter) Stream(prompts ...prompt.Prompt) (<-chan *gen.StreamResponse, error) { + return nil, errors.New("stream not implemented in mock") +} + +// Helper function to create a mock generator +func createMockGenerator(responses []*gen.Response) *gen.Generator { + return &gen.Generator{ + Prompter: &mockPrompter{responses: responses}, + Request: gen.Request{ + Context: context.Background(), + Model: gen.Model{ + Provider: "test", + Name: "test-model", + }, + }, + } +} + +// Helper function to create a text response +func createTextResponse(text string, inputTokens, outputTokens int) *gen.Response { + return &gen.Response{ + Texts: []string{text}, + Metadata: models.Metadata{ + Model: "test/test-model", + InputTokens: inputTokens, + OutputTokens: outputTokens, + TotalTokens: inputTokens + outputTokens, + }, + } +} + +// Helper function to create a tool response +func createToolResponse(toolCalls []tools.Call, inputTokens, outputTokens int) *gen.Response { + return &gen.Response{ + Tools: toolCalls, + Metadata: models.Metadata{ + Model: "test/test-model", + InputTokens: inputTokens, + OutputTokens: outputTokens, + TotalTokens: inputTokens + outputTokens, + }, + } +} + +// Helper function to create a tool call +func createToolCall(id, name string, argument []byte, ref *tools.Tool) tools.Call { + return tools.Call{ + ID: id, + Name: name, + Argument: argument, + Ref: ref, + } +} + +// Helper function to create a tool +func createTool(name string, description string, function tools.Function) tools.Tool { + return tools.Tool{ + Name: name, + Description: description, + Function: function, + } +} + +func TestRun_StringResult(t *testing.T) { + tests := []struct { + name string + maxDepth int + parallelism int + responses []*gen.Response + expectedResult string + expectedError bool + expectedDepth int + }{ + { + name: "successful string result on first try", + maxDepth: 3, + parallelism: 1, + responses: []*gen.Response{ + createTextResponse("Hello, World!", 10, 5), + }, + expectedResult: "Hello, World!", + expectedError: false, + expectedDepth: 0, + }, + { + name: "successful string result after tool calls", + maxDepth: 3, + parallelism: 1, + responses: []*gen.Response{ + createToolResponse([]tools.Call{ + createToolCall("call1", "test_tool", []byte(`{"arg": "value"}`), &tools.Tool{ + Name: "test_tool", + Function: func(ctx context.Context, call tools.Call) (string, error) { + return "tool result", nil + }, + }), + }, 10, 5), + createTextResponse("Final result", 15, 8), + }, + expectedResult: "Final result", + expectedError: false, + expectedDepth: 1, + }, + { + name: "max depth reached", + maxDepth: 1, + parallelism: 1, + responses: []*gen.Response{ + createToolResponse([]tools.Call{ + createToolCall("call1", "test_tool", []byte(`{"arg": "value"}`), &tools.Tool{ + Name: "test_tool", + Function: func(ctx context.Context, call tools.Call) (string, error) { + return "tool result", nil + }, + }), + }, 10, 5), + createToolResponse([]tools.Call{ + createToolCall("call2", "test_tool2", []byte(`{"arg": "value2"}`), &tools.Tool{ + Name: "test_tool2", + Function: func(ctx context.Context, call tools.Call) (string, error) { + return "tool result 2", nil + }, + }), + }, 15, 8), + }, + expectedError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := createMockGenerator(tt.responses) + + result, err := Run[string](tt.maxDepth, tt.parallelism, g, prompt.AsUser("test prompt")) + + if tt.expectedError { + if err == nil { + t.Errorf("expected error but got none") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if result.Result != tt.expectedResult { + t.Errorf("expected result %q, got %q", tt.expectedResult, result.Result) + } + + if result.Depth != tt.expectedDepth { + t.Errorf("expected depth %d, got %d", tt.expectedDepth, result.Depth) + } + + // Check metadata aggregation + expectedInputTokens := 0 + expectedOutputTokens := 0 + for _, resp := range tt.responses[:tt.expectedDepth+1] { + expectedInputTokens += resp.Metadata.InputTokens + expectedOutputTokens += resp.Metadata.OutputTokens + } + + if result.Metadata.InputTokens != expectedInputTokens { + t.Errorf("expected input tokens %d, got %d", expectedInputTokens, result.Metadata.InputTokens) + } + + if result.Metadata.OutputTokens != expectedOutputTokens { + t.Errorf("expected output tokens %d, got %d", expectedOutputTokens, result.Metadata.OutputTokens) + } + }) + } +} + +func TestRun_StructResult(t *testing.T) { + type TestStruct struct { + Message string `json:"message"` + Count int `json:"count"` + } + + tests := []struct { + name string + maxDepth int + parallelism int + responses []*gen.Response + expectedResult TestStruct + expectedError bool + }{ + { + name: "successful struct result", + maxDepth: 3, + parallelism: 1, + responses: []*gen.Response{ + createTextResponse(`{"message": "Hello", "count": 42}`, 10, 5), + }, + expectedResult: TestStruct{Message: "Hello", Count: 42}, + expectedError: false, + }, + { + name: "invalid JSON for struct", + maxDepth: 3, + parallelism: 1, + responses: []*gen.Response{ + createTextResponse(`{"message": "Hello", "count": "not a number"}`, 10, 5), + }, + expectedError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := createMockGenerator(tt.responses) + + result, err := Run[TestStruct](tt.maxDepth, tt.parallelism, g, prompt.AsUser("test prompt")) + + if tt.expectedError { + if err == nil { + t.Errorf("expected error but got none") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if result.Result != tt.expectedResult { + t.Errorf("expected result %+v, got %+v", tt.expectedResult, result.Result) + } + }) + } +} + +func TestRun_ToolValidationErrors(t *testing.T) { + tests := []struct { + name string + maxDepth int + parallelism int + responses []*gen.Response + expectedErrorContains string + }{ + { + name: "tool without ref", + maxDepth: 3, + parallelism: 1, + responses: []*gen.Response{ + createToolResponse([]tools.Call{ + createToolCall("call1", "test_tool", []byte(`{"arg": "value"}`), nil), + }, 10, 5), + }, + expectedErrorContains: "tool test_tool not found in local setup", + }, + { + name: "tool without function", + maxDepth: 3, + parallelism: 1, + responses: []*gen.Response{ + createToolResponse([]tools.Call{ + createToolCall("call1", "test_tool", []byte(`{"arg": "value"}`), &tools.Tool{ + Name: "test_tool", + // Function is nil + }), + }, 10, 5), + }, + expectedErrorContains: "tool test_tool has no callback function attached", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := createMockGenerator(tt.responses) + + _, err := Run[string](tt.maxDepth, tt.parallelism, g, prompt.AsUser("test prompt")) + + if err == nil { + t.Errorf("expected error but got none") + return + } + + if tt.expectedErrorContains != "" && !errors.Is(err, errors.New(tt.expectedErrorContains)) { + if !contains(err.Error(), tt.expectedErrorContains) { + t.Errorf("expected error to contain %q, got %q", tt.expectedErrorContains, err.Error()) + } + } + }) + } +} + +func TestRun_ToolExecutionErrors(t *testing.T) { + tests := []struct { + name string + maxDepth int + parallelism int + responses []*gen.Response + expectedErrorContains string + }{ + { + name: "tool execution fails", + maxDepth: 3, + parallelism: 1, + responses: []*gen.Response{ + createToolResponse([]tools.Call{ + createToolCall("call1", "test_tool", []byte(`{"arg": "value"}`), &tools.Tool{ + Name: "test_tool", + Function: func(ctx context.Context, call tools.Call) (string, error) { + return "", errors.New("tool execution failed") + }, + }), + }, 10, 5), + }, + expectedErrorContains: "tool test_tool failed", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := createMockGenerator(tt.responses) + + _, err := Run[string](tt.maxDepth, tt.parallelism, g, prompt.AsUser("test prompt")) + + if err == nil { + t.Errorf("expected error but got none") + return + } + + if tt.expectedErrorContains != "" && !contains(err.Error(), tt.expectedErrorContains) { + t.Errorf("expected error to contain %q, got %q", tt.expectedErrorContains, err.Error()) + } + }) + } +} + +func TestRun_PromptError(t *testing.T) { + g := &gen.Generator{ + Prompter: &mockPrompter{ + shouldError: true, + errorMessage: "prompt failed", + }, + Request: gen.Request{ + Context: context.Background(), + Model: gen.Model{ + Provider: "test", + Name: "test-model", + }, + }, + } + + _, err := Run[string](3, 1, g, prompt.AsUser("test prompt")) + + if err == nil { + t.Errorf("expected error but got none") + return + } + + if !contains(err.Error(), "failed to prompt") { + t.Errorf("expected error to contain 'failed to prompt', got %q", err.Error()) + } +} + +func TestRunWithToolsOnly_StringResult(t *testing.T) { + tests := []struct { + name string + maxDepth int + parallelism int + responses []*gen.Response + expectedResult string + expectedError bool + expectedDepth int + }{ + { + name: "successful string result with custom tool", + maxDepth: 3, + parallelism: 1, + responses: []*gen.Response{ + createToolResponse([]tools.Call{ + createToolCall("call1", customResultCalculatedTool, []byte(`"Hello, World!"`), &tools.Tool{ + Name: customResultCalculatedTool, + Function: func(ctx context.Context, call tools.Call) (string, error) { + return "result", nil + }, + }), + }, 10, 5), + }, + expectedResult: "Hello, World!", + expectedError: false, + expectedDepth: 0, + }, + { + name: "successful string result after other tool calls", + maxDepth: 3, + parallelism: 1, + responses: []*gen.Response{ + createToolResponse([]tools.Call{ + createToolCall("call1", "test_tool", []byte(`{"arg": "value"}`), &tools.Tool{ + Name: "test_tool", + Function: func(ctx context.Context, call tools.Call) (string, error) { + return "tool result", nil + }, + }), + }, 10, 5), + createToolResponse([]tools.Call{ + createToolCall("call2", customResultCalculatedTool, []byte(`"Final result"`), &tools.Tool{ + Name: customResultCalculatedTool, + Function: func(ctx context.Context, call tools.Call) (string, error) { + return "result", nil + }, + }), + }, 15, 8), + }, + expectedResult: "Final result", + expectedError: false, + expectedDepth: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := createMockGenerator(tt.responses) + + result, err := RunWithToolsOnly[string](tt.maxDepth, tt.parallelism, g, prompt.AsUser("test prompt")) + + if tt.expectedError { + if err == nil { + t.Errorf("expected error but got none") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if result.Result != tt.expectedResult { + t.Errorf("expected result %q, got %q", tt.expectedResult, result.Result) + } + + if result.Depth != tt.expectedDepth { + t.Errorf("expected depth %d, got %d", tt.expectedDepth, result.Depth) + } + }) + } +} + +func TestRunWithToolsOnly_StructResult(t *testing.T) { + type TestStruct struct { + Message string `json:"message"` + Count int `json:"count"` + } + + tests := []struct { + name string + maxDepth int + parallelism int + responses []*gen.Response + expectedResult TestStruct + expectedError bool + }{ + { + name: "successful struct result with custom tool", + maxDepth: 3, + parallelism: 1, + responses: []*gen.Response{ + createToolResponse([]tools.Call{ + createToolCall("call1", customResultCalculatedTool, []byte(`{"message": "Hello", "count": 42}`), &tools.Tool{ + Name: customResultCalculatedTool, + Function: func(ctx context.Context, call tools.Call) (string, error) { + return "result", nil + }, + }), + }, 10, 5), + }, + expectedResult: TestStruct{Message: "Hello", Count: 42}, + expectedError: false, + }, + { + name: "invalid JSON for struct in custom tool", + maxDepth: 3, + parallelism: 1, + responses: []*gen.Response{ + createToolResponse([]tools.Call{ + createToolCall("call1", customResultCalculatedTool, []byte(`{"message": "Hello", "count": "not a number"}`), &tools.Tool{ + Name: customResultCalculatedTool, + Function: func(ctx context.Context, call tools.Call) (string, error) { + return "result", nil + }, + }), + }, 10, 5), + }, + expectedError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := createMockGenerator(tt.responses) + + result, err := RunWithToolsOnly[TestStruct](tt.maxDepth, tt.parallelism, g, prompt.AsUser("test prompt")) + + if tt.expectedError { + if err == nil { + t.Errorf("expected error but got none") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if result.Result != tt.expectedResult { + t.Errorf("expected result %+v, got %+v", tt.expectedResult, result.Result) + } + }) + } +} + +func TestRunWithToolsOnly_RemovesCustomToolFromExistingTools(t *testing.T) { + // Create a generator with existing tools including the custom tool + existingTool := createTool("existing_tool", "existing tool", func(ctx context.Context, call tools.Call) (string, error) { + return "existing result", nil + }) + + customTool := createTool(customResultCalculatedTool, "custom tool", func(ctx context.Context, call tools.Call) (string, error) { + return "custom result", nil + }) + + g := &gen.Generator{ + Prompter: &mockPrompter{ + responses: []*gen.Response{ + createToolResponse([]tools.Call{ + createToolCall("call1", customResultCalculatedTool, []byte(`"result"`), &customTool), + }, 10, 5), + }, + }, + Request: gen.Request{ + Context: context.Background(), + Model: gen.Model{ + Provider: "test", + Name: "test-model", + }, + Tools: []tools.Tool{existingTool, customTool}, + }, + } + + result, err := RunWithToolsOnly[string](3, 1, g, prompt.AsUser("test prompt")) + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if result.Result != "result" { + t.Errorf("expected result 'result', got %q", result.Result) + } +} + +func TestExecuteCallbacksSequential(t *testing.T) { + ctx := context.Background() + + tool1 := createTool("tool1", "first tool", func(ctx context.Context, call tools.Call) (string, error) { + return "result1", nil + }) + + tool2 := createTool("tool2", "second tool", func(ctx context.Context, call tools.Call) (string, error) { + return "result2", nil + }) + + callbacks := []tools.Call{ + createToolCall("call1", "tool1", []byte(`{"arg": "value1"}`), &tool1), + createToolCall("call2", "tool2", []byte(`{"arg": "value2"}`), &tool2), + } + + results := executeCallbacksSequential(ctx, callbacks) + + if len(results) != 2 { + t.Errorf("expected 2 results, got %d", len(results)) + return + } + + if results[0].Response != "result1" || results[0].Error != nil { + t.Errorf("expected first result to be 'result1' with no error, got %q, %v", results[0].Response, results[0].Error) + } + + if results[1].Response != "result2" || results[1].Error != nil { + t.Errorf("expected second result to be 'result2' with no error, got %q, %v", results[1].Response, results[1].Error) + } +} + +func TestExecuteCallbacksSequential_WithError(t *testing.T) { + ctx := context.Background() + + tool1 := createTool("tool1", "first tool", func(ctx context.Context, call tools.Call) (string, error) { + return "result1", nil + }) + + tool2 := createTool("tool2", "second tool", func(ctx context.Context, call tools.Call) (string, error) { + return "", errors.New("tool2 failed") + }) + + callbacks := []tools.Call{ + createToolCall("call1", "tool1", []byte(`{"arg": "value1"}`), &tool1), + createToolCall("call2", "tool2", []byte(`{"arg": "value2"}`), &tool2), + } + + results := executeCallbacksSequential(ctx, callbacks) + + if len(results) != 2 { + t.Errorf("expected 2 results, got %d", len(results)) + return + } + + if results[0].Response != "result1" || results[0].Error != nil { + t.Errorf("expected first result to be 'result1' with no error, got %q, %v", results[0].Response, results[0].Error) + } + + if results[1].Response != "" || results[1].Error == nil { + t.Errorf("expected second result to have error, got %q, %v", results[1].Response, results[1].Error) + } +} + +func TestExecuteCallbacksParallel(t *testing.T) { + ctx := context.Background() + + tool1 := createTool("tool1", "first tool", func(ctx context.Context, call tools.Call) (string, error) { + time.Sleep(10 * time.Millisecond) // Simulate some work + return "result1", nil + }) + + tool2 := createTool("tool2", "second tool", func(ctx context.Context, call tools.Call) (string, error) { + time.Sleep(10 * time.Millisecond) // Simulate some work + return "result2", nil + }) + + callbacks := []tools.Call{ + createToolCall("call1", "tool1", []byte(`{"arg": "value1"}`), &tool1), + createToolCall("call2", "tool2", []byte(`{"arg": "value2"}`), &tool2), + } + + results := executeCallbacksParallel(ctx, callbacks, 2) + + if len(results) != 2 { + t.Errorf("expected 2 results, got %d", len(results)) + return + } + + // Check that both results are present (order may vary due to parallelism) + foundResult1 := false + foundResult2 := false + + for _, result := range results { + if result.Response == "result1" && result.Error == nil { + foundResult1 = true + } + if result.Response == "result2" && result.Error == nil { + foundResult2 = true + } + } + + if !foundResult1 { + t.Errorf("did not find result1 in parallel execution results") + } + + if !foundResult2 { + t.Errorf("did not find result2 in parallel execution results") + } +} + +func TestExecuteCallbacksParallel_WithConcurrencyLimit(t *testing.T) { + ctx := context.Background() + + // Create tools that take time to execute + tool1 := createTool("tool1", "first tool", func(ctx context.Context, call tools.Call) (string, error) { + time.Sleep(50 * time.Millisecond) // Simulate work + return "result1", nil + }) + + tool2 := createTool("tool2", "second tool", func(ctx context.Context, call tools.Call) (string, error) { + time.Sleep(50 * time.Millisecond) // Simulate work + return "result2", nil + }) + + tool3 := createTool("tool3", "third tool", func(ctx context.Context, call tools.Call) (string, error) { + time.Sleep(50 * time.Millisecond) // Simulate work + return "result3", nil + }) + + callbacks := []tools.Call{ + createToolCall("call1", "tool1", []byte(`{"arg": "value1"}`), &tool1), + createToolCall("call2", "tool2", []byte(`{"arg": "value2"}`), &tool2), + createToolCall("call3", "tool3", []byte(`{"arg": "value3"}`), &tool3), + } + + start := time.Now() + results := executeCallbacksParallel(ctx, callbacks, 1) // Limit to 1 concurrent execution + duration := time.Since(start) + + if len(results) != 3 { + t.Errorf("expected 3 results, got %d", len(results)) + return + } + + // With parallelism=1, execution should be sequential, so it should take at least 150ms + if duration < 150*time.Millisecond { + t.Errorf("expected sequential execution to take at least 150ms, took %v", duration) + } + + // Check that all results are present + expectedResults := map[string]bool{"result1": false, "result2": false, "result3": false} + for _, result := range results { + if result.Error == nil { + expectedResults[result.Response] = true + } + } + + for result, found := range expectedResults { + if !found { + t.Errorf("did not find %s in parallel execution results", result) + } + } +} + +func TestResult_StringRepresentation(t *testing.T) { + result := &Result[string]{ + Prompts: []prompt.Prompt{ + prompt.AsUser("test prompt"), + }, + Result: "test result", + Metadata: models.Metadata{ + Model: "test/test-model", + InputTokens: 10, + OutputTokens: 5, + TotalTokens: 15, + }, + Depth: 1, + } + + // Test that the result can be accessed + if result.Result != "test result" { + t.Errorf("expected result 'test result', got %q", result.Result) + } + + if result.Depth != 1 { + t.Errorf("expected depth 1, got %d", result.Depth) + } + + if len(result.Prompts) != 1 { + t.Errorf("expected 1 prompt, got %d", len(result.Prompts)) + } + + if result.Metadata.InputTokens != 10 { + t.Errorf("expected 10 input tokens, got %d", result.Metadata.InputTokens) + } +} + +func TestRun_WithOutputSchema(t *testing.T) { + type TestStruct struct { + Message string `json:"message"` + } + + g := &gen.Generator{ + Prompter: &mockPrompter{ + responses: []*gen.Response{ + createTextResponse(`{"message": "Hello"}`, 10, 5), + }, + }, + Request: gen.Request{ + Context: context.Background(), + Model: gen.Model{ + Provider: "test", + Name: "test-model", + }, + OutputSchema: schema.From(TestStruct{}), + }, + } + + result, err := Run[TestStruct](3, 1, g, prompt.AsUser("test prompt")) + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if result.Result.Message != "Hello" { + t.Errorf("expected message 'Hello', got %q", result.Result.Message) + } +} + +// Helper function to check if a string contains a substring +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(substr) == 0 || + (len(s) > len(substr) && (s[:len(substr)] == substr || + s[len(s)-len(substr):] == substr || + func() bool { + for i := 1; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false + }()))) +} diff --git a/models/gen/generator_test.go b/models/gen/generator_test.go new file mode 100644 index 0000000..b8a156e --- /dev/null +++ b/models/gen/generator_test.go @@ -0,0 +1,232 @@ +package gen + +import ( + "context" + "reflect" + "testing" + + "github.com/modfin/bellman/prompt" + "github.com/modfin/bellman/schema" + "github.com/modfin/bellman/tools" +) + +// fakePrompter implements the Prompter interface for testing. +type fakePrompter struct { + lastRequest Request + prompts []prompt.Prompt + retResp *Response + retChan chan *StreamResponse + retErr error +} + +func (f *fakePrompter) SetRequest(r Request) { + f.lastRequest = r +} +func (f *fakePrompter) Prompt(prompts ...prompt.Prompt) (*Response, error) { + f.prompts = prompts + return f.retResp, f.retErr +} +func (f *fakePrompter) Stream(prompts ...prompt.Prompt) (<-chan *StreamResponse, error) { + f.prompts = prompts + if f.retChan == nil { + return nil, f.retErr + } + return f.retChan, f.retErr +} + +func TestFloatInt(t *testing.T) { + f := Float(3.14) + if f == nil || *f != 3.14 { + t.Errorf("Float(3.14) = %v, want pointer to 3.14", f) + } + i := Int(42) + if i == nil || *i != 42 { + t.Errorf("Int(42) = %v, want pointer to 42", i) + } +} + +func TestSetConfigAndClone(t *testing.T) { + orig := &Generator{Request: Request{SystemPrompt: "a"}} + req := Request{SystemPrompt: "b"} + newG := orig.SetConfig(req) + if newG.Request.SystemPrompt != "b" { + t.Errorf("SetConfig didn't set SystemPrompt, got %q", newG.Request.SystemPrompt) + } + if orig.Request.SystemPrompt != "a" { + t.Errorf("SetConfig mutated original, orig SystemPrompt = %q", orig.Request.SystemPrompt) + } + if newG == orig { + t.Error("SetConfig returned same Generator pointer, want new one") + } +} + +func TestCloneDeepCopy(t *testing.T) { + trueVal := true + tbVal := 5 + fVal := 1.23 + iVal := 7 + + origSchema := &schema.JSON{Description: "desc"} + ctx := context.WithValue(context.Background(), "key", "val") + toolA := tools.NewTool("A") + toolB := tools.NewTool("B") + original := &Generator{ + Request: Request{ + Context: ctx, + Stream: true, + Model: Model{Provider: "p", Name: "n"}, + SystemPrompt: "sys", + OutputSchema: origSchema, + StrictOutput: true, + Tools: []tools.Tool{toolA}, + ToolConfig: &toolB, + ThinkingBudget: &tbVal, + ThinkingParts: &trueVal, + TopP: &fVal, + TopK: &iVal, + Temperature: &fVal, + MaxTokens: &iVal, + FrequencyPenalty: &fVal, + PresencePenalty: &fVal, + StopSequences: []string{"s1", "s2"}, + }, + } + cloned := original.clone() + if cloned == original { + t.Fatal("clone returned same pointer") + } + // context is shallow copied + if cloned.Request.Context != original.Request.Context { + t.Error("Context not copied correctly") + } + // OutputSchema deep copy + if cloned.Request.OutputSchema == original.Request.OutputSchema { + t.Error("OutputSchema pointer not deep-copied") + } + // Note: OutputSchema contains maps which cannot be compared with reflect.DeepEqual + // We just verify the pointer is different + + // ToolConfig deep copy + if cloned.Request.ToolConfig == original.Request.ToolConfig { + t.Error("ToolConfig pointer not deep-copied") + } + // Note: ToolConfig contains functions which cannot be compared with reflect.DeepEqual + // We just verify the pointer is different + // Tools slice deep copy + if &cloned.Request.Tools[0] == &original.Request.Tools[0] { + t.Error("Tools slice not deep-copied") + } + if !reflect.DeepEqual(cloned.Request.Tools, original.Request.Tools) { + t.Error("Tools content mismatch") + } + // Pointer fields deep copy + if cloned.Request.PresencePenalty == original.Request.PresencePenalty { + t.Error("PresencePenalty pointer not deep-copied") + } + if *cloned.Request.PresencePenalty != *original.Request.PresencePenalty { + t.Error("PresencePenalty value mismatch") + } + if cloned.Request.FrequencyPenalty == original.Request.FrequencyPenalty { + t.Error("FrequencyPenalty pointer not deep-copied") + } + if *cloned.Request.FrequencyPenalty != *original.Request.FrequencyPenalty { + t.Error("FrequencyPenalty value mismatch") + } + if cloned.Request.Temperature == original.Request.Temperature { + t.Error("Temperature pointer not deep-copied") + } + if *cloned.Request.Temperature != *original.Request.Temperature { + t.Error("Temperature value mismatch") + } + if cloned.Request.TopP == original.Request.TopP { + t.Error("TopP pointer not deep-copied") + } + if *cloned.Request.TopP != *original.Request.TopP { + t.Error("TopP value mismatch") + } + if cloned.Request.TopK == original.Request.TopK { + t.Error("TopK pointer not deep-copied") + } + if *cloned.Request.TopK != *original.Request.TopK { + t.Error("TopK value mismatch") + } + if cloned.Request.MaxTokens == original.Request.MaxTokens { + t.Error("MaxTokens pointer not deep-copied") + } + if *cloned.Request.MaxTokens != *original.Request.MaxTokens { + t.Error("MaxTokens value mismatch") + } + if cloned.Request.ThinkingBudget == original.Request.ThinkingBudget { + t.Error("ThinkingBudget pointer not deep-copied") + } + if *cloned.Request.ThinkingBudget != *original.Request.ThinkingBudget { + t.Error("ThinkingBudget value mismatch") + } + if cloned.Request.ThinkingParts == original.Request.ThinkingParts { + t.Error("ThinkingParts pointer not deep-copied") + } + if *cloned.Request.ThinkingParts != *original.Request.ThinkingParts { + t.Error("ThinkingParts value mismatch") + } + // StopSequences deep copy + if &cloned.Request.StopSequences[0] == &original.Request.StopSequences[0] { + t.Error("StopSequences slice not deep-copied") + } + if !reflect.DeepEqual(cloned.Request.StopSequences, original.Request.StopSequences) { + t.Error("StopSequences content mismatch") + } +} + +func TestGeneratorPromptAndStream(t *testing.T) { + // Prompt error when no prompter + g := &Generator{} + _, err := g.Prompt(prompt.AsUser("hi")) + if err == nil || err.Error() != "prompter is required" { + t.Errorf("expected prompter required error, got %v", err) + } + // Stream error when no prompter + _, err = g.Stream(prompt.AsUser("hi")) + if err == nil || err.Error() != "prompter is required" { + t.Errorf("expected prompter required error, got %v", err) + } + // Prompt success + fp := &fakePrompter{retResp: &Response{Texts: []string{"ok"}}} + g2 := &Generator{Prompter: fp} + resp, err := g2.Prompt(prompt.AsAssistant("hello")) + if err != nil { + t.Errorf("Prompt unexpected error: %v", err) + } + if resp != fp.retResp { + t.Error("Prompt did not return expected response") + } + if len(fp.prompts) != 1 || fp.prompts[0].Text != "hello" { + t.Errorf("Prompt did not receive prompts, got %v", fp.prompts) + } + if !reflect.DeepEqual(fp.lastRequest, g2.Request) { + t.Errorf("Prompt did not receive request, got %v want %v", fp.lastRequest, g2.Request) + } + // Stream success + ch := make(chan *StreamResponse, 1) + ch <- &StreamResponse{Type: TYPE_DELTA, Content: "abc"} + close(ch) + fp2 := &fakePrompter{retChan: ch} + g3 := &Generator{Prompter: fp2} + outCh, err := g3.Stream(prompt.AsUser("x")) + if err != nil { + t.Errorf("Stream unexpected error: %v", err) + } + select { + case r, ok := <-outCh: + if !ok || r.Content != "abc" { + t.Errorf("Stream channel content mismatch, got %v", r) + } + default: + t.Error("Stream channel empty") + } + if len(fp2.prompts) != 1 || fp2.prompts[0].Text != "x" { + t.Errorf("Stream did not receive prompts, got %v", fp2.prompts) + } + if !fp2.lastRequest.Stream { + t.Error("Stream request not set with Stream=true") + } +} diff --git a/models/gen/model_test.go b/models/gen/model_test.go new file mode 100644 index 0000000..badc593 --- /dev/null +++ b/models/gen/model_test.go @@ -0,0 +1,27 @@ +package gen + +import "testing" + +func TestModelStringFQN(t *testing.T) { + m := Model{Provider: "prov", Name: "nm"} + want := "prov/nm" + if s := m.String(); s != want { + t.Errorf("String() = %q, want %q", s, want) + } + if f := m.FQN(); f != want { + t.Errorf("FQN() = %q, want %q", f, want) + } +} + +func TestToModel(t *testing.T) { + m, err := ToModel("p/n") + if err != nil { + t.Fatalf("ToModel unexpected error: %v", err) + } + if m.Provider != "p" || m.Name != "n" { + t.Errorf("ToModel = %+v, want Provider=\"p\", Name=\"n\"", m) + } + if _, err := ToModel("invalid"); err == nil { + t.Error("ToModel expected error for invalid input, got nil") + } +} \ No newline at end of file diff --git a/models/gen/response_test.go b/models/gen/response_test.go new file mode 100644 index 0000000..1e4fd3e --- /dev/null +++ b/models/gen/response_test.go @@ -0,0 +1,102 @@ +package gen + +import ( + "context" + "errors" + "testing" + + "github.com/modfin/bellman/models" + "github.com/modfin/bellman/tools" +) + +func TestStreamResponseError(t *testing.T) { + sr := StreamResponse{Type: TYPE_ERROR, Content: "oops"} + err := sr.Error() + want := "streaming response error: oops" + if err == nil || err.Error() != want { + t.Errorf("Error() = %v, want %q", err, want) + } + sr2 := StreamResponse{Type: TYPE_DELTA, Content: "data"} + if e := sr2.Error(); e != nil { + t.Errorf("Error() for non-error type = %v, want nil", e) + } +} + +func TestResponseTextAndTools(t *testing.T) { + r := &Response{Texts: []string{"t1"}} + if !r.IsText() { + t.Error("IsText() = false, want true") + } + if r.IsTools() { + t.Error("IsTools() = true, want false") + } + text, err := r.AsText() + if err != nil || text != "t1" { + t.Errorf("AsText() = %q, %v, want %q, nil", text, err, "t1") + } + if _, err := r.AsTools(); err == nil { + t.Error("AsTools() = nil, want error") + } + + // Tools only + tool := tools.Tool{Name: "x", Function: func(context.Context, tools.Call) (string, error) { return "", nil }} + r2 := &Response{Tools: []tools.Call{{Name: "x", Ref: &tool}}} + if !r2.IsTools() { + t.Error("IsTools() = false, want true") + } + if r2.IsText() { + t.Error("IsText() = true, want false") + } + calls, err := r2.AsTools() + if err != nil { + t.Errorf("AsTools() unexpected error: %v", err) + } + if len(calls) != 1 || calls[0].Name != "x" { + t.Errorf("AsTools() = %v, want call x", calls) + } +} + +func TestResponseUnmarshal(t *testing.T) { + var obj struct{ Foo string } + r := &Response{Texts: []string{`{"Foo":"bar"}`}} + if err := r.Unmarshal(&obj); err != nil { + t.Errorf("Unmarshal() error: %v", err) + } + if obj.Foo != "bar" { + t.Errorf("Unmarshal wrote %+v, want Foo=\"bar\"", obj) + } + r2 := &Response{Texts: []string{"invalid"}} + if err := r2.Unmarshal(&obj); err == nil { + t.Error("Unmarshal() = nil, want error for invalid JSON") + } +} + +func TestResponseEval(t *testing.T) { + // missing Ref + r := &Response{Tools: []tools.Call{{Name: "t1"}}} + if err := r.Eval(context.Background()); err == nil { + t.Error("Eval() = nil, want error for missing Ref") + } + // missing Function + tool2 := tools.Tool{Name: "t2"} + r2 := &Response{Tools: []tools.Call{{Name: "t2", Ref: &tool2}}} + if err := r2.Eval(context.Background()); err == nil { + t.Error("Eval() = nil, want error for missing Function") + } + // callback error + fn := func(ctx context.Context, call tools.Call) (string, error) { + return "", errors.New("fail") + } + tool3 := tools.Tool{Name: "t3", Function: fn} + r3 := &Response{Tools: []tools.Call{{Name: "t3", Ref: &tool3}}} + if err := r3.Eval(context.Background()); err == nil { + t.Error("Eval() = nil, want error for callback failure") + } + // success + fn2 := func(ctx context.Context, call tools.Call) (string, error) { return "ok", nil } + tool4 := tools.Tool{Name: "t4", Function: fn2} + r4 := &Response{Tools: []tools.Call{{Name: "t4", Ref: &tool4}}, Metadata: models.Metadata{}} + if err := r4.Eval(context.Background()); err != nil { + t.Errorf("Eval() error: %v", err) + } +} \ No newline at end of file diff --git a/prompt/prompt_test.go b/prompt/prompt_test.go new file mode 100644 index 0000000..3f64d44 --- /dev/null +++ b/prompt/prompt_test.go @@ -0,0 +1,147 @@ +package prompt + +import ( + "encoding/base64" + "reflect" + "testing" +) + +func TestAsAssistant(t *testing.T) { + p := AsAssistant("hello") + if p.Role != AssistantRole { + t.Errorf("expected role %q, got %q", AssistantRole, p.Role) + } + if p.Text != "hello" { + t.Errorf("expected text 'hello', got %q", p.Text) + } +} + +func TestAsUser(t *testing.T) { + p := AsUser("hi") + if p.Role != UserRole { + t.Errorf("expected role %q, got %q", UserRole, p.Role) + } + if p.Text != "hi" { + t.Errorf("expected text 'hi', got %q", p.Text) + } +} + +func TestAsUserWithData(t *testing.T) { + data := []byte("somedata") + p := AsUserWithData(MimeApplicationPDF, data) + if p.Role != UserRole { + t.Errorf("expected role %q, got %q", UserRole, p.Role) + } + if p.Payload == nil { + t.Fatal("expected payload, got nil") + } + if p.Payload.Mime != MimeApplicationPDF { + t.Errorf("expected mime %q, got %q", MimeApplicationPDF, p.Payload.Mime) + } + decoded, err := base64.StdEncoding.DecodeString(p.Payload.Data) + if err != nil { + t.Fatalf("failed to decode base64: %v", err) + } + if string(decoded) != string(data) { + t.Errorf("expected data %q, got %q", data, decoded) + } +} + +func TestAsUserWithURI(t *testing.T) { + uri := "file:///tmp/test.pdf" + p := AsUserWithURI(MimeApplicationPDF, uri) + if p.Role != UserRole { + t.Errorf("expected role %q, got %q", UserRole, p.Role) + } + if p.Payload == nil { + t.Fatal("expected payload, got nil") + } + if p.Payload.Mime != MimeApplicationPDF { + t.Errorf("expected mime %q, got %q", MimeApplicationPDF, p.Payload.Mime) + } + if p.Payload.Uri != uri { + t.Errorf("expected uri %q, got %q", uri, p.Payload.Uri) + } +} + +func TestAsToolCall(t *testing.T) { + args := []byte(`{"foo":42}`) + p := AsToolCall("id123", "myfunc", args) + if p.Role != ToolCallRole { + t.Errorf("expected role %q, got %q", ToolCallRole, p.Role) + } + if p.ToolCall == nil { + t.Fatal("expected ToolCall, got nil") + } + if p.ToolCall.ToolCallID != "id123" { + t.Errorf("expected ToolCallID 'id123', got %q", p.ToolCall.ToolCallID) + } + if p.ToolCall.Name != "myfunc" { + t.Errorf("expected Name 'myfunc', got %q", p.ToolCall.Name) + } + if string(p.ToolCall.Arguments) != string(args) { + t.Errorf("expected Arguments %q, got %q", args, p.ToolCall.Arguments) + } +} + +func TestAsToolResponse(t *testing.T) { + p := AsToolResponse("id456", "myfunc2", "result") + if p.Role != ToolResponseRole { + t.Errorf("expected role %q, got %q", ToolResponseRole, p.Role) + } + if p.ToolResponse == nil { + t.Fatal("expected ToolResponse, got nil") + } + if p.ToolResponse.ToolCallID != "id456" { + t.Errorf("expected ToolCallID 'id456', got %q", p.ToolResponse.ToolCallID) + } + if p.ToolResponse.Name != "myfunc2" { + t.Errorf("expected Name 'myfunc2', got %q", p.ToolResponse.Name) + } + if p.ToolResponse.Response != "result" { + t.Errorf("expected Response 'result', got %q", p.ToolResponse.Response) + } +} + +func TestMIMEMaps(t *testing.T) { + // Images + for k, v := range MIMEImages { + if !v { + t.Errorf("expected MIMEImages[%q] to be true", k) + } + } + // Audio + for k, v := range MIMEAudio { + if !v { + t.Errorf("expected MIMEAudio[%q] to be true", k) + } + } + // Video + for k, v := range MIMEVideo { + if !v { + t.Errorf("expected MIMEVideo[%q] to be true", k) + } + } +} + +func TestPromptStructTags(t *testing.T) { + p := Prompt{ + Role: UserRole, + Text: "foo", + Payload: &Payload{Mime: MimeTextPlain, Data: "bar", Uri: "baz"}, + ToolCall: &ToolCall{ToolCallID: "id", Name: "n", Arguments: []byte("{}")}, + ToolResponse: &ToolResponse{ToolCallID: "id2", Name: "n2", Response: "r"}, + } + // Just check that fields are settable and readable + if p.Role != UserRole || p.Text != "foo" || p.Payload.Mime != MimeTextPlain { + t.Errorf("Prompt struct fields not set/read correctly") + } + if p.ToolCall.Name != "n" || p.ToolResponse.Name != "n2" { + t.Errorf("ToolCall/ToolResponse fields not set/read correctly") + } + // Check reflect tags + typ := reflect.TypeOf(Prompt{}) + if typ.NumField() != 5 { + t.Errorf("Prompt struct should have 5 fields, got %d", typ.NumField()) + } +}