Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module github.com/tmc/langgraphgo

go 1.22

require github.com/tmc/langchaingo v0.1.7
require github.com/tmc/langchaingo v0.1.9

require (
github.com/dlclark/regexp2 v1.10.0 // indirect
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tmc/langchaingo v0.1.7 h1:Jx3/KEUAkCxU0hcNo+WZcXDnCUG/PfjcrW7N+f3ohOw=
github.com/tmc/langchaingo v0.1.7/go.mod h1:lPpWPoAud+yQowJNRZhdtRbQCSHKF+jRxd0gU58GDHU=
github.com/tmc/langchaingo v0.1.9 h1:6dtKgK52u2+9ksTTzTNvIpS5MiT5IfxtAVR4gDCxfn0=
github.com/tmc/langchaingo v0.1.9/go.mod h1:MJpoh929t7a3JkbCW2cXTWwInjdaY2NMBDU4JeetwFo=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
29 changes: 23 additions & 6 deletions graph/graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,18 @@ type MessageGraph struct {
// edges is a slice of Edge objects representing the connections between nodes.
edges []Edge

// conditionalEdges contains a map between "From" node, while "To" node is derived based on the condition.
conditionalEdges map[string]func(ctx context.Context, state []llms.MessageContent) string

// entryPoint is the name of the entry point node in the graph.
entryPoint string
}

// NewMessageGraph creates a new instance of MessageGraph.
func NewMessageGraph() *MessageGraph {
return &MessageGraph{
nodes: make(map[string]Node),
nodes: make(map[string]Node),
conditionalEdges: make(map[string]func(ctx context.Context, state []llms.MessageContent) string),
}
}

Expand All @@ -76,6 +80,11 @@ func (g *MessageGraph) AddEdge(from, to string) {
})
}

// AddConditionalEdge adds a new edge in which "from" node is identified based on the "condition".
func (g *MessageGraph) AddConditionalEdge(from string, condition func(ctx context.Context, state []llms.MessageContent) string) {
g.conditionalEdges[from] = condition
}

// SetEntryPoint sets the entry point node name for the message graph.
func (g *MessageGraph) SetEntryPoint(name string) {
g.entryPoint = name
Expand Down Expand Up @@ -124,11 +133,19 @@ func (r *Runnable) Invoke(ctx context.Context, messages []llms.MessageContent) (
}

foundNext := false
for _, edge := range r.graph.edges {
if edge.From == currentNode {
currentNode = edge.To
foundNext = true
break
nextNodeFn, ok := r.graph.conditionalEdges[currentNode]
if ok {
currentNode = nextNodeFn(ctx, state)
foundNext = true
}

if !foundNext {
for _, edge := range r.graph.edges {
if edge.From == currentNode {
currentNode = edge.To
foundNext = true
break
}
}
}

Expand Down
53 changes: 44 additions & 9 deletions graph/graph_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ import (
"context"
"errors"
"fmt"
"strings"
"testing"

"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/openai"
"github.com/tmc/langchaingo/schema"
"github.com/tmc/langgraphgo/graph"
)

Expand All @@ -26,7 +26,7 @@ func ExampleMessageGraph() {
return nil, err
}
return append(state,
llms.TextParts(schema.ChatMessageTypeAI, r.Choices[0].Content),
llms.TextParts(llms.ChatMessageTypeAI, r.Choices[0].Content),
), nil
})
g.AddNode(graph.END, func(_ context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) {
Expand All @@ -44,7 +44,7 @@ func ExampleMessageGraph() {
ctx := context.Background()
// Let's run it!
res, err := runnable.Invoke(ctx, []llms.MessageContent{
llms.TextParts(schema.ChatMessageTypeHuman, "What is 1 + 1?"),
llms.TextParts(llms.ChatMessageTypeHuman, "What is 1 + 1?"),
})
if err != nil {
panic(err)
Expand All @@ -56,6 +56,7 @@ func ExampleMessageGraph() {
// [{human [{What is 1 + 1?}]} {ai [{1 + 1 equals 2.}]}]
}

//nolint:funlen,gocognit,cyclop
func TestMessageGraph(t *testing.T) {
t.Parallel()
testCases := []struct {
Expand All @@ -70,21 +71,21 @@ func TestMessageGraph(t *testing.T) {
buildGraph: func() *graph.MessageGraph {
g := graph.NewMessageGraph()
g.AddNode("node1", func(_ context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) {
return append(state, llms.TextParts(schema.ChatMessageTypeAI, "Node 1")), nil
return append(state, llms.TextParts(llms.ChatMessageTypeAI, "Node 1")), nil
})
g.AddNode("node2", func(_ context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) {
return append(state, llms.TextParts(schema.ChatMessageTypeAI, "Node 2")), nil
return append(state, llms.TextParts(llms.ChatMessageTypeAI, "Node 2")), nil
})
g.AddEdge("node1", "node2")
g.AddEdge("node2", graph.END)
g.SetEntryPoint("node1")
return g
},
inputMessages: []llms.MessageContent{llms.TextParts(schema.ChatMessageTypeHuman, "Input")},
inputMessages: []llms.MessageContent{llms.TextParts(llms.ChatMessageTypeHuman, "Input")},
expectedOutput: []llms.MessageContent{
llms.TextParts(schema.ChatMessageTypeHuman, "Input"),
llms.TextParts(schema.ChatMessageTypeAI, "Node 1"),
llms.TextParts(schema.ChatMessageTypeAI, "Node 2"),
llms.TextParts(llms.ChatMessageTypeHuman, "Input"),
llms.TextParts(llms.ChatMessageTypeAI, "Node 1"),
llms.TextParts(llms.ChatMessageTypeAI, "Node 2"),
},
expectedError: nil,
},
Expand Down Expand Up @@ -137,6 +138,40 @@ func TestMessageGraph(t *testing.T) {
},
expectedError: errors.New("error in node node1: node error"),
},
{
name: "Conditional edge - condition for edge fulfilled",
buildGraph: func() *graph.MessageGraph {
g := graph.NewMessageGraph()
g.AddNode("node1", func(_ context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) {
return append(state, llms.TextParts(llms.ChatMessageTypeAI, "function calling: use calculator")), nil
})
g.AddNode("node2", func(_ context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) {
return append(state, llms.TextParts(llms.ChatMessageTypeAI, "Node 2")), nil
})
g.AddNode("calculator", func(_ context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) {
return append(state, llms.TextParts(llms.ChatMessageTypeTool, "1+1=2")), nil
})
g.AddConditionalEdge("node1", func(_ context.Context, state []llms.MessageContent) string {
if content, ok := state[len(state)-1].Parts[0].(llms.TextContent); ok {
if strings.Contains(content.Text, "calculator") {
return "calculator"
}
}
return "node2"
})
g.AddEdge("node2", graph.END)
g.AddEdge("calculator", graph.END)
g.SetEntryPoint("node1")
return g
},
inputMessages: []llms.MessageContent{llms.TextParts(llms.ChatMessageTypeHuman, "what is 1+1?")},
expectedOutput: []llms.MessageContent{
llms.TextParts(llms.ChatMessageTypeHuman, "what is 1+1?"),
llms.TextParts(llms.ChatMessageTypeAI, "function calling: use calculator"),
llms.TextParts(llms.ChatMessageTypeTool, "1+1=2"),
},
expectedError: nil,
},
}

for _, tc := range testCases {
Expand Down