diff --git a/go.mod b/go.mod index a644a6e..5e9c4a1 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,7 @@ module github.com/tmc/langgraphgo go 1.22 -require github.com/tmc/langchaingo v0.1.7 +require github.com/tmc/langchaingo v0.1.9 require ( github.com/dlclark/regexp2 v1.10.0 // indirect diff --git a/go.sum b/go.sum index 0d584cb..bbfc6da 100644 --- a/go.sum +++ b/go.sum @@ -10,7 +10,7 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -github.com/tmc/langchaingo v0.1.7 h1:Jx3/KEUAkCxU0hcNo+WZcXDnCUG/PfjcrW7N+f3ohOw= -github.com/tmc/langchaingo v0.1.7/go.mod h1:lPpWPoAud+yQowJNRZhdtRbQCSHKF+jRxd0gU58GDHU= +github.com/tmc/langchaingo v0.1.9 h1:6dtKgK52u2+9ksTTzTNvIpS5MiT5IfxtAVR4gDCxfn0= +github.com/tmc/langchaingo v0.1.9/go.mod h1:MJpoh929t7a3JkbCW2cXTWwInjdaY2NMBDU4JeetwFo= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/graph/graph.go b/graph/graph.go index 421b69e..7cfbabc 100644 --- a/graph/graph.go +++ b/graph/graph.go @@ -49,6 +49,9 @@ type MessageGraph struct { // edges is a slice of Edge objects representing the connections between nodes. edges []Edge + // conditionalEdges contains a map between "From" node, while "To" node is derived based on the condition. + conditionalEdges map[string]func(ctx context.Context, state []llms.MessageContent) string + // entryPoint is the name of the entry point node in the graph. entryPoint string } @@ -56,7 +59,8 @@ type MessageGraph struct { // NewMessageGraph creates a new instance of MessageGraph. func NewMessageGraph() *MessageGraph { return &MessageGraph{ - nodes: make(map[string]Node), + nodes: make(map[string]Node), + conditionalEdges: make(map[string]func(ctx context.Context, state []llms.MessageContent) string), } } @@ -76,6 +80,11 @@ func (g *MessageGraph) AddEdge(from, to string) { }) } +// AddConditionalEdge adds a new edge in which "from" node is identified based on the "condition". +func (g *MessageGraph) AddConditionalEdge(from string, condition func(ctx context.Context, state []llms.MessageContent) string) { + g.conditionalEdges[from] = condition +} + // SetEntryPoint sets the entry point node name for the message graph. func (g *MessageGraph) SetEntryPoint(name string) { g.entryPoint = name @@ -124,11 +133,19 @@ func (r *Runnable) Invoke(ctx context.Context, messages []llms.MessageContent) ( } foundNext := false - for _, edge := range r.graph.edges { - if edge.From == currentNode { - currentNode = edge.To - foundNext = true - break + nextNodeFn, ok := r.graph.conditionalEdges[currentNode] + if ok { + currentNode = nextNodeFn(ctx, state) + foundNext = true + } + + if !foundNext { + for _, edge := range r.graph.edges { + if edge.From == currentNode { + currentNode = edge.To + foundNext = true + break + } } } diff --git a/graph/graph_test.go b/graph/graph_test.go index 1e04d8a..53a1293 100644 --- a/graph/graph_test.go +++ b/graph/graph_test.go @@ -4,11 +4,11 @@ import ( "context" "errors" "fmt" + "strings" "testing" "github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/llms/openai" - "github.com/tmc/langchaingo/schema" "github.com/tmc/langgraphgo/graph" ) @@ -26,7 +26,7 @@ func ExampleMessageGraph() { return nil, err } return append(state, - llms.TextParts(schema.ChatMessageTypeAI, r.Choices[0].Content), + llms.TextParts(llms.ChatMessageTypeAI, r.Choices[0].Content), ), nil }) g.AddNode(graph.END, func(_ context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) { @@ -44,7 +44,7 @@ func ExampleMessageGraph() { ctx := context.Background() // Let's run it! res, err := runnable.Invoke(ctx, []llms.MessageContent{ - llms.TextParts(schema.ChatMessageTypeHuman, "What is 1 + 1?"), + llms.TextParts(llms.ChatMessageTypeHuman, "What is 1 + 1?"), }) if err != nil { panic(err) @@ -56,6 +56,7 @@ func ExampleMessageGraph() { // [{human [{What is 1 + 1?}]} {ai [{1 + 1 equals 2.}]}] } +//nolint:funlen,gocognit,cyclop func TestMessageGraph(t *testing.T) { t.Parallel() testCases := []struct { @@ -70,21 +71,21 @@ func TestMessageGraph(t *testing.T) { buildGraph: func() *graph.MessageGraph { g := graph.NewMessageGraph() g.AddNode("node1", func(_ context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) { - return append(state, llms.TextParts(schema.ChatMessageTypeAI, "Node 1")), nil + return append(state, llms.TextParts(llms.ChatMessageTypeAI, "Node 1")), nil }) g.AddNode("node2", func(_ context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) { - return append(state, llms.TextParts(schema.ChatMessageTypeAI, "Node 2")), nil + return append(state, llms.TextParts(llms.ChatMessageTypeAI, "Node 2")), nil }) g.AddEdge("node1", "node2") g.AddEdge("node2", graph.END) g.SetEntryPoint("node1") return g }, - inputMessages: []llms.MessageContent{llms.TextParts(schema.ChatMessageTypeHuman, "Input")}, + inputMessages: []llms.MessageContent{llms.TextParts(llms.ChatMessageTypeHuman, "Input")}, expectedOutput: []llms.MessageContent{ - llms.TextParts(schema.ChatMessageTypeHuman, "Input"), - llms.TextParts(schema.ChatMessageTypeAI, "Node 1"), - llms.TextParts(schema.ChatMessageTypeAI, "Node 2"), + llms.TextParts(llms.ChatMessageTypeHuman, "Input"), + llms.TextParts(llms.ChatMessageTypeAI, "Node 1"), + llms.TextParts(llms.ChatMessageTypeAI, "Node 2"), }, expectedError: nil, }, @@ -137,6 +138,40 @@ func TestMessageGraph(t *testing.T) { }, expectedError: errors.New("error in node node1: node error"), }, + { + name: "Conditional edge - condition for edge fulfilled", + buildGraph: func() *graph.MessageGraph { + g := graph.NewMessageGraph() + g.AddNode("node1", func(_ context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) { + return append(state, llms.TextParts(llms.ChatMessageTypeAI, "function calling: use calculator")), nil + }) + g.AddNode("node2", func(_ context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) { + return append(state, llms.TextParts(llms.ChatMessageTypeAI, "Node 2")), nil + }) + g.AddNode("calculator", func(_ context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) { + return append(state, llms.TextParts(llms.ChatMessageTypeTool, "1+1=2")), nil + }) + g.AddConditionalEdge("node1", func(_ context.Context, state []llms.MessageContent) string { + if content, ok := state[len(state)-1].Parts[0].(llms.TextContent); ok { + if strings.Contains(content.Text, "calculator") { + return "calculator" + } + } + return "node2" + }) + g.AddEdge("node2", graph.END) + g.AddEdge("calculator", graph.END) + g.SetEntryPoint("node1") + return g + }, + inputMessages: []llms.MessageContent{llms.TextParts(llms.ChatMessageTypeHuman, "what is 1+1?")}, + expectedOutput: []llms.MessageContent{ + llms.TextParts(llms.ChatMessageTypeHuman, "what is 1+1?"), + llms.TextParts(llms.ChatMessageTypeAI, "function calling: use calculator"), + llms.TextParts(llms.ChatMessageTypeTool, "1+1=2"), + }, + expectedError: nil, + }, } for _, tc := range testCases {