diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/etl-pipeline.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/etl-pipeline.adoc index 58f461d457e..a4f808837ef 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/etl-pipeline.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/etl-pipeline.adoc @@ -709,18 +709,26 @@ class MyKeywordEnricher { } List enrichDocuments(List documents) { - KeywordMetadataEnricher enricher = new KeywordMetadataEnricher(this.chatModel, 5); + KeywordMetadataEnricher enricher = KeywordMetadataEnricher.builder(chatModel) + .keywordCount(5) + .build(); + + // Or use custom templates + KeywordMetadataEnricher enricher = KeywordMetadataEnricher.builder(chatModel) + .keywordsTemplate(YOUR_CUSTOM_TEMPLATE) + .build(); + return enricher.apply(documents); } } ---- -==== Constructor +==== Constructor Options -The `KeywordMetadataEnricher` constructor takes two parameters: +The `KeywordMetadataEnricher` provides two constructor options: -1. `ChatModel chatModel`: The AI model used for generating keywords. -2. `int keywordCount`: The number of keywords to extract for each document. +1. `KeywordMetadataEnricher(ChatModel chatModel, int keywordCount)`: To use the default template and extract a specified number of keywords. +2. `KeywordMetadataEnricher(ChatModel chatModel, PromptTemplate keywordsTemplate)`: To use a custom template for keyword extraction. ==== Behavior @@ -734,7 +742,8 @@ The `KeywordMetadataEnricher` processes documents as follows: ==== Customization -The keyword extraction prompt can be customized by modifying the `KEYWORDS_TEMPLATE` constant in the class. The default template is: +You can use the default template or customize the template through the keywordsTemplate parameter. +The default template is: [source,java] ---- @@ -748,7 +757,14 @@ Where `+{context_str}+` is replaced with the document content, and `%s` is repla [source,java] ---- ChatModel chatModel = // initialize your chat model -KeywordMetadataEnricher enricher = new KeywordMetadataEnricher(chatModel, 5); +KeywordMetadataEnricher enricher = KeywordMetadataEnricher.builder(chatModel) + .keywordCount(5) + .build(); + +// Or use custom templates +KeywordMetadataEnricher enricher = KeywordMetadataEnricher.builder(chatModel) + .keywordsTemplate(new PromptTemplate("Extract 5 important keywords from the following text and separate them with commas:\n{context_str}")) + .build(); Document doc = new Document("This is a document about artificial intelligence and its applications in modern technology."); @@ -766,6 +782,7 @@ System.out.println("Extracted keywords: " + keywords); * The enricher adds the "excerpt_keywords" metadata field to each processed document. * The generated keywords are returned as a comma-separated string. * This enricher is particularly useful for improving document searchability and for generating tags or categories for documents. +* In the Builder pattern, if the `keywordsTemplate` parameter is set, the `keywordCount` parameter will be ignored. === SummaryMetadataEnricher The `SummaryMetadataEnricher` is a `DocumentTransformer` that uses a generative AI model to create summaries for documents and add them as metadata. It can generate summaries for the current document, as well as adjacent documents (previous and next). diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/transformer/KeywordMetadataEnricher.java b/spring-ai-model/src/main/java/org/springframework/ai/model/transformer/KeywordMetadataEnricher.java index 838fc9fe5b8..bf9b38fca70 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/model/transformer/KeywordMetadataEnricher.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/model/transformer/KeywordMetadataEnricher.java @@ -19,6 +19,9 @@ import java.util.List; import java.util.Map; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; @@ -30,16 +33,19 @@ * Keyword extractor that uses generative to extract 'excerpt_keywords' metadata field. * * @author Christian Tzolov + * @author YunKui Lu */ public class KeywordMetadataEnricher implements DocumentTransformer { + private static final Logger logger = LoggerFactory.getLogger(KeywordMetadataEnricher.class); + public static final String CONTEXT_STR_PLACEHOLDER = "context_str"; public static final String KEYWORDS_TEMPLATE = """ {context_str}. Give %s unique keywords for this document. Format as comma separated. Keywords: """; - private static final String EXCERPT_KEYWORDS_METADATA_KEY = "excerpt_keywords"; + public static final String EXCERPT_KEYWORDS_METADATA_KEY = "excerpt_keywords"; /** * Model predictor @@ -47,28 +53,93 @@ public class KeywordMetadataEnricher implements DocumentTransformer { private final ChatModel chatModel; /** - * The number of keywords to extract. + * The prompt template to use for keyword extraction. */ - private final int keywordCount; + private final PromptTemplate keywordsTemplate; + /** + * Create a new {@link KeywordMetadataEnricher} instance. + * @param chatModel the model predictor to use for keyword extraction. + * @param keywordCount the number of keywords to extract. + */ public KeywordMetadataEnricher(ChatModel chatModel, int keywordCount) { - Assert.notNull(chatModel, "ChatModel must not be null"); - Assert.isTrue(keywordCount >= 1, "Document count must be >= 1"); + Assert.notNull(chatModel, "chatModel must not be null"); + Assert.isTrue(keywordCount >= 1, "keywordCount must be >= 1"); + + this.chatModel = chatModel; + this.keywordsTemplate = new PromptTemplate(String.format(KEYWORDS_TEMPLATE, keywordCount)); + } + + /** + * Create a new {@link KeywordMetadataEnricher} instance. + * @param chatModel the model predictor to use for keyword extraction. + * @param keywordsTemplate the prompt template to use for keyword extraction. + */ + public KeywordMetadataEnricher(ChatModel chatModel, PromptTemplate keywordsTemplate) { + Assert.notNull(chatModel, "chatModel must not be null"); + Assert.notNull(keywordsTemplate, "keywordsTemplate must not be null"); this.chatModel = chatModel; - this.keywordCount = keywordCount; + this.keywordsTemplate = keywordsTemplate; } @Override public List apply(List documents) { for (Document document : documents) { - - var template = new PromptTemplate(String.format(KEYWORDS_TEMPLATE, this.keywordCount)); - Prompt prompt = template.create(Map.of(CONTEXT_STR_PLACEHOLDER, document.getText())); + Prompt prompt = this.keywordsTemplate.create(Map.of(CONTEXT_STR_PLACEHOLDER, document.getText())); String keywords = this.chatModel.call(prompt).getResult().getOutput().getText(); - document.getMetadata().putAll(Map.of(EXCERPT_KEYWORDS_METADATA_KEY, keywords)); + document.getMetadata().put(EXCERPT_KEYWORDS_METADATA_KEY, keywords); } return documents; } + // Exposed for testing purposes + PromptTemplate getKeywordsTemplate() { + return this.keywordsTemplate; + } + + public static Builder builder(ChatModel chatModel) { + return new Builder(chatModel); + } + + public static class Builder { + + private final ChatModel chatModel; + + private int keywordCount; + + private PromptTemplate keywordsTemplate; + + public Builder(ChatModel chatModel) { + Assert.notNull(chatModel, "The chatModel must not be null"); + this.chatModel = chatModel; + } + + public Builder keywordCount(int keywordCount) { + Assert.isTrue(keywordCount >= 1, "The keywordCount must be >= 1"); + this.keywordCount = keywordCount; + return this; + } + + public Builder keywordsTemplate(PromptTemplate keywordsTemplate) { + Assert.notNull(keywordsTemplate, "The keywordsTemplate must not be null"); + this.keywordsTemplate = keywordsTemplate; + return this; + } + + public KeywordMetadataEnricher build() { + if (this.keywordsTemplate != null) { + + if (this.keywordCount != 0) { + logger.warn("keywordCount will be ignored as keywordsTemplate is set."); + } + + return new KeywordMetadataEnricher(this.chatModel, this.keywordsTemplate); + } + + return new KeywordMetadataEnricher(this.chatModel, this.keywordCount); + } + + } + } diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/transformer/KeywordMetadataEnricherTest.java b/spring-ai-model/src/test/java/org/springframework/ai/model/transformer/KeywordMetadataEnricherTest.java new file mode 100644 index 00000000000..a708671d405 --- /dev/null +++ b/spring-ai-model/src/test/java/org/springframework/ai/model/transformer/KeywordMetadataEnricherTest.java @@ -0,0 +1,168 @@ +package org.springframework.ai.model.transformer; + +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.document.Document; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.*; +import static org.springframework.ai.model.transformer.KeywordMetadataEnricher.*; + +/** + * @author YunKui Lu + */ +@ExtendWith(MockitoExtension.class) +class KeywordMetadataEnricherTest { + + @Mock + private ChatModel chatModel; + + @Captor + private ArgumentCaptor promptCaptor; + + private final String CUSTOM_TEMPLATE = "Custom template: {context_str}"; + + @Test + void testUseWithDefaultTemplate() { + // 1. Prepare test data + // @formatter:off + List documents = List.of( + new Document("content1"), + new Document("content2"), + new Document("content3"));// @formatter:on + int keywordCount = 3; + + // 2. Mock + given(chatModel.call(any(Prompt.class))).willReturn( + new ChatResponse(List.of(new Generation(new AssistantMessage("keyword1-1, keyword1-2, keyword1-3")))), + new ChatResponse(List.of(new Generation(new AssistantMessage("keyword2-1, keyword2-2, keyword2-3")))), + new ChatResponse(List.of(new Generation(new AssistantMessage("keyword3-1, keyword3-2, keyword3-3"))))); + + // 3. Create instance + KeywordMetadataEnricher keywordMetadataEnricher = new KeywordMetadataEnricher(chatModel, keywordCount); + + // 4. Apply + keywordMetadataEnricher.apply(documents); + + // 5. Assert + verify(chatModel, times(3)).call(promptCaptor.capture()); + + assertThat(promptCaptor.getAllValues().get(0).getUserMessage().getText()) + .isEqualTo(getDefaultTemplatePromptText(keywordCount, "content1")); + assertThat(promptCaptor.getAllValues().get(1).getUserMessage().getText()) + .isEqualTo(getDefaultTemplatePromptText(keywordCount, "content2")); + assertThat(promptCaptor.getAllValues().get(2).getUserMessage().getText()) + .isEqualTo(getDefaultTemplatePromptText(keywordCount, "content3")); + + assertThat(documents.get(0).getMetadata()).containsEntry(EXCERPT_KEYWORDS_METADATA_KEY, + "keyword1-1, keyword1-2, keyword1-3"); + assertThat(documents.get(1).getMetadata()).containsEntry(EXCERPT_KEYWORDS_METADATA_KEY, + "keyword2-1, keyword2-2, keyword2-3"); + assertThat(documents.get(2).getMetadata()).containsEntry(EXCERPT_KEYWORDS_METADATA_KEY, + "keyword3-1, keyword3-2, keyword3-3"); + } + + @Test + void testUseCustomTemplate() { + // 1. Prepare test data + // @formatter:off + List documents = List.of( + new Document("content1"), + new Document("content2"), + new Document("content3"));// @formatter:on + PromptTemplate promptTemplate = new PromptTemplate(CUSTOM_TEMPLATE); + + // 2. Mock + given(chatModel.call(any(Prompt.class))).willReturn( + new ChatResponse(List.of(new Generation(new AssistantMessage("keyword1-1, keyword1-2, keyword1-3")))), + new ChatResponse(List.of(new Generation(new AssistantMessage("keyword2-1, keyword2-2, keyword2-3")))), + new ChatResponse(List.of(new Generation(new AssistantMessage("keyword3-1, keyword3-2, keyword3-3"))))); + + // 3. Create instance + KeywordMetadataEnricher keywordMetadataEnricher = new KeywordMetadataEnricher(this.chatModel, promptTemplate); + + // 4. Apply + keywordMetadataEnricher.apply(documents); + + // 5. Assert + verify(chatModel, times(documents.size())).call(promptCaptor.capture()); + + assertThat(promptCaptor.getAllValues().get(0).getUserMessage().getText()) + .isEqualTo("Custom template: content1"); + assertThat(promptCaptor.getAllValues().get(1).getUserMessage().getText()) + .isEqualTo("Custom template: content2"); + assertThat(promptCaptor.getAllValues().get(2).getUserMessage().getText()) + .isEqualTo("Custom template: content3"); + + assertThat(documents.get(0).getMetadata()).containsEntry(EXCERPT_KEYWORDS_METADATA_KEY, + "keyword1-1, keyword1-2, keyword1-3"); + assertThat(documents.get(1).getMetadata()).containsEntry(EXCERPT_KEYWORDS_METADATA_KEY, + "keyword2-1, keyword2-2, keyword2-3"); + assertThat(documents.get(2).getMetadata()).containsEntry(EXCERPT_KEYWORDS_METADATA_KEY, + "keyword3-1, keyword3-2, keyword3-3"); + } + + @Test + void testConstructorThrowsException() { + assertThrows(IllegalArgumentException.class, () -> new KeywordMetadataEnricher(null, 3), + "chatModel must not be null"); + + assertThrows(IllegalArgumentException.class, () -> new KeywordMetadataEnricher(chatModel, 0), + "keywordCount must be >= 1"); + + assertThrows(IllegalArgumentException.class, () -> new KeywordMetadataEnricher(chatModel, null), + "keywordsTemplate must not be null"); + } + + @Test + void testBuilderThrowsException() { + assertThrows(IllegalArgumentException.class, () -> KeywordMetadataEnricher.builder(null), + "The chatModel must not be null"); + + Builder builder = builder(chatModel); + assertThrows(IllegalArgumentException.class, () -> builder.keywordCount(0), "The keywordCount must be >= 1"); + + assertThrows(IllegalArgumentException.class, () -> builder.keywordsTemplate(null), + "The keywordsTemplate must not be null"); + } + + @Test + void testBuilderWithKeywordCount() { + int keywordCount = 3; + KeywordMetadataEnricher enricher = builder(chatModel).keywordCount(keywordCount).build(); + + assertThat(enricher.getKeywordsTemplate().getTemplate()) + .isEqualTo(String.format(KEYWORDS_TEMPLATE, keywordCount)); + } + + @Test + void testBuilderWithKeywordsTemplate() { + PromptTemplate template = new PromptTemplate(CUSTOM_TEMPLATE); + KeywordMetadataEnricher enricher = builder(chatModel).keywordsTemplate(template).build(); + + assertThat(enricher).extracting("chatModel", "keywordsTemplate").containsExactly(chatModel, template); + } + + private String getDefaultTemplatePromptText(int keywordCount, String documentContent) { + PromptTemplate promptTemplate = new PromptTemplate(String.format(KEYWORDS_TEMPLATE, keywordCount)); + Prompt prompt = promptTemplate.create(Map.of(CONTEXT_STR_PLACEHOLDER, documentContent)); + return prompt.getContents(); + } + +}