Skip to content

Advanced GraphRAG examples #162

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion neo4j-example/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

<groupId>dev.langchain4j</groupId>
<artifactId>neo4j-example</artifactId>
<version>1.0.0-beta4</version>
<version>1.1.0-beta7</version>
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
Expand Down Expand Up @@ -48,6 +48,12 @@
<version>${project.version}</version>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-community-llm-graph-transformer</artifactId>
<version>${project.version}</version>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings-all-minilm-l6-v2</artifactId>
Expand Down
129 changes: 129 additions & 0 deletions neo4j-example/src/main/java/KnowledgeGraphWriterExample.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import dev.langchain4j.community.data.document.graph.GraphDocument;
import dev.langchain4j.community.data.document.transformer.graph.LLMGraphTransformer;
import dev.langchain4j.community.rag.content.retriever.neo4j.KnowledgeGraphWriter;
import dev.langchain4j.community.rag.content.retriever.neo4j.Neo4jGraph;
import dev.langchain4j.data.document.DefaultDocument;
import dev.langchain4j.data.document.Document;
import dev.langchain4j.model.openai.OpenAiChatModel;
import org.testcontainers.containers.Neo4jContainer;

import java.util.List;

import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.model.openai.OpenAiChatModelName.GPT_4_O_MINI;

public class KnowledgeGraphWriterExample {
private static final String EXAMPLES_PROMPT =
"""
[
{
"tail":"Microsoft",
"head":"Adam",
"head_type":"Person",
"text":"Adam is a software engineer in Microsoft since 2009, and last year he got an award as the Best Talent",
"relation":"WORKS_FOR",
"tail_type":"Company"
},
{
"tail":"Best Talent",
"head":"Adam",
"head_type":"Person",
"text":"Adam is a software engineer in Microsoft since 2009, and last year he got an award as the Best Talent",
"relation":"HAS_AWARD",
"tail_type":"Award"
},
{
"tail":"Microsoft",
"head":"Microsoft Word",
"head_type":"Product",
"text":"Microsoft is a tech company that provide several products such as Microsoft Word",
"relation":"PRODUCED_BY",
"tail_type":"Company"
},
{
"tail":"lightweight app",
"head":"Microsoft Word",
"head_type":"Product",
"text":"Microsoft Word is a lightweight app that accessible offline",
"relation":"HAS_CHARACTERISTIC",
"tail_type":"Characteristic"
},
{
"tail":"accessible offline",
"head":"Microsoft Word",
"head_type":"Product",
"text":"Microsoft Word is a lightweight app that accessible offline",
"relation":"HAS_CHARACTERISTIC",
"tail_type":"Characteristic"
}
]
""";

public static String CAT_ON_THE_TABLE = "Sylvester the cat is on the table";
public static String KEANU_REEVES_ACTED = "Keanu Reeves acted in Matrix";
public static final String OPENAI_API_KEY = getOrDefault(System.getenv("OPENAI_API_KEY"), "demo");
public static final String OPENAI_BASE_URL = "demo".equals(OPENAI_API_KEY) ? "http://langchain4j.dev/demo/openai/v1" : null;

public static void main(String[] args) {
final OpenAiChatModel model = OpenAiChatModel.builder()
.apiKey(OPENAI_API_KEY)
.baseUrl(OPENAI_BASE_URL)
.modelName(GPT_4_O_MINI)
.build();

LLMGraphTransformer graphTransformer = LLMGraphTransformer.builder()
.model(model)
.examples(EXAMPLES_PROMPT)
.build();

Document docKeanu = new DefaultDocument(KEANU_REEVES_ACTED);
Document docCat = new DefaultDocument(CAT_ON_THE_TABLE);
List<Document> documents = List.of(docCat, docKeanu);

List<GraphDocument> graphDocuments = graphTransformer.transformAll(documents);

try (Neo4jContainer<?> neo4jContainer = new Neo4jContainer<>("neo4j:5.26")
.withAdminPassword("admin1234")
.withLabsPlugins("apoc")) {
neo4jContainer.start();
Neo4jGraph graph = Neo4jGraph.builder()
.withBasicAuth(neo4jContainer.getBoltUrl(), "neo4j", neo4jContainer.getAdminPassword())
.build();

KnowledgeGraphWriter writer = KnowledgeGraphWriter.builder()
.graph(graph)
.label("Entity")
.relType("MENTIONS")
.idProperty("id")
.textProperty("text")
.build();

// `graphDocuments` obtained from LLMGraphTransformer
writer.addGraphDocuments(graphDocuments, true); // set to true to include document source

/*
The above KnowledgeGraphWriter will add paths like:
(:Document {id: UUID, text: 'Sylvester the cat is on the table'})-[:MENTIONS]->(:Entity:Animal {id: 'Sylvester the cat'})-[:IS_ON]->(:Entity:Object {id: 'table'})
(Document {id: UUID, text: 'Keanu Reeves acted in Matrix'})-[:MENTIONS]->(:Entity:Person {id: 'Keanu Reeves'})-[:ACTED_IN]->(:Entity:Movie {id: 'Matrix'})
*/

KnowledgeGraphWriter writerWithoutDocs = KnowledgeGraphWriter.builder()
.graph(graph)
.label("FooBar")
.relType("MENTIONS")
.idProperty("id")
.textProperty("text")
.build();

// `graphDocuments` obtained from LLMGraphTransformer
writerWithoutDocs.addGraphDocuments(graphDocuments, false); // set to true not to include document source
/*
The above KnowledgeGraphWriter will add paths like:
(:FooBar:Animal {id: 'Sylvester the cat'})-[:IS_ON]->(:FooBar:Object {id: 'table'})
(:FooBar:Person {id: 'Keanu Reeves'})-[:ACTED_IN]->(:FooBar:Movie {id: 'Matrix'})
*/

graph.close();
}
}
}
108 changes: 91 additions & 17 deletions neo4j-example/src/main/java/Neo4jContentRetrieverExample.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import dev.langchain4j.community.rag.content.retriever.neo4j.Neo4jGraph;
import dev.langchain4j.community.rag.content.retriever.neo4j.Neo4jText2CypherRetriever;
import dev.langchain4j.model.chat.ChatModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.rag.content.Content;
import dev.langchain4j.rag.query.Query;
Expand Down Expand Up @@ -33,27 +34,100 @@ public static void main(String[] args) {
neo4jContainer.start();
try (Driver driver = GraphDatabase.driver(neo4jContainer.getBoltUrl(), AuthTokens.none())) {
try (Neo4jGraph graph = Neo4jGraph.builder().driver(driver).build()) {
try (Session session = driver.session()) {
session.run("CREATE (book:Book {title: 'Dune'})<-[:WROTE]-(author:Person {name: 'Frank Herbert'})");
}
// The refreshSchema is needed only if we execute write operation after the `Neo4jGraph` instance,
// in this case `CREATE (book:Book...`
// If CREATE (and in general write operations to the db) are performed externally before Neo4jGraph.builder(),
// the refreshSchema() is not needed
graph.refreshSchema();
contentRetrieverWithMinimalConfig(driver, graph, chatLanguageModel);

Neo4jText2CypherRetriever retriever = Neo4jText2CypherRetriever.builder()
.graph(graph)
.chatModel(chatLanguageModel)
.build();
contentRetrieverWithExamples(graph, chatLanguageModel);

Query query = new Query("Who is the author of the book 'Dune'?");

List<Content> contents = retriever.retrieve(query);

System.out.println(contents.get(0).textSegment().text()); // "Frank Herbert"
contentRetrieverWithoutRetries(graph, chatLanguageModel);
}
}
}
}

private static void contentRetrieverWithMinimalConfig(Driver driver, Neo4jGraph graph, ChatModel chatLanguageModel) {
// tag::retrieve-text2cypher[]
try (Session session = driver.session()) {
session.run("CREATE (book:Book {title: 'Dune'})<-[:WROTE]-(author:Person {name: 'Frank Herbert'})");
}
// The refreshSchema is needed only if we execute write operation after the `Neo4jGraph` instance,
// in this case `CREATE (book:Book...`
// If CREATE (and in general write operations to the db) are performed externally before Neo4jGraph.builder(),
// the refreshSchema() is not needed
graph.refreshSchema();

Neo4jText2CypherRetriever retriever = Neo4jText2CypherRetriever.builder()
.graph(graph)
.chatModel(chatLanguageModel)
.build();

Query query = new Query("Who is the author of the book 'Dune'?");

List<Content> contents = retriever.retrieve(query);

System.out.println(contents.get(0).textSegment().text()); // "Frank Herbert"
// end::retrieve-text2cypher[]
}

private static void contentRetrieverWithExamples(Neo4jGraph graph, ChatModel chatLanguageModel) {
// tag::retrieve-text2cypher-examples[]
List<String> examples = List.of(
"""
# Which streamer has the most followers?
MATCH (s:Stream)
RETURN s.name AS streamer
ORDER BY s.followers DESC LIMIT 1
""",
"""
# How many streamers are from Norway?
MATCH (s:Stream)-[:HAS_LANGUAGE]->(:Language {{name: 'Norwegian'}})
RETURN count(s) AS streamers
""");

Neo4jText2CypherRetriever neo4jContentRetriever = Neo4jText2CypherRetriever.builder()
.graph(graph)
.chatModel(chatLanguageModel)
// add the above examples
.examples(examples)
.build();

final String textQuery = "Which streamer from Italy has the most followers?";
Query query = new Query(textQuery);
List<Content> contents = neo4jContentRetriever.retrieve(query);
System.out.println(contents.get(0).textSegment().text());
// output: "The most followed italian streamer"
// end::retrieve-text2cypher-examples[]
}

private static void contentRetrieverWithoutRetries(Neo4jGraph graph, ChatModel chatLanguageModel) {
Neo4jText2CypherRetriever retriever = Neo4jText2CypherRetriever.builder()
.graph(graph)
.chatModel(chatLanguageModel)
.maxRetries(0) // disables retry logic
.build();

Query query = new Query("Who is the author of the book 'Dune'?");

List<Content> contents = retriever.retrieve(query);

System.out.println(contents.get(0).textSegment().text()); // "Frank Herbert"
}

private static void contentRetrieverWithSamplesAndMaxRels(ChatModel chatLanguageModel, Driver driver) {
// Sample up to 3 example paths from the graph schema
// Explore a maximum of 8 relationships from the start node
try (Neo4jGraph graph = Neo4jGraph.builder().driver(driver).sample(3L).maxRels(8L).build()) {
// tag::retrieve-text2cypher-sample-max-rels[]
Neo4jText2CypherRetriever retriever = Neo4jText2CypherRetriever.builder()
.graph(graph)
.chatModel(chatLanguageModel)
.build();

Query query = new Query("Who is the author of the book 'Dune'?");

List<Content> contents = retriever.retrieve(query);

System.out.println(contents.get(0).textSegment().text()); // "Frank Herbert"
// end::retrieve-text2cypher-sample-max-rels[]
}
}
}
17 changes: 11 additions & 6 deletions neo4j-example/src/main/java/Neo4jEmbeddingStoreExample.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ public static void main(String[] args) {
searchEmbeddingsWithAddAllWithMetadataMaxResultsAndMinScore();

// custom embeddingStore
// tag::custom-embedding-store[]
Neo4jEmbeddingStore customEmbeddingStore = Neo4jEmbeddingStore.builder()
.withBasicAuth(neo4j.getBoltUrl(), "neo4j", neo4j.getAdminPassword())
.dimension(embeddingModel.dimension())
Expand All @@ -38,16 +39,17 @@ public static void main(String[] args) {
.idProperty("customId")
.textProperty("customText")
.build();
// end::custom-embedding-store[]
searchEmbeddingsWithSingleMaxResult(customEmbeddingStore);
}
}

private static void searchEmbeddingsWithSingleMaxResult(EmbeddingStore<TextSegment> minimalEmbedding) {

// tag::add-single-embedding[]
TextSegment segment1 = TextSegment.from("I like football.");
Embedding embedding1 = embeddingModel.embed(segment1).content();
minimalEmbedding.add(embedding1, segment1);

// end::add-single-embedding[]
TextSegment segment2 = TextSegment.from("The weather is good today.");
Embedding embedding2 = embeddingModel.embed(segment2).content();
minimalEmbedding.add(embedding2, segment2);
Expand All @@ -65,7 +67,7 @@ private static void searchEmbeddingsWithSingleMaxResult(EmbeddingStore<TextSegme
}

private static void searchEmbeddingsWithAddAllAndSingleMaxResult() {

// tag::add-multiple-embeddings[]
TextSegment segment1 = TextSegment.from("I like football.");
Embedding embedding1 = embeddingModel.embed(segment1).content();

Expand All @@ -78,7 +80,8 @@ private static void searchEmbeddingsWithAddAllAndSingleMaxResult() {
List.of(embedding1, embedding2, embedding3),
List.of(segment1, segment2, segment3)
);

// end::add-multiple-embeddings[]

Embedding queryEmbedding = embeddingModel.embed("What are your favourites sport?").content();
final EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
.queryEmbedding(queryEmbedding)
Expand All @@ -94,7 +97,8 @@ private static void searchEmbeddingsWithAddAllAndSingleMaxResult() {
}

private static void searchEmbeddingsWithAddAllWithMetadataMaxResultsAndMinScore() {


// tag::add-multiple-embeddings-metadata[]
TextSegment segment1 = TextSegment.from("I like football.", Metadata.from("test-key-1", "test-value-1"));
Embedding embedding1 = embeddingModel.embed(segment1).content();

Expand All @@ -107,7 +111,8 @@ private static void searchEmbeddingsWithAddAllWithMetadataMaxResultsAndMinScore(
List.of(embedding1, embedding2, embedding3),
List.of(segment1, segment2, segment3)
);

// end::add-multiple-embeddings-metadata[]

Embedding queryEmbedding = embeddingModel.embed("What are your favourite sports?").content();
final EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
.queryEmbedding(queryEmbedding)
Expand Down
Loading