From a6279d7101458d33c9afdf744a8222e1067d97f8 Mon Sep 17 00:00:00 2001 From: Donnie Adams Date: Fri, 1 Nov 2024 20:44:06 -0400 Subject: [PATCH] fix: ensure the parent call frame is of category none Signed-off-by: Donnie Adams --- frame.go | 11 +++++++++++ run.go | 28 ++++++++++------------------ 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/frame.go b/frame.go index 39d8ba9..5207e43 100644 --- a/frame.go +++ b/frame.go @@ -49,6 +49,17 @@ type RunFrame struct { Type EventType `json:"type"` } +type CallFrames map[string]CallFrame + +func (c CallFrames) ParentCallFrame() CallFrame { + for _, call := range c { + if call.ParentID == "" && call.ToolCategory == NoCategory { + return call + } + } + return CallFrame{} +} + type CallFrame struct { CallContext `json:",inline"` diff --git a/run.go b/run.go index a111045..4a5865d 100644 --- a/run.go +++ b/run.go @@ -36,15 +36,14 @@ type Run struct { wait func() basicCommand bool - program *Program - callsLock sync.RWMutex - calls map[string]CallFrame - parentCallFrameID string - rawOutput map[string]any - output, errput string - events chan Frame - lock sync.Mutex - responseCode int + program *Program + callsLock sync.RWMutex + calls CallFrames + rawOutput map[string]any + output, errput string + events chan Frame + lock sync.Mutex + responseCode int } // Text returns the text output of the gptscript. It blocks until the output is ready. @@ -104,7 +103,7 @@ func (r *Run) RespondingTool() Tool { } // Calls will return a flattened array of the calls for this run. -func (r *Run) Calls() map[string]CallFrame { +func (r *Run) Calls() CallFrames { r.callsLock.RLock() defer r.callsLock.RUnlock() return maps.Clone(r.calls) @@ -115,11 +114,7 @@ func (r *Run) ParentCallFrame() (CallFrame, bool) { r.callsLock.RLock() defer r.callsLock.RUnlock() - if r.parentCallFrameID == "" { - return CallFrame{}, false - } - - return r.calls[r.parentCallFrameID], true + return r.calls.ParentCallFrame(), true } // ErrorOutput returns the stderr output of the gptscript. @@ -394,9 +389,6 @@ func (r *Run) request(ctx context.Context, payload any) (err error) { if event.Call != nil { r.callsLock.Lock() r.calls[event.Call.ID] = *event.Call - if r.parentCallFrameID == "" && event.Call.ParentID == "" { - r.parentCallFrameID = event.Call.ID - } r.callsLock.Unlock() } else if event.Run != nil { if event.Run.Type == EventTypeRunStart {