diff --git a/backend/prompt_studio/prompt_profile_manager_v2/migrations/0002_alter_profilemanager_is_summarize_llm.py b/backend/prompt_studio/prompt_profile_manager_v2/migrations/0002_alter_profilemanager_is_summarize_llm.py deleted file mode 100644 index dc635f08f3..0000000000 --- a/backend/prompt_studio/prompt_profile_manager_v2/migrations/0002_alter_profilemanager_is_summarize_llm.py +++ /dev/null @@ -1,20 +0,0 @@ -# Generated by Django 4.2.1 on 2025-07-28 12:19 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - ("prompt_profile_manager_v2", "0001_initial"), - ] - - operations = [ - migrations.AlterField( - model_name="profilemanager", - name="is_summarize_llm", - field=models.BooleanField( - db_comment="DEPRECATED: Default LLM Profile used for summarizing. Use CustomTool.summarize_llm_adapter instead.", - default=False, - ), - ), - ] diff --git a/backend/prompt_studio/prompt_profile_manager_v2/migrations/0002_alter_profilemanager_retrieval_strategy.py b/backend/prompt_studio/prompt_profile_manager_v2/migrations/0002_merged_retrieval_and_summarize.py similarity index 68% rename from backend/prompt_studio/prompt_profile_manager_v2/migrations/0002_alter_profilemanager_retrieval_strategy.py rename to backend/prompt_studio/prompt_profile_manager_v2/migrations/0002_merged_retrieval_and_summarize.py index 56c8f8f92e..e99706c4f7 100644 --- a/backend/prompt_studio/prompt_profile_manager_v2/migrations/0002_alter_profilemanager_retrieval_strategy.py +++ b/backend/prompt_studio/prompt_profile_manager_v2/migrations/0002_merged_retrieval_and_summarize.py @@ -1,4 +1,4 @@ -# Generated by Django 4.2.1 on 2025-07-30 09:11 +# Generated by Django 4.2.1 on 2025-08-05 10:27 from django.db import migrations, models @@ -9,11 +9,20 @@ class Migration(migrations.Migration): ] operations = [ + migrations.AlterField( + model_name="profilemanager", + name="is_summarize_llm", + field=models.BooleanField( + db_comment="DEPRECATED: Default LLM Profile used for summarizing. Use CustomTool.summarize_llm_adapter instead.", + default=False, + ), + ), migrations.AlterField( model_name="profilemanager", name="retrieval_strategy", field=models.TextField( blank=True, + default="simple", choices=[ ("simple", "Simple retrieval"), ("subquestion", "Subquestion retrieval"), diff --git a/backend/prompt_studio/prompt_profile_manager_v2/migrations/0003_alter_profilemanager_retrieval_strategy.py b/backend/prompt_studio/prompt_profile_manager_v2/migrations/0003_alter_profilemanager_retrieval_strategy.py deleted file mode 100644 index f6fbca8ed4..0000000000 --- a/backend/prompt_studio/prompt_profile_manager_v2/migrations/0003_alter_profilemanager_retrieval_strategy.py +++ /dev/null @@ -1,30 +0,0 @@ -# Generated by Django 4.2.1 on 2025-08-01 10:04 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - ("prompt_profile_manager_v2", "0002_alter_profilemanager_retrieval_strategy"), - ] - - operations = [ - migrations.AlterField( - model_name="profilemanager", - name="retrieval_strategy", - field=models.TextField( - blank=True, - choices=[ - ("simple", "Simple retrieval"), - ("subquestion", "Subquestion retrieval"), - ("fusion", "Fusion retrieval"), - ("recursive", "Recursive retrieval"), - ("router", "Router retrieval"), - ("keyword_table", "Keyword table retrieval"), - ("automerging", "Auto-merging retrieval"), - ], - db_comment="Field to store the retrieval strategy for prompts", - default="simple", - ), - ), - ] diff --git a/unstract/prompt-service-helpers/__init__.py b/unstract/prompt-service-helpers/__init__.py new file mode 100644 index 0000000000..2ac55449aa --- /dev/null +++ b/unstract/prompt-service-helpers/__init__.py @@ -0,0 +1,3 @@ +"""Unstract Prompt Service Helpers - Agentic RAG Implementation""" + +__version__ = "0.1.0" diff --git a/unstract/prompt-service-helpers/agentic_extraction/__init__.py b/unstract/prompt-service-helpers/agentic_extraction/__init__.py new file mode 100644 index 0000000000..067eaafd35 --- /dev/null +++ b/unstract/prompt-service-helpers/agentic_extraction/__init__.py @@ -0,0 +1,11 @@ +"""Agentic extraction module for multi-agent data extraction using Autogen GraphFlow. +This module provides RAG-enabled agents for document data extraction. +""" + +from .agent_factory import AgentFactory +from .agentic_extraction_task import execute_agentic_extraction + +__all__ = [ + "execute_agentic_extraction", + "AgentFactory", +] diff --git a/unstract/prompt-service-helpers/agentic_extraction/agent_factory.py b/unstract/prompt-service-helpers/agentic_extraction/agent_factory.py new file mode 100644 index 0000000000..b788a04fd4 --- /dev/null +++ b/unstract/prompt-service-helpers/agentic_extraction/agent_factory.py @@ -0,0 +1,480 @@ +"""Agent factory for creating Autogen agents with RAG tool integration. +This factory creates specialized extraction agents that use RAG for document retrieval, +following the answer_prompt format from the current prompt service. +""" + +import logging +from typing import Any + +from autogen_agentchat.agents import AssistantAgent + +from .tools.rag_tool import RAGTool, RetrievalStrategy + +logger = logging.getLogger(__name__) + + +class AgentFactory: + """Factory for creating Autogen agents with RAG tool integration. + Focuses on creating agents that can use RAG for document-based extraction, + using answer_prompt format field configurations. + """ + + def __init__(self, doc_id: str, platform_key: str | None = None): + """Initialize agent factory for creating field-specific RAG tools. + + Args: + doc_id: Document identifier for retrieval + platform_key: Platform API key + """ + self.doc_id = doc_id + self.platform_key = platform_key + + def create_agent( + self, + agent_config: dict[str, Any], + field_config: dict[str, Any] | None = None, + ) -> AssistantAgent: + """Create an Autogen agent based on configuration with field-specific RAG. + + Args: + agent_config: Agent configuration from digraph generation + field_config: Field configuration from answer_prompt format (optional) + + Returns: + AssistantAgent instance with field-specific RAG tool integration + """ + agent_name = agent_config.get("name", "extraction_agent") + agent_type = agent_config.get("agent_type", "AssistantAgent") + system_message = agent_config.get("system_message", "") + tools = agent_config.get("tools", []) + + # Create field-specific RAG tool if needed + rag_tool = None + if "rag" in tools and field_config: + rag_tool = self._create_field_specific_rag_tool(field_config) + elif "rag" in tools: + # Default RAG tool + rag_tool = RAGTool( + doc_id=self.doc_id, + platform_key=self.platform_key, + retrieval_strategy=RetrievalStrategy.SIMPLE, + ) + + # Enhance system message with RAG tool instructions + enhanced_system_message = self._enhance_system_message_with_rag( + system_message, agent_name, tools, field_config + ) + + # Create LLM config with field-specific settings + llm_config = self._create_llm_config(field_config) + + # Add RAG tool to function calling if agent uses RAG + if rag_tool: + llm_config["functions"] = [rag_tool.to_autogen_function()] + + # Create the agent + agent = AssistantAgent( + name=agent_name, + system_message=enhanced_system_message, + llm_config=llm_config, + ) + + logger.info( + f"Created agent: {agent_name} with tools: {tools} for field: {field_config.get('name', 'unknown') if field_config else 'none'}" + ) + return agent + + def _create_field_specific_rag_tool(self, field_config: dict[str, Any]) -> RAGTool: + """Create a RAG tool with field-specific configuration from answer_prompt format. + + Args: + field_config: Field configuration from answer_prompt + + Returns: + Configured RAG tool instance + """ + # Extract field-specific parameters + chunk_size = field_config.get("chunk-size", None) + chunk_overlap = field_config.get("chunk-overlap", None) + top_k = field_config.get("similarity-top-k", 5) + strategy = field_config.get("retrieval-strategy", "simple") + embedding_id = field_config.get("embedding", None) + vector_db_id = field_config.get("vector-db", None) + + # Map strategy string to enum + try: + retrieval_strategy = RetrievalStrategy(strategy) + except ValueError: + logger.warning(f"Unknown retrieval strategy: {strategy}, using simple") + retrieval_strategy = RetrievalStrategy.SIMPLE + + # Create RAG tool with field-specific configuration + return RAGTool( + doc_id=self.doc_id, + platform_key=self.platform_key, + embedding_instance_id=embedding_id, + vector_db_instance_id=vector_db_id, + top_k=top_k, + retrieval_strategy=retrieval_strategy, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + + def _create_llm_config( + self, field_config: dict[str, Any] | None = None + ) -> dict[str, Any]: + """Create LLM configuration with field-specific settings. + + Args: + field_config: Field configuration from answer_prompt + + Returns: + LLM configuration dictionary + """ + llm_config = { + "model": "gpt-4", + "temperature": 0.1, + "timeout": 300, + } + + if field_config: + # Use field-specific LLM if provided + llm_instance_id = field_config.get("llm") + if llm_instance_id: + llm_config["llm_instance_id"] = llm_instance_id + + return llm_config + + def _enhance_system_message_with_rag( + self, + base_system_message: str, + agent_name: str, + tools: list[str], + field_config: dict[str, Any] | None = None, + ) -> str: + """Enhance system message with field-specific RAG tool instructions. + + Args: + base_system_message: Original system message + agent_name: Name of the agent + tools: List of tools available to the agent + field_config: Field configuration from answer_prompt + + Returns: + Enhanced system message with RAG instructions + """ + if "rag" not in tools: + return base_system_message + + # Get field-specific details + field_name = field_config.get("name", "unknown") if field_config else "unknown" + field_type = field_config.get("type", "text") if field_config else "text" + retrieval_strategy = ( + field_config.get("retrieval-strategy", "simple") if field_config else "simple" + ) + chunk_size = ( + field_config.get("chunk-size", "default") if field_config else "default" + ) + + rag_instructions = f""" + +RAG Tool Instructions: +You have access to a RAG (Retrieval-Augmented Generation) tool for document retrieval. +Document ID: {self.doc_id} +Field: {field_name} (Type: {field_type}) +Retrieval Strategy: {retrieval_strategy} +Chunk Size: {chunk_size} + +Use the RAG tool to: +1. Search for relevant information: rag_search(query="your search query") +2. Get context for your specific field: Use queries related to {field_name} +3. Verify your extractions: Search for confirming or contradicting information + +RAG Tool Usage for {field_type} field: +- Always search for relevant content before making extractions +- Use specific queries related to {field_name} +- Include multiple search queries if needed to gather comprehensive information +- The retrieval strategy is optimized for {retrieval_strategy} search patterns + +Example RAG usage for {field_name}: +- Primary search: rag_search(query="{field_name}") +- Context search: rag_search(query="{field_name} context information") +- Verification search: rag_search(query="confirm {field_name} details") + +For {field_type} fields: +- Focus on extracting precise {field_type} data +- Use the configured retrieval strategy ({retrieval_strategy}) for optimal results +- Chunk size is set to {chunk_size} for this field + +Remember: The RAG tool retrieves actual content from the document using answer_prompt compatible retrieval. +""" + + return base_system_message + rag_instructions + + def create_generic_extraction_agent( + self, + field_config: dict[str, Any], + required: bool = False, + ) -> AssistantAgent: + """Create a generic data extraction agent with field-specific RAG from answer_prompt format. + + Args: + field_config: Field configuration from answer_prompt outputs + required: Whether the field is required + + Returns: + AssistantAgent for generic extraction + """ + field_name = field_config.get("name", "unknown_field") + field_prompt = field_config.get("prompt", "") + field_type = field_config.get("type", "text") + + system_message = f"""You are a generic data extraction agent. Your task is to extract the field '{field_name}' from the document. + +Field: {field_name} +Type: {field_type} +Prompt: {field_prompt} +Required: {required} +Document ID: {self.doc_id} + +Instructions: +1. Use the RAG tool to search for relevant information about this field +2. Search with multiple queries if needed to gather comprehensive information +3. Extract the specific information requested for this field +4. Ensure accuracy and completeness based on the field type ({field_type}) +5. If information is not found, clearly state it's unavailable + +Extraction Process: +1. Start by searching for content related to the field: rag_search(query="{field_name}") +2. Use field-specific queries: rag_search(query="{field_prompt}") +3. Analyze the retrieved information +4. Extract the specific value requested +5. Verify your extraction with additional searches if uncertain +6. Provide the final extracted value + +Output format: Return only the extracted value for the field, formatted according to type {field_type}.""" + + agent_config = { + "name": f"generic_extraction_agent_{field_name}", + "agent_type": "AssistantAgent", + "system_message": system_message, + "tools": ["rag"], + } + + return self.create_agent(agent_config, field_config) + + def create_table_extraction_agent( + self, + field_config: dict[str, Any], + required: bool = False, + ) -> AssistantAgent: + """Create a table data extraction agent with field-specific RAG from answer_prompt format. + + Args: + field_config: Field configuration from answer_prompt outputs + required: Whether the field is required + + Returns: + AssistantAgent for table extraction + """ + field_name = field_config.get("name", "unknown_table") + field_prompt = field_config.get("prompt", "") + table_settings = field_config.get("table_settings", {}) + + system_message = f"""You are a table data extraction agent. Your task is to extract the field '{field_name}' which contains tabular data. + +Field: {field_name} +Type: table +Prompt: {field_prompt} +Required: {required} +Table Settings: {table_settings} +Document ID: {self.doc_id} + +Instructions: +1. Use the RAG tool to search for tables and tabular data in the document +2. Look for structured data, tables, lists, or formatted information +3. Search with queries like "table", "data", the field name, and related terms +4. Extract and preserve the table structure +5. Format the output according to table_settings if provided + +Table Extraction Process: +1. Search for table-related content: rag_search(query="table {field_name}") +2. Search for structured data: rag_search(query="data list {field_name}") +3. Look for specific table elements mentioned in the field description +4. Use the field prompt for targeted searches: rag_search(query="{field_prompt}") +5. Combine and structure the found tabular information +6. Present in a clear, structured format + +Output format: Return the table data in a structured format (JSON, CSV-like, or clear text structure) based on table_settings.""" + + agent_config = { + "name": f"table_extraction_agent_{field_name}", + "agent_type": "AssistantAgent", + "system_message": system_message, + "tools": ["rag"], + } + + return self.create_agent(agent_config, field_config) + + def create_omniparse_extraction_agent( + self, + field_config: dict[str, Any], + required: bool = False, + ) -> AssistantAgent: + """Create an omniparse data extraction agent with field-specific RAG from answer_prompt format. + + Args: + field_config: Field configuration from answer_prompt outputs + required: Whether the field is required + + Returns: + AssistantAgent for complex extraction + """ + field_name = field_config.get("name", "unknown_field") + field_prompt = field_config.get("prompt", "") + field_type = field_config.get("type", "text") + + system_message = f"""You are an omniparse data extraction agent specialized in complex document formats. Your task is to extract the field '{field_name}'. + +Field: {field_name} +Type: {field_type} +Prompt: {field_prompt} +Required: {required} +Document ID: {self.doc_id} + +Instructions: +1. Use the RAG tool to search for complex or visual content related to this field +2. Handle non-standard formats, visual elements, or complex layouts +3. Search with multiple approaches to find the information +4. Look for information that might be in charts, diagrams, or complex structures +5. Use comprehensive search strategies with the configured retrieval method + +Complex Extraction Process: +1. Broad search: rag_search(query="{field_name}") +2. Prompt-based search: rag_search(query="{field_prompt}") +3. Visual/format search: rag_search(query="chart diagram figure {field_name}") +4. Context search: rag_search(query="visual image content {field_name}") +5. Structure search: rag_search(query="layout format {field_name}") +6. Combine information from multiple sources +7. Extract and interpret the complex data + +Output format: Return the extracted information with clear indication of its source and format, formatted as {field_type}.""" + + agent_config = { + "name": f"omniparse_extraction_agent_{field_name}", + "agent_type": "AssistantAgent", + "system_message": system_message, + "tools": ["rag"], + } + + return self.create_agent(agent_config, field_config) + + def create_challenger_agent(self, fields: list[dict[str, Any]]) -> AssistantAgent: + """Create a challenger agent for validation with RAG using answer_prompt field configurations. + + Args: + fields: List of field configurations from answer_prompt outputs + + Returns: + AssistantAgent for validation + """ + field_names = [f.get("name", "") for f in fields] + field_types = {f.get("name", ""): f.get("type", "text") for f in fields} + + system_message = f"""You are a challenger agent responsible for validating extracted data quality using RAG. + +Fields to validate: {', '.join(field_names)} +Field types: {field_types} +Document ID: {self.doc_id} + +Your role: +1. Review all extracted field values from other agents +2. Use RAG to verify each extraction against the source document +3. Challenge incorrect, incomplete, or inconsistent extractions +4. Verify that required fields are properly extracted +5. Check for logical consistency between related fields +6. Validate data types match expected field types + +Validation process: +1. For each extracted field, use RAG to search for confirming evidence +2. Use queries like: rag_search(query="[field_name] [extracted_value]") +3. Look for contradictory information +4. Verify completeness and accuracy according to field type +5. Check format consistency (e.g., dates, numbers, emails) +6. Provide specific feedback for corrections + +If you find issues, clearly state what needs to be corrected and why. +If extractions are accurate and properly formatted, approve them for final collation. + +Output format: For each field, state "APPROVED: [field_name]" or "REJECTED: [field_name] - [specific reason with evidence]" """ + + agent_config = { + "name": "challenger_agent", + "agent_type": "AssistantAgent", + "system_message": system_message, + "tools": ["rag"], + } + + return self.create_agent(agent_config) + + def create_collation_agent(self, fields: list[dict[str, Any]]) -> AssistantAgent: + """Create a data collation agent using answer_prompt field configurations. + + Args: + fields: List of field configurations from answer_prompt outputs + + Returns: + AssistantAgent for collation + """ + field_names = [f.get("name", "") for f in fields] + field_types = {f.get("name", ""): f.get("type", "text") for f in fields} + field_json_structure = ",\n".join( + [ + f' "{name}": "extracted_value" // Type: {field_types.get(name, "text")}' + for name in field_names + ] + ) + + system_message = f"""You are a data collation agent responsible for combining all validated field values into the final output. + +Fields to collate: {', '.join(field_names)} +Field types: {field_types} + +Your role: +1. Collect all validated field values from extraction agents +2. Resolve any remaining conflicts between extractions +3. Format the final output as a structured JSON object +4. Ensure all required fields are included +5. Apply any final formatting or transformations based on field types +6. Validate data types before final output + +Output format: +{{ +{field_json_structure} +}} + +Instructions: +- Use the most recent validated values for each field +- If multiple values exist for a field, use your judgment to select the best one +- Ensure the output JSON is properly formatted +- Include null values for fields that couldn't be extracted +- Respect field types when formatting values: + * text: string values + * number: numeric values + * date: ISO date format + * email: valid email format + * boolean: true/false + * json: valid JSON object + * table: structured table format +- Do not use RAG - only work with the extracted values provided by other agents + +Note: You do not have access to RAG tool - focus on organizing and formatting the extracted data.""" + + agent_config = { + "name": "data_collation_agent", + "agent_type": "AssistantAgent", + "system_message": system_message, + "tools": [], # No RAG for collation agent + } + + return self.create_agent(agent_config) diff --git a/unstract/prompt-service-helpers/agentic_extraction/agentic_extraction_task.py b/unstract/prompt-service-helpers/agentic_extraction/agentic_extraction_task.py new file mode 100644 index 0000000000..5ce9c57525 --- /dev/null +++ b/unstract/prompt-service-helpers/agentic_extraction/agentic_extraction_task.py @@ -0,0 +1,361 @@ +"""Celery task for agentic data extraction using Autogen GraphFlow. +This task takes the generated digraph and executes the multi-agent extraction workflow. +""" + +import json +import logging +from typing import Any + +# Autogen imports +from autogen_agentchat.agents import AssistantAgent +from autogen_agentchat.teams import DiGraphBuilder, GraphFlow +from celery import shared_task + +from .agent_factory import AgentFactory +from .tools.rag_tool import RAGTool + +logger = logging.getLogger(__name__) + + +@shared_task(bind=True, name="execute_agentic_extraction") +def execute_agentic_extraction( + self, + digraph_output: dict[str, Any], + answer_prompt_payload: dict[str, Any], + doc_id: str, + extraction_task: str | None = None, + platform_key: str | None = None, +) -> dict[str, Any]: + """Execute agentic data extraction using Autogen GraphFlow with answer_prompt format. + + Args: + digraph_output: Output from digraph generation containing graph structure + answer_prompt_payload: Answer prompt payload with outputs and tool_settings + doc_id: Document ID for RAG access + extraction_task: Custom extraction task description + platform_key: Platform API key for SDK operations + + Returns: + Dict containing: + - final_output: Final extracted data + - agent_results: Individual agent results + - execution_metadata: Execution details + - performance_metrics: Performance information + """ + task_id = self.request.id + logger.info(f"[Task {task_id}] Starting agentic data extraction") + + try: + # Step 1: Initialize RAG tool + rag_tool = RAGTool(doc_id=doc_id) + + # Step 2: Create agent factory with RAG tool + agent_factory = AgentFactory(rag_tool=rag_tool) + + # Step 3: Recreate the GraphFlow from digraph output + graph_flow = recreate_graph_flow_from_digraph( + digraph_output, agent_factory, extraction_spec + ) + + # Step 4: Define extraction task + if not extraction_task: + fields = extraction_spec.get("fields", []) + field_names = [f.get("name", "") for f in fields] + extraction_task = f""" + Extract the following fields from the document: + {', '.join(field_names)} + + Document ID: {doc_id} + + Each agent should: + 1. Focus on their assigned field(s) + 2. Use available tools (RAG, Calculator, String operations) as needed + 3. Provide accurate and complete extractions + 4. Pass results to the next agent in the workflow + + Final output should be a structured JSON with all extracted fields. + """ + + # Step 5: Execute the GraphFlow + logger.info( + f"[Task {task_id}] Executing GraphFlow with {len(graph_flow.participants)} agents" + ) + execution_results = execute_graph_flow_team(graph_flow, extraction_task) + + # Step 6: Process and format results + final_output = process_execution_results(execution_results, extraction_spec) + + # Step 7: Calculate performance metrics + performance_metrics = calculate_performance_metrics(execution_results, task_id) + + # Prepare final result + result = { + "final_output": final_output, + "agent_results": execution_results.get("all_results", {}), + "execution_metadata": { + "task_id": task_id, + "doc_id": doc_id, + "agent_count": len(graph_flow.participants), + "execution_status": execution_results.get("execution_status", "unknown"), + }, + "performance_metrics": performance_metrics, + } + + logger.info(f"[Task {task_id}] Agentic extraction completed successfully") + return result + + except Exception as e: + logger.error(f"[Task {task_id}] Error in agentic extraction: {str(e)}") + raise + + +def recreate_graph_flow_from_digraph( + digraph_output: dict[str, Any], + agent_factory: AgentFactory, + extraction_spec: dict[str, Any], +) -> GraphFlow: + """Recreate GraphFlow from digraph generation output. + + Args: + digraph_output: Output from digraph generation + agent_factory: Factory for creating agents with tools + extraction_spec: Original extraction specification + + Returns: + GraphFlow instance ready for execution + """ + # Extract agent configurations from digraph output + agent_configs = digraph_output.get("agents", []) + execution_plan = digraph_output.get("execution_plan", {}) + + # Create new DiGraphBuilder + builder = DiGraphBuilder() + + # Create agents using the factory + agents = [] + for agent_config in agent_configs: + agent = agent_factory.create_agent(agent_config, extraction_spec) + agents.append(agent) + builder.add_node(agent) + + # Recreate edges based on execution plan + recreate_edges_from_plan(builder, agents, execution_plan, extraction_spec) + + # Build graph and create GraphFlow + graph = builder.build() + graph_flow = GraphFlow(participants=builder.get_participants(), graph=graph) + + return graph_flow + + +def recreate_edges_from_plan( + builder: DiGraphBuilder, + agents: list[AssistantAgent], + execution_plan: dict[str, Any], + extraction_spec: dict[str, Any], +) -> None: + """Recreate edges in the DiGraphBuilder based on execution plan. + + Args: + builder: DiGraphBuilder instance + agents: List of created agents + execution_plan: Execution plan from digraph generation + extraction_spec: Original extraction specification + """ + # Create agent lookup + agent_map = {agent.name: agent for agent in agents} + + # Get dependencies from extraction spec + dependencies = extraction_spec.get("dependencies", {}) + tool_settings = extraction_spec.get("tool_settings", {}) + + # Add dependency edges + for field_name, dep_fields in dependencies.items(): + target_agent = None + for agent in agents: + if field_name in agent.name: + target_agent = agent + break + + if target_agent: + for dep_field in dep_fields: + source_agent = None + for agent in agents: + if dep_field in agent.name: + source_agent = agent + break + + if source_agent: + builder.add_edge(source_agent, target_agent) + + # Add workflow edges (extraction → challenger → collation) + challenger_agent = None + collation_agent = None + + for agent in agents: + if "challenger_agent" in agent.name: + challenger_agent = agent + elif "data_collation_agent" in agent.name: + collation_agent = agent + + # Connect extraction agents to challenger or collation + if challenger_agent and tool_settings.get("enable_challenge", False): + # All extraction agents → challenger + for agent in agents: + if ( + agent != challenger_agent + and agent != collation_agent + and "extraction_agent" in agent.name + ): + builder.add_edge(agent, challenger_agent) + + # Challenger → collation with approval condition + if collation_agent: + builder.add_edge( + challenger_agent, + collation_agent, + condition=lambda msg: ( + "approved" in msg.content.lower() + or "validated" in msg.content.lower() + or "accepted" in msg.content.lower() + ), + ) + elif collation_agent: + # Direct extraction agents → collation + for agent in agents: + if agent != collation_agent and "extraction_agent" in agent.name: + builder.add_edge(agent, collation_agent) + + +def execute_graph_flow_team( + graph_flow: GraphFlow, + extraction_task: str, +) -> dict[str, Any]: + """Execute the GraphFlow team and collect results. + + Args: + graph_flow: GraphFlow instance to execute + extraction_task: Task description for the team + + Returns: + Execution results + """ + results = {} + final_output = None + events = [] + + try: + logger.info("Starting GraphFlow execution...") + + # Execute the GraphFlow + stream = graph_flow.run_stream(task=extraction_task) + + for event in stream: + # Log event details + logger.info(f"Event: {event.type}, Agent: {event.source}") + + # Store event for analysis + events.append( + { + "type": event.type, + "source": event.source, + "content": event.content[:200] if hasattr(event, "content") else None, + "timestamp": str(event.timestamp) + if hasattr(event, "timestamp") + else None, + } + ) + + # Store agent results + if hasattr(event, "content") and event.content: + results[event.source] = event.content + + # Check for final collation result + if event.source == "data_collation_agent": + try: + final_output = json.loads(event.content) + except json.JSONDecodeError: + final_output = event.content + + execution_status = "completed" if final_output else "incomplete" + + except Exception as e: + logger.error(f"Error during GraphFlow execution: {str(e)}") + execution_status = "error" + final_output = None + + return { + "final_output": final_output, + "all_results": results, + "execution_status": execution_status, + "events": events, + } + + +def process_execution_results( + execution_results: dict[str, Any], + extraction_spec: dict[str, Any], +) -> dict[str, Any]: + """Process and validate execution results. + + Args: + execution_results: Raw execution results from GraphFlow + extraction_spec: Original extraction specification + + Returns: + Processed final output + """ + final_output = execution_results.get("final_output") + + if not final_output: + # Try to construct output from individual agent results + agent_results = execution_results.get("all_results", {}) + fields = extraction_spec.get("fields", []) + + constructed_output = {} + for field in fields: + field_name = field.get("name", "") + # Look for results from agents that handled this field + for agent_name, result in agent_results.items(): + if field_name in agent_name: + try: + # Try to extract structured data + if isinstance(result, str) and result.strip().startswith("{"): + parsed_result = json.loads(result) + if field_name in parsed_result: + constructed_output[field_name] = parsed_result[field_name] + else: + constructed_output[field_name] = result + except (json.JSONDecodeError, KeyError): + constructed_output[field_name] = result + break + + if constructed_output: + final_output = constructed_output + + return final_output or {} + + +def calculate_performance_metrics( + execution_results: dict[str, Any], + task_id: str, +) -> dict[str, Any]: + """Calculate performance metrics for the execution. + + Args: + execution_results: Execution results + task_id: Task identifier + + Returns: + Performance metrics + """ + events = execution_results.get("events", []) + agent_results = execution_results.get("all_results", {}) + + return { + "total_events": len(events), + "agents_executed": len(agent_results), + "execution_status": execution_results.get("execution_status", "unknown"), + "has_final_output": execution_results.get("final_output") is not None, + "task_id": task_id, + } diff --git a/unstract/prompt-service-helpers/agentic_extraction/tools/__init__.py b/unstract/prompt-service-helpers/agentic_extraction/tools/__init__.py new file mode 100644 index 0000000000..7daa41aef5 --- /dev/null +++ b/unstract/prompt-service-helpers/agentic_extraction/tools/__init__.py @@ -0,0 +1,9 @@ +"""Tools module for agentic extraction agents. +Currently implements RAG tool for document retrieval and search. +""" + +from .rag_tool import RAGTool + +__all__ = [ + "RAGTool", +] diff --git a/unstract/prompt-service-helpers/agentic_extraction/tools/rag_tool.py b/unstract/prompt-service-helpers/agentic_extraction/tools/rag_tool.py new file mode 100644 index 0000000000..263eab5d3d --- /dev/null +++ b/unstract/prompt-service-helpers/agentic_extraction/tools/rag_tool.py @@ -0,0 +1,503 @@ +"""RAG (Retrieval-Augmented Generation) tool for Autogen agents. +This tool enables agents to search and retrieve relevant information from the document +using the same retrieval techniques as the current prompt service with LlamaIndex integration. +""" + +import logging +from enum import Enum +from typing import Any + +from llama_index.core.query_engine import RouterQueryEngine, SubQuestionQueryEngine +from llama_index.core.retrievers import QueryFusionRetriever +from llama_index.core.selectors import LLMSingleSelector +from llama_index.core.tools import QueryEngineTool + +# LlamaIndex components for retrieval strategies +from llama_index.core.vector_stores import ExactMatchFilter, MetadataFilters + +# Import Unstract SDK components for RAG functionality +from unstract.sdk.embedding import Embedding +from unstract.sdk.index import Index +from unstract.sdk.tool.base import BaseTool +from unstract.sdk.vector_db import VectorDB + +logger = logging.getLogger(__name__) + + +class RetrievalStrategy(str, Enum): + """Retrieval strategies matching current prompt service.""" + + SIMPLE = "simple" + SUBQUESTION = "subquestion" + FUSION = "fusion" + RECURSIVE = "recursive" + ROUTER = "router" + KEYWORD_TABLE = "keyword_table" + AUTOMERGING = "automerging" + + +class RAGTool: + """RAG tool for Autogen agents to retrieve relevant document content. + Integrates with Unstract SDK and uses LlamaIndex retrieval strategies + matching the current prompt service implementation. + """ + + def __init__( + self, + doc_id: str, + platform_key: str | None = None, + embedding_instance_id: str | None = None, + vector_db_instance_id: str | None = None, + top_k: int = 5, + retrieval_strategy: RetrievalStrategy = RetrievalStrategy.SIMPLE, + chunk_size: int | None = None, + chunk_overlap: int | None = None, + ): + """Initialize RAG tool with retrieval strategies. + + Args: + doc_id: Document identifier for retrieval + platform_key: Platform API key + embedding_instance_id: Embedding adapter instance ID + vector_db_instance_id: Vector DB adapter instance ID + top_k: Number of top results to retrieve + retrieval_strategy: Retrieval strategy to use + chunk_size: Chunk size for retrieval (0 = full document) + chunk_overlap: Chunk overlap for retrieval + """ + self.doc_id = doc_id + self.platform_key = platform_key or "default" + self.embedding_instance_id = embedding_instance_id or "default_embedding" + self.vector_db_instance_id = vector_db_instance_id or "default_vectordb" + self.top_k = top_k + self.retrieval_strategy = retrieval_strategy + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + + # Initialize SDK components + self._initialize_sdk_components() + + # Initialize LlamaIndex components + self._initialize_llama_index_retrievers() + + def _initialize_sdk_components(self): + """Initialize Unstract SDK components for RAG.""" + try: + # Create base tool for SDK operations + self.tool = BaseTool(platform_key=self.platform_key) + + # Initialize embedding + self.embedding = Embedding( + tool=self.tool, + adapter_instance_id=self.embedding_instance_id, + ) + + # Initialize vector DB + self.vector_db = VectorDB( + tool=self.tool, + adapter_instance_id=self.vector_db_instance_id, + embedding=self.embedding, + ) + + # Initialize index for querying + self.index = Index( + tool=self.tool, + run_id=f"rag_session_{self.doc_id}", + capture_metrics=True, + ) + + logger.info(f"RAG tool initialized for doc_id: {self.doc_id}") + + except Exception as e: + logger.error(f"Error initializing RAG tool: {str(e)}") + raise + + def _initialize_llama_index_retrievers(self): + """Initialize LlamaIndex retriever components.""" + try: + # Get vector store index from SDK + self.vector_query_engine = self.vector_db.get_vector_store_index() + + # Create document filter for this specific document + self.doc_filter = MetadataFilters( + filters=[ExactMatchFilter(key="doc_id", value=self.doc_id)] + ) + + # Initialize retrievers based on strategy + self._setup_retriever_strategy() + + logger.info( + f"LlamaIndex retrievers initialized with strategy: {self.retrieval_strategy}" + ) + + except Exception as e: + logger.error(f"Error initializing LlamaIndex retrievers: {str(e)}") + # Fallback to SDK-only approach if LlamaIndex fails + self.use_llamaindex = False + + def _setup_retriever_strategy(self): + """Setup retriever based on selected strategy.""" + if self.retrieval_strategy == RetrievalStrategy.SIMPLE: + self.retriever = self.vector_query_engine.as_retriever( + similarity_top_k=self.top_k, filters=self.doc_filter + ) + + elif self.retrieval_strategy == RetrievalStrategy.FUSION: + # Multi-query retriever with fusion + base_retriever = self.vector_query_engine.as_retriever( + similarity_top_k=self.top_k, filters=self.doc_filter + ) + self.retriever = QueryFusionRetriever( + retrievers=[base_retriever], + similarity_top_k=self.top_k, + num_queries=4, # Generate 4 query variations + mode="reciprocal_rerank", + use_async=True, + ) + + elif self.retrieval_strategy == RetrievalStrategy.SUBQUESTION: + # Sub-question query engine + query_engine = self.vector_query_engine.as_query_engine( + similarity_top_k=self.top_k, filters=self.doc_filter + ) + self.query_engine = SubQuestionQueryEngine.from_defaults( + query_engine_tools=[ + QueryEngineTool.from_defaults( + query_engine=query_engine, + description=f"Useful for retrieving specific facts about document {self.doc_id}", + ) + ] + ) + + elif self.retrieval_strategy == RetrievalStrategy.ROUTER: + # Router query engine with multiple search strategies + vector_query_engine = self.vector_query_engine.as_query_engine( + similarity_top_k=self.top_k, filters=self.doc_filter + ) + + # Create multiple query engines for routing + query_engine_tools = [ + QueryEngineTool.from_defaults( + query_engine=vector_query_engine, + description="Useful for semantic search and finding related content", + ), + ] + + self.query_engine = RouterQueryEngine( + selector=LLMSingleSelector.from_defaults(), + query_engine_tools=query_engine_tools, + ) + + else: + # Default to simple retriever + self.retriever = self.vector_query_engine.as_retriever( + similarity_top_k=self.top_k, filters=self.doc_filter + ) + + def search(self, query: str, context: str | None = None) -> list[dict[str, Any]]: + """Search for relevant content using configured retrieval strategy. + + Args: + query: Search query + context: Additional context for the search + + Returns: + List of relevant content chunks in answer_prompt format + """ + try: + # Handle chunk_size = 0 case (full document retrieval) + if self.chunk_size == 0: + return self._retrieve_full_document() + + # Enhance query with context if provided + if context: + enhanced_query = f"{context} {query}" + else: + enhanced_query = query + + # Use appropriate retrieval method based on strategy + if self.retrieval_strategy in [ + RetrievalStrategy.SUBQUESTION, + RetrievalStrategy.ROUTER, + ]: + results = self._query_with_engine(enhanced_query) + else: + results = self._retrieve_with_llamaindex(enhanced_query) + + # Format results consistently + formatted_results = self._format_search_results(results, query) + + logger.info( + f"RAG search completed: {len(formatted_results)} results for query: {query[:50]}..." + ) + return formatted_results + + except Exception as e: + logger.error(f"Error in RAG search: {str(e)}") + # Fallback to SDK search if LlamaIndex fails + return self._fallback_sdk_search(query, context) + + def _retrieve_with_llamaindex(self, query: str) -> Any: + """Retrieve using LlamaIndex retriever.""" + if hasattr(self, "retriever"): + return self.retriever.retrieve(query) + else: + # Fallback to simple retriever + retriever = self.vector_query_engine.as_retriever( + similarity_top_k=self.top_k, filters=self.doc_filter + ) + return retriever.retrieve(query) + + def _query_with_engine(self, query: str) -> Any: + """Query using LlamaIndex query engine.""" + if hasattr(self, "query_engine"): + response = self.query_engine.query(query) + return response.source_nodes if hasattr(response, "source_nodes") else [] + else: + return self._retrieve_with_llamaindex(query) + + def _retrieve_full_document(self) -> list[dict[str, Any]]: + """Retrieve full document content when chunk_size = 0.""" + try: + # Use SDK to get all document content + results = self.index.query_index( + embedding_instance_id=self.embedding_instance_id, + vector_db_instance_id=self.vector_db_instance_id, + doc_id=self.doc_id, + usage_kwargs={ + "query": "*", + "top_k": 1000, + }, # Large top_k for full retrieval + ) + return self._format_search_results(results, "full_document") + except Exception as e: + logger.error(f"Error retrieving full document: {str(e)}") + return [] + + def _fallback_sdk_search( + self, query: str, context: str | None = None + ) -> list[dict[str, Any]]: + """Fallback to SDK search when LlamaIndex fails.""" + try: + enhanced_query = f"{context} {query}" if context else query + + results = self.index.query_index( + embedding_instance_id=self.embedding_instance_id, + vector_db_instance_id=self.vector_db_instance_id, + doc_id=self.doc_id, + usage_kwargs={"query": enhanced_query, "top_k": self.top_k}, + ) + + return self._format_search_results(results, query) + except Exception as e: + logger.error(f"Error in SDK fallback search: {str(e)}") + return [] + + def _format_search_results(self, results: Any, query: str) -> list[dict[str, Any]]: + """Format search results to match answer_prompt context format. + + Args: + results: Raw results from retrieval + query: Original search query + + Returns: + Formatted results list matching prompt service format + """ + formatted_results = [] + + try: + # Handle LlamaIndex node results + if hasattr(results, "__iter__") and not isinstance(results, str): + for idx, result in enumerate(results[: self.top_k]): + if hasattr(result, "text") and hasattr(result, "score"): + # LlamaIndex node format + formatted_result = { + "chunk_id": getattr(result, "id_", f"chunk_{idx}"), + "content": result.text, + "score": getattr(result, "score", 1.0), + "metadata": getattr(result, "metadata", {}), + "section": getattr(result, "metadata", {}).get( + "section", "unknown" + ), + } + formatted_results.append(formatted_result) + elif hasattr(result, "node"): + # Handle nested node structure + node = result.node + formatted_result = { + "chunk_id": getattr(node, "id_", f"chunk_{idx}"), + "content": getattr(node, "text", str(node)), + "score": getattr(result, "score", 1.0), + "metadata": getattr(node, "metadata", {}), + "section": getattr(node, "metadata", {}).get( + "section", "unknown" + ), + } + formatted_results.append(formatted_result) + elif isinstance(result, dict): + # Already formatted result + formatted_results.append(result) + else: + # Generic result + formatted_results.append( + { + "chunk_id": f"chunk_{idx}", + "content": str(result), + "score": 1.0, + "metadata": {}, + "section": "unknown", + } + ) + + # Handle SDK results + elif hasattr(results, "nodes") and results.nodes: + for idx, node in enumerate(results.nodes[: self.top_k]): + formatted_result = { + "chunk_id": getattr(node, "id_", f"chunk_{idx}"), + "content": getattr(node, "text", str(node)), + "score": getattr(node, "score", 1.0), + "metadata": getattr(node, "metadata", {}), + "section": getattr(node, "metadata", {}).get( + "section", "unknown" + ), + } + formatted_results.append(formatted_result) + + except Exception as e: + logger.error(f"Error formatting search results: {str(e)}") + + return formatted_results + + def get_context_for_field(self, field_name: str, field_prompt: str) -> str: + """Get relevant context for a specific field extraction using configured strategy. + + Args: + field_name: Name of the field to extract + field_prompt: Prompt/description for the field + + Returns: + Formatted context string matching answer_prompt format + """ + # Create field-specific query + query = f"{field_name}: {field_prompt}" + results = self.search(query, context=f"extracting {field_name}") + + # Format context like answer_prompt service + context_parts = [] + for result in results: + content = result.get("content", "") + score = result.get("score", 0) + section = result.get("section", "unknown") + + if content and score > 0.3: # Lower threshold for field context + # Add section information if available + if section != "unknown": + context_parts.append(f"[Section: {section}]\n{content}") + else: + context_parts.append(content) + + if context_parts: + context = "\n\n---------------\n\n".join(context_parts) + return f"Context:\n---------------\n{context}\n-----------------" + else: + return f"No specific context found for {field_name}. Please extract from available document content." + + def verify_extraction( + self, field_name: str, extracted_value: str, confidence_threshold: float = 0.7 + ) -> dict[str, Any]: + """Verify an extracted value against the document using RAG. + + Args: + field_name: Name of the extracted field + extracted_value: The extracted value to verify + confidence_threshold: Minimum confidence for verification + + Returns: + Verification results with evidence + """ + # Search for content that might contradict or confirm the extraction + verification_query = f"{field_name} {extracted_value}" + results = self.search(verification_query, context="verification") + + verification_score = 0.0 + supporting_evidence = [] + contradicting_evidence = [] + + for result in results: + content = result.get("content", "").lower() + score = result.get("score", 0) + + # Enhanced verification logic + if extracted_value.lower() in content: + verification_score += score + supporting_evidence.append(result) + elif any( + keyword in content + for keyword in [ + "not", + "incorrect", + "wrong", + "different", + "instead", + "rather", + ] + ): + contradicting_evidence.append(result) + + is_verified = verification_score >= confidence_threshold + + return { + "is_verified": is_verified, + "confidence": min(verification_score, 1.0), + "supporting_evidence": supporting_evidence, + "contradicting_evidence": contradicting_evidence, + "verification_query": verification_query, + "strategy_used": self.retrieval_strategy.value, + } + + def get_tool_description(self) -> str: + """Get description of the RAG tool for agents.""" + return f""" +RAG Tool - Retrieval-Augmented Generation (Strategy: {self.retrieval_strategy.value}) +Document ID: {self.doc_id} +Chunk Size: {self.chunk_size or 'default'} +Top-K Results: {self.top_k} + +Functions: +- search(query, context=None): Search for relevant content using {self.retrieval_strategy.value} strategy +- get_context_for_field(field_name, field_description): Get formatted context for field extraction +- verify_extraction(field_name, extracted_value): Verify extracted values with evidence + +Usage examples: +- search("company financial data"): Find financial information +- get_context_for_field("revenue", "annual revenue amount"): Get context for revenue extraction +- verify_extraction("company_name", "Acme Corp"): Verify if company name is correct + +Returns structured results with content, relevance scores, sections, and metadata. + """.strip() + + def to_autogen_function(self) -> dict[str, Any]: + """Convert RAG tool to Autogen function format. + + Returns: + Function definition for Autogen agents + """ + return { + "name": "rag_search", + "description": self.get_tool_description(), + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Search query for finding relevant document content", + }, + "context": { + "type": "string", + "description": "Additional context to improve search relevance (optional)", + }, + }, + "required": ["query"], + }, + "function": self.search, + } diff --git a/unstract/prompt-service-helpers/digraph_generation/README.md b/unstract/prompt-service-helpers/digraph_generation/README.md new file mode 100644 index 0000000000..cecae4c039 --- /dev/null +++ b/unstract/prompt-service-helpers/digraph_generation/README.md @@ -0,0 +1,231 @@ +# Digraph Generation Module - Prompt Service Helpers + +This module generates directed graphs (digraphs) using Microsoft Autogen's DiGraphBuilder and GraphFlow for orchestrating sequential multi-agent data extraction workflows. + +## Autogen Integration + +This module directly uses Autogen components: +- **`autogen_agentchat.agents.AssistantAgent`** - For creating individual extraction agents +- **`autogen_agentchat.teams.DiGraphBuilder`** - For building directed graph structure +- **`autogen_agentchat.teams.GraphFlow`** - For sequential team execution + +The implementation creates actual Autogen objects (not code generation) for immediate execution. + +## Overview + +The digraph generation module takes extraction specifications from the answer_prompt payload and creates executable Autogen GraphFlow code. It analyzes field dependencies, determines required agents, and generates a workflow that coordinates specialized extraction agents. + +## Components + +### 1. **digraph_task.py** +Main celery task that generates Autogen GraphFlow code. + +**Key Features:** +- Parses extraction specifications from answer_prompt format +- Creates specialized agents based on field types and requirements +- Generates dependency graphs with proper execution order +- Outputs executable Python code using Autogen's DiGraphBuilder +- Supports parallel execution where possible + +**Task Name:** `generate_extraction_digraph` + +**Inputs:** +- `extraction_spec`: Extraction specification containing: + - `outputs`: List of fields to extract (same as "prompts" in answer_prompt) + - `tool_settings`: Configuration including `enable_challenge` + - `dependencies`: Field dependencies (optional) +- `previous_stage_output`: Output from chunking/embedding stage containing: + - `doc_id`: Document identifier for RAG access + +**Outputs:** +- `autogen_code`: Executable Python code for Autogen GraphFlow +- `agents`: List of agent configurations +- `edges`: List of edge configurations +- `execution_plan`: Execution flow information +- `metadata`: Additional processing information + +### 2. **spec_parser.py** +Parser for extraction specifications with predefined agent configurations. + +**Predefined Agent Types:** +1. **Generic Data Extraction Agent** + - Tools: RAG, Calculator + - Purpose: General text field extraction + +2. **Table Data Extraction Agent** + - Tools: Calculator + - Purpose: Tabular data extraction and processing + +3. **Omniparse Data Extraction Agent** + - Tools: Calculator + - Purpose: Complex layouts, visual elements, non-standard formats + +4. **Challenger Agent** (optional) + - Tools: RAG, Calculator + - Purpose: Validation and quality assurance + +5. **Data Collation Agent** (always included) + - Tools: String concatenation + - Purpose: Combining and formatting final output + +**Key Methods:** +- `parse(extraction_spec)`: Parse answer_prompt format to structured spec +- `create_agent_system_message()`: Generate system messages for agents +- `validate_spec()`: Validate the parsed specification + +## Agent Selection Logic + +The module automatically determines which agents to use based on field characteristics: + +- **Table fields**: Uses `table_data_extraction_agent` +- **Visual/complex fields**: Uses `omniparse_data_extraction_agent` +- **Standard text fields**: Uses `generic_data_extraction_agent` +- **Challenger**: Added if `tool_settings.enable_challenge = true` +- **Collation**: Always added as the final agent + +## Generated Autogen Code Structure + +The output follows the official Autogen DiGraphBuilder pattern: + +```python +from autogen_agentchat.agents import AssistantAgent +from autogen_agentchat.teams import DiGraphBuilder, GraphFlow + +# 1. Define agents +agent_field1 = AssistantAgent("agent_field1", system_message="...") +agent_field2 = AssistantAgent("agent_field2", system_message="...") +challenger_agent = AssistantAgent("challenger_agent", system_message="...") +collation_agent = AssistantAgent("collation_agent", system_message="...") + +# 2. Build the graph +builder = DiGraphBuilder() +builder.add_node(agent_field1).add_node(agent_field2).add_node(challenger_agent).add_node(collation_agent) + +# 3. Add edges with conditions +builder.add_edge(agent_field1, challenger_agent) +builder.add_edge(agent_field2, challenger_agent) +builder.add_edge(challenger_agent, collation_agent, condition=lambda msg: "approved" in msg.content.lower()) + +# 4. Build and run +graph = builder.build() +flow = GraphFlow(participants=builder.get_participants(), graph=graph) +stream = flow.run_stream(task="Extract data from document") +``` + +## Usage Example + +```python +from celery import Celery +from digraph_generation.digraph_task import generate_extraction_digraph + +# Example extraction specification (answer_prompt format) +extraction_spec = { + "outputs": [ + { + "name": "company_name", + "prompt": "Extract the company name from the document", + "type": "text", + "required": True + }, + { + "name": "financial_table", + "prompt": "Extract the financial data table", + "type": "table", + "required": True + } + ], + "tool_settings": { + "enable_challenge": True + }, + "dependencies": { + "financial_table": ["company_name"] # Table extraction depends on company name + } +} + +# Previous stage output from chunking/embedding +previous_output = { + "doc_id": "abc123-def456", + "chunk_count": 25, + "embedding_count": 25 +} + +# Generate the digraph +result = generate_extraction_digraph.delay( + extraction_spec=extraction_spec, + previous_stage_output=previous_output +) + +# Get the result +output = result.get() +print("Generated Autogen code:") +print(output["autogen_code"]) + +# Execute the generated code +exec(output["autogen_code"]) +``` + +## Dependency Analysis + +The module analyzes field dependencies in two ways: + +1. **Explicit Dependencies**: Specified in the `dependencies` field +2. **Implicit Dependencies**: Detected from variable references in prompts (e.g., `{{other_field}}`) + +Dependencies determine the execution order and edge conditions in the generated graph. + +## Integration with Celery Chain + +This task fits into the extraction pipeline as follows: + +``` +Text Extraction → Chunking & Embedding → Digraph Generation → Agent Execution +``` + +The `doc_id` from the chunking stage enables RAG functionality in the extraction agents. + +## Agent System Messages + +Each agent type has a specialized system message template: + +- **Extraction agents**: Field-specific instructions with RAG/calculator guidance +- **Challenger agent**: Validation instructions with quality criteria +- **Collation agent**: JSON formatting and output structure instructions + +## Graph Features + +The generated graphs support: + +- **Parallel Execution**: Independent fields can be extracted simultaneously +- **Conditional Edges**: Challenger approval gates and dependency conditions +- **Proper Termination**: Well-defined endpoints through collation agent +- **Error Handling**: Built into agent system messages and conditions + +## Configuration + +Environment variables for customization: +- `DIGRAPH_QUEUE`: Celery queue for digraph tasks (default: "processing_queue") +- `DIGRAPH_TIMEOUT`: Task timeout in seconds (default: 600) +- `ENABLE_CHALLENGER_DEFAULT`: Default challenger setting (default: false) + +## Dependencies + +Required packages: +- `celery>=5.3.0` +- `autogen-agentchat` (for Autogen GraphFlow) +- Standard Python libraries + +## Future Enhancements + +- Support for custom agent types +- Dynamic tool assignment based on document characteristics +- Advanced condition logic for complex workflows +- Integration with Unstract's tool registry for dynamic tool discovery +- Support for loops and iterative refinement workflows + +## Validation + +The module includes validation for: +- Field name uniqueness +- Dependency consistency (no circular dependencies) +- Agent configuration completeness +- Graph connectivity and termination conditions diff --git a/unstract/prompt-service-helpers/digraph_generation/__init__.py b/unstract/prompt-service-helpers/digraph_generation/__init__.py new file mode 100644 index 0000000000..281fb0b02b --- /dev/null +++ b/unstract/prompt-service-helpers/digraph_generation/__init__.py @@ -0,0 +1,36 @@ +"""Digraph generation module for Autogen GraphFlow. + +This module provides celery tasks for generating directed graphs that can be used +with Microsoft Autogen's GraphFlow for orchestrating multi-agent data extraction. +""" + +from .digraph_task import execute_graph_flow, generate_extraction_digraph +from .executor import execute_extraction_workflow, validate_graph_flow_data +from .spec_parser import ExtractionSpecParser + +__all__ = [ + "generate_extraction_digraph", + "execute_graph_flow", + "ExtractionSpecParser", + "execute_extraction_workflow", + "validate_graph_flow_data", +] + +# Celery task registration information +CELERY_TASKS = { + "generate_extraction_digraph": { + "task": generate_extraction_digraph, + "name": "generate_extraction_digraph", + "queue": "processing_queue", + "routing_key": "digraph.generation", + "priority": 6, # Higher than chunking + "rate_limit": "50/m", + "time_limit": 600, # 10 minutes + "soft_time_limit": 540, # 9 minutes + } +} + +# Module metadata +__version__ = "0.1.0" +__author__ = "Unstract Team" +__description__ = "Autogen digraph generation for data extraction workflows" diff --git a/unstract/prompt-service-helpers/digraph_generation/digraph_task.py b/unstract/prompt-service-helpers/digraph_generation/digraph_task.py new file mode 100644 index 0000000000..c47a860ad7 --- /dev/null +++ b/unstract/prompt-service-helpers/digraph_generation/digraph_task.py @@ -0,0 +1,928 @@ +"""Celery task for generating Autogen GraphFlow for sequential data extraction. +This task creates actual Autogen DiGraphBuilder instances and GraphFlow +to orchestrate extraction agents in a sequential/dependency-based execution pattern. + +The implementation uses: +- autogen_agentchat.agents.AssistantAgent for individual extraction agents +- autogen_agentchat.teams.DiGraphBuilder to create directed graph structure +- autogen_agentchat.teams.GraphFlow for sequential team execution + +This is designed for sequential team processing where agents execute in a specific +order based on field dependencies and workflow requirements. +""" + +import json +import logging +from typing import Any + +# Autogen imports - using the correct components for sequential teams +from autogen_agentchat.agents import AssistantAgent +from autogen_agentchat.teams import DiGraphBuilder, GraphFlow +from celery import shared_task + +# Note: For sequential team processing, we use DiGraphBuilder which creates a directed graph +# that can be executed by GraphFlow for proper sequential/dependency-based execution +from .spec_parser import ExtractionSpecParser + +logger = logging.getLogger(__name__) + + +@shared_task(bind=True, name="generate_extraction_digraph") +def generate_extraction_digraph( + self, + extraction_spec: dict[str, Any], + previous_stage_output: dict[str, Any] | None = None, +) -> dict[str, Any]: + """Generate Autogen GraphFlow for data extraction. + + Args: + extraction_spec: Data extraction specification containing: + - outputs: List of fields to extract with prompts + - tool_settings: Configuration for tools and agents + - dependencies: Field dependencies (optional) + previous_stage_output: Output from chunking/embedding stage containing: + - doc_id: Document identifier for RAG access + + Returns: + Dict containing: + - graph_flow: Serialized GraphFlow object + - agents: List of created agents + - graph: DiGraph structure + - execution_plan: Execution flow information + - metadata: Additional information + """ + task_id = self.request.id + logger.info(f"[Task {task_id}] Starting Autogen digraph generation") + + try: + # Step 1: Parse the extraction specification + parser = ExtractionSpecParser() + parsed_spec = parser.parse(extraction_spec) + + # Add doc_id from previous stage if available + doc_id = None + if previous_stage_output and "doc_id" in previous_stage_output: + doc_id = previous_stage_output["doc_id"] + logger.info(f"[Task {task_id}] Using doc_id: {doc_id}") + + # Step 2: Get field and agent information + fields = parsed_spec.get("fields", []) + required_agents = parsed_spec.get("required_agents", []) + dependencies = parsed_spec.get("dependencies", {}) + + # Step 3: Generate agent configurations using the parser + agents = generate_agent_configs( + fields=fields, + required_agents=required_agents, + tool_settings=parsed_spec.get("tool_settings", {}), + parser=parser, + doc_id=doc_id, + ) + + # Step 4: Create actual Autogen agents + autogen_agents = create_autogen_agents(agents, doc_id) + + # Step 5: Create Autogen DiGraphBuilder for sequential team processing + # DiGraphBuilder creates a directed graph that enforces execution order + builder = DiGraphBuilder() + + # Add all agents as nodes to the DiGraphBuilder + for agent in autogen_agents: + builder.add_node(agent) + + # Step 6: Add edges to define sequential execution order based on dependencies + add_edges_to_builder( + builder, autogen_agents, dependencies, parsed_spec.get("tool_settings", {}) + ) + + # Step 7: Build the directed graph using Autogen's DiGraphBuilder + # This creates the actual graph structure that GraphFlow will execute + graph = builder.build() + + # Step 8: Create GraphFlow with the built graph for sequential execution + # GraphFlow will execute agents in the order defined by the directed graph + graph_flow = GraphFlow(participants=builder.get_participants(), graph=graph) + + # Verify we're using proper Autogen components + verify_autogen_components(builder, graph_flow) + + # Step 9: Create execution plan + execution_plan = create_execution_plan(agents, dependencies) + + # Prepare result + result = { + "graph_flow": serialize_graph_flow(graph_flow), + "agents": [serialize_agent(agent) for agent in autogen_agents], + "graph": serialize_graph(graph), + "execution_plan": execution_plan, + "metadata": { + "task_id": task_id, + "total_agents": len(autogen_agents), + "doc_id": doc_id, + "extraction_spec": parsed_spec, + }, + } + + logger.info( + f"[Task {task_id}] Autogen digraph generated successfully with " + f"{len(autogen_agents)} agents" + ) + return result + + except Exception as e: + logger.error(f"[Task {task_id}] Error generating Autogen digraph: {str(e)}") + raise + + +def analyze_field_dependencies( + fields: list[dict[str, Any]], explicit_dependencies: dict[str, list[str]] +) -> dict[str, list[str]]: + """Analyze dependencies between fields based on variable references. + + Args: + fields: List of field specifications + explicit_dependencies: Explicitly defined dependencies + + Returns: + Dictionary mapping field names to their dependencies + """ + dependencies = {} + + # Initialize with explicit dependencies + for field_name, deps in explicit_dependencies.items(): + dependencies[field_name] = list(deps) + + # Analyze implicit dependencies from variable references + for field in fields: + field_name = field.get("name", "") + prompt_text = field.get("prompt", "") + + # Find variable references in the prompt (e.g., {{other_field}}) + import re + + variable_pattern = r"\{\{(\w+)\}\}" + referenced_fields = re.findall(variable_pattern, prompt_text) + + if field_name not in dependencies: + dependencies[field_name] = [] + + # Add referenced fields as dependencies + for ref_field in referenced_fields: + if ref_field != field_name and ref_field not in dependencies[field_name]: + dependencies[field_name].append(ref_field) + + return dependencies + + +def generate_agent_configs( + fields: list[dict[str, Any]], + required_agents: list[str], + tool_settings: dict[str, Any], + parser: ExtractionSpecParser, + doc_id: str | None = None, +) -> list[dict[str, Any]]: + """Generate agent configurations based on required agents and fields. + + Args: + fields: List of field specifications + required_agents: List of required agent types + tool_settings: Tool configuration settings + parser: ExtractionSpecParser instance + doc_id: Document ID for RAG access + + Returns: + List of agent configurations + """ + agents = [] + + # Group fields by agent type + fields_by_agent = {} + for field in fields: + agent_type = field.get("agent_type", "generic_data_extraction_agent") + if agent_type not in fields_by_agent: + fields_by_agent[agent_type] = [] + fields_by_agent[agent_type].append(field) + + # Create extraction agents + for agent_type in required_agents: + if agent_type in ["challenger_agent", "data_collation_agent"]: + continue # Handle these separately + + # Create one agent instance per field for this agent type + if agent_type in fields_by_agent: + for field in fields_by_agent[agent_type]: + field_name = field.get("name", "") + system_message = parser.create_agent_system_message(agent_type, field) + + agent_config = { + "name": f"{agent_type}_{field_name}", + "agent_type": "AssistantAgent", + "system_message": system_message, + "tools": parser.get_agent_config(agent_type).get("tools", []), + "field_config": { + "field_name": field_name, + "extraction_agent_type": agent_type, + "doc_id": doc_id, + }, + } + agents.append(agent_config) + + # Add challenger agent if enabled + if "challenger_agent" in required_agents: + challenger_system_message = parser.create_agent_system_message("challenger_agent") + challenger_agent = { + "name": "challenger_agent", + "agent_type": "AssistantAgent", + "system_message": challenger_system_message, + "tools": parser.get_agent_config("challenger_agent").get("tools", []), + "field_config": { + "agent_role": "challenger", + "doc_id": doc_id, + }, + } + agents.append(challenger_agent) + + # Add collation agent (always needed) + if "data_collation_agent" in required_agents: + collation_system_message = parser.create_agent_system_message( + "data_collation_agent", all_fields=fields + ) + collation_agent = { + "name": "data_collation_agent", + "agent_type": "AssistantAgent", + "system_message": collation_system_message, + "tools": parser.get_agent_config("data_collation_agent").get("tools", []), + "field_config": { + "agent_role": "collation", + "output_fields": [f.get("name") for f in fields], + }, + } + agents.append(collation_agent) + + return agents + + +def determine_agent_config( + field: dict[str, Any], + tool_settings: dict[str, Any], + doc_id: str | None = None, +) -> tuple[str, str]: + """Determine agent type and system message for a field. + + Args: + field: Field specification + tool_settings: Tool settings + doc_id: Document ID for RAG + + Returns: + Tuple of (agent_type, system_message) + """ + field_name = field.get("name", "") + prompt = field.get("prompt", "") + field_type = field.get("type", "text").lower() + + # Determine if this is a table extraction + if field_type == "table" or "table" in prompt.lower(): + agent_type = "AssistantAgent" + system_message = f"""You are a table extraction specialist. Your task is to extract the field '{field_name}' from the document. + +Field Description: {prompt} + +Instructions: +1. Search for tabular data in the document that matches the field description +2. Extract the table data accurately, preserving structure +3. Format the output as requested +4. If you need context from the document, use the RAG tool to search for relevant information +5. Be precise and only extract what is explicitly requested + +Tools available: calculator""" + + elif field_type in ["image", "chart", "diagram"] or any( + keyword in prompt.lower() for keyword in ["image", "chart", "diagram", "visual"] + ): + agent_type = "AssistantAgent" + system_message = f"""You are a visual content extraction specialist. Your task is to extract the field '{field_name}' from the document. + +Field Description: {prompt} + +Instructions: +1. Analyze visual elements in the document (charts, diagrams, images) +2. Extract the requested information from visual content +3. Provide accurate descriptions and data from visual elements +4. Use calculation tools if needed for data processing +5. Be thorough in your visual analysis + +Tools available: calculator""" + + else: + # Generic text extraction + agent_type = "AssistantAgent" + rag_instruction = ( + "5. Use the RAG tool to search for relevant context if needed" + if doc_id + else "5. Work with the provided document content" + ) + + system_message = f"""You are a data extraction specialist. Your task is to extract the field '{field_name}' from the document. + +Field Description: {prompt} +Field Type: {field_type} +Required: {field.get('required', False)} + +Instructions: +1. Carefully read and understand the field description +2. Search through the document for information matching this field +3. Extract only the specific information requested +4. Ensure accuracy and completeness +{rag_instruction} +6. If the information is not found, clearly state that it's not available + +Tools available: {"rag, calculator" if doc_id else "calculator"}""" + + return agent_type, system_message + + +def create_challenger_system_message( + fields: list[dict[str, Any]], doc_id: str | None +) -> str: + """Create system message for the challenger agent.""" + field_names = [f.get("name") for f in fields] + required_fields = [f.get("name") for f in fields if f.get("required", False)] + + rag_instruction = ( + "- Use RAG to verify information against the source document" + if doc_id + else "- Verify against the provided document content" + ) + + return f"""You are a quality assurance specialist responsible for validating extracted data. + +Your role: +1. Review all extracted field values from other agents +2. Challenge incorrect, incomplete, or inconsistent extractions +3. Verify that required fields are properly extracted +4. Check for logical consistency between related fields + +Fields to validate: {', '.join(field_names)} +Required fields: {', '.join(required_fields)} + +Validation approach: +- Check accuracy against source material +{rag_instruction} +- Ensure completeness for required fields +- Identify inconsistencies or errors +- Provide specific feedback for corrections + +If you find issues, clearly state what needs to be corrected and why. +If extractions are accurate, approve them for final collation. + +Tools available: {"rag, calculator" if doc_id else "calculator"}""" + + +def create_collation_system_message(fields: list[dict[str, Any]]) -> str: + """Create system message for the collation agent.""" + field_names = [f.get("name") for f in fields] + + return f"""You are a data collation specialist responsible for combining all extracted field values into the final output. + +Your role: +1. Collect all validated field values from extraction agents +2. Resolve any remaining conflicts between extractions +3. Format the final output as a structured JSON object +4. Ensure all required fields are included +5. Apply any final formatting or transformations + +Fields to collate: {', '.join(field_names)} + +Output format: +{{ +{', '.join([f' "{name}": "extracted_value"' for name in field_names])} +}} + +Instructions: +- Use the most recent validated values for each field +- If multiple values exist for a field, use your judgment to select the best one +- Ensure the output JSON is properly formatted +- Include null values for fields that couldn't be extracted + +Tools available: string_operations""" + + +def generate_edge_configs( + agents: list[dict[str, Any]], + dependencies: dict[str, list[str]], + tool_settings: dict[str, Any], +) -> list[dict[str, Any]]: + """Generate edge configurations based on dependencies. + + Args: + agents: List of agent configurations + dependencies: Field dependencies + tool_settings: Tool settings + + Returns: + List of edge configurations + """ + edges = [] + agent_map = { + agent["field_config"].get("field_name"): agent["name"] + for agent in agents + if "field_name" in agent.get("field_config", {}) + } + + # Add dependency edges + for field_name, dep_fields in dependencies.items(): + if field_name in agent_map: + target_agent = agent_map[field_name] + + for dep_field in dep_fields: + if dep_field in agent_map: + source_agent = agent_map[dep_field] + + edge = { + "source": source_agent, + "target": target_agent, + "condition": None, # No condition for dependency edges + } + edges.append(edge) + + # Add edges to challenger (if enabled) + has_challenger = tool_settings.get("enable_challenge", False) + if has_challenger: + extraction_agents = [ + agent["name"] + for agent in agents + if agent.get("field_config", {}).get("agent_role") != "challenger" + and agent.get("field_config", {}).get("agent_role") != "collation" + ] + + for agent_name in extraction_agents: + edge = { + "source": agent_name, + "target": "challenger_agent", + "condition": None, + } + edges.append(edge) + + # Add edges to collation + if has_challenger: + # Collation depends on challenger + edge = { + "source": "challenger_agent", + "target": "collation_agent", + "condition": 'lambda msg: "approved" in msg.content.lower() or "validated" in msg.content.lower()', + } + edges.append(edge) + else: + # Collation depends directly on extraction agents + extraction_agents = [ + agent["name"] + for agent in agents + if agent.get("field_config", {}).get("agent_role") != "collation" + ] + + for agent_name in extraction_agents: + edge = { + "source": agent_name, + "target": "collation_agent", + "condition": None, + } + edges.append(edge) + + return edges + + +def create_autogen_agents( + agent_configs: list[dict[str, Any]], + doc_id: str | None = None, +) -> list[AssistantAgent]: + """Create actual Autogen AssistantAgent instances. + + Args: + agent_configs: List of agent configuration dictionaries + doc_id: Document ID for RAG context + + Returns: + List of AssistantAgent instances + """ + agents = [] + + # LLM configuration + llm_config = { + "model": "gpt-4", + "temperature": 0.1, + "timeout": 300, + } + + for config in agent_configs: + # Create AssistantAgent + agent = AssistantAgent( + name=config["name"], + system_message=config["system_message"], + llm_config=llm_config, + ) + + agents.append(agent) + + return agents + + +def add_edges_to_builder( + builder: DiGraphBuilder, + agents: list[AssistantAgent], + dependencies: dict[str, list[str]], + tool_settings: dict[str, Any], +) -> None: + """Add edges to the DiGraphBuilder based on dependencies. + + Args: + builder: DiGraphBuilder instance + agents: List of AssistantAgent instances + dependencies: Field dependencies + tool_settings: Tool settings + """ + # Create agent lookup by name + agent_map = {agent.name: agent for agent in agents} + + # Add dependency edges + for field_name, dep_fields in dependencies.items(): + # Find target agent for this field + target_agent = None + for agent in agents: + if field_name in agent.name: + target_agent = agent + break + + if target_agent: + for dep_field in dep_fields: + # Find source agent for dependency + source_agent = None + for agent in agents: + if dep_field in agent.name: + source_agent = agent + break + + if source_agent: + builder.add_edge(source_agent, target_agent) + + # Add edges to challenger if enabled + has_challenger = tool_settings.get("enable_challenge", False) + challenger_agent = None + collation_agent = None + + # Find special agents + for agent in agents: + if "challenger_agent" in agent.name: + challenger_agent = agent + elif "data_collation_agent" in agent.name: + collation_agent = agent + + if has_challenger and challenger_agent: + # All extraction agents → challenger + for agent in agents: + if agent != challenger_agent and agent != collation_agent: + builder.add_edge(agent, challenger_agent) + + # Challenger → collation with condition + if collation_agent: + builder.add_edge( + challenger_agent, + collation_agent, + condition=lambda msg: "approved" in msg.content.lower() + or "validated" in msg.content.lower(), + ) + else: + # Direct extraction agents → collation + if collation_agent: + for agent in agents: + if agent != collation_agent: + builder.add_edge(agent, collation_agent) + + +def serialize_graph_flow(graph_flow: GraphFlow) -> dict[str, Any]: + """Serialize GraphFlow object for JSON storage. + + Args: + graph_flow: GraphFlow instance + + Returns: + Serialized representation + """ + return { + "type": "GraphFlow", + "participants": [agent.name for agent in graph_flow.participants], + "graph_info": "Graph structure serialized separately", + } + + +def serialize_agent(agent: AssistantAgent) -> dict[str, Any]: + """Serialize AssistantAgent for JSON storage. + + Args: + agent: AssistantAgent instance + + Returns: + Serialized representation + """ + return { + "name": agent.name, + "type": "AssistantAgent", + "system_message": agent.system_message, + "llm_config": getattr(agent, "llm_config", {}), + } + + +def serialize_graph(graph) -> dict[str, Any]: + """Serialize graph structure for JSON storage. + + Args: + graph: Graph object from DiGraphBuilder + + Returns: + Serialized representation + """ + return { + "type": "DiGraph", + "info": "Graph structure from Autogen DiGraphBuilder", + } + + +def verify_autogen_components(builder: DiGraphBuilder, graph_flow: GraphFlow) -> None: + """Verify that we're using proper Autogen components. + + Args: + builder: DiGraphBuilder instance + graph_flow: GraphFlow instance + """ + logger.info("Verifying Autogen components:") + logger.info(f"DiGraphBuilder type: {type(builder)}") + logger.info(f"GraphFlow type: {type(graph_flow)}") + logger.info(f"Participants count: {len(graph_flow.participants)}") + + # Confirm we're using the right Autogen classes + assert isinstance( + builder, DiGraphBuilder + ), f"Expected DiGraphBuilder, got {type(builder)}" + assert isinstance( + graph_flow, GraphFlow + ), f"Expected GraphFlow, got {type(graph_flow)}" + + logger.info( + "✓ Successfully using Autogen DiGraphBuilder and GraphFlow for sequential team processing" + ) + + +def execute_graph_flow( + graph_flow: GraphFlow, + task: str = "Extract all specified fields from the document accurately.", +) -> dict[str, Any]: + """Execute the GraphFlow and return results. + + Args: + graph_flow: GraphFlow instance to execute + task: Task description for the agents + + Returns: + Execution results + """ + results = {} + final_output = None + + try: + # Run the GraphFlow + stream = graph_flow.run_stream(task=task) + + for event in stream: + logger.info(f"Event: {event.type}, Agent: {event.source}") + + # Store results + if hasattr(event, "content") and event.content: + results[event.source] = event.content + + # Check if this is the final collation result + if event.source == "data_collation_agent": + try: + final_output = json.loads(event.content) + except json.JSONDecodeError: + final_output = event.content + + except Exception as e: + logger.error(f"Error executing GraphFlow: {str(e)}") + raise + + return { + "final_output": final_output, + "all_results": results, + "execution_status": "completed" if final_output else "incomplete", + } + + +def generate_autogen_code( + agents: list[dict[str, Any]], + edges: list[dict[str, Any]], + doc_id: str | None, + extraction_spec: dict[str, Any], +) -> str: + """Generate executable Autogen GraphFlow code using DiGraphBuilder. + + Args: + agents: Agent configurations + edges: Edge configurations + doc_id: Document ID for RAG + extraction_spec: Original extraction specification + + Returns: + Executable Python code string following Autogen DiGraphBuilder format + """ + code_lines = [] + + # Imports following Autogen documentation + code_lines.extend( + [ + "from autogen_agentchat.agents import AssistantAgent", + "from autogen_agentchat.teams import DiGraphBuilder, GraphFlow", + "import json", + "", + ] + ) + + # LLM Configuration + code_lines.extend( + [ + "# LLM Configuration - configure according to your setup", + "llm_config = {", + " 'model': 'gpt-4',", + " 'temperature': 0.1,", + " 'timeout': 300,", + "}", + "", + ] + ) + + # Document context + if doc_id: + code_lines.extend( + [ + "# Document context for RAG", + f'doc_id = "{doc_id}"', + "", + ] + ) + + # Agent definitions + code_lines.append("# Define agents") + for agent in agents: + agent_code = generate_agent_code(agent) + code_lines.extend(agent_code) + + code_lines.append("") + + # Graph builder following Autogen pattern + code_lines.extend( + [ + "# Build the graph using DiGraphBuilder", + "builder = DiGraphBuilder()", + "", + "# Add nodes to the graph", + ] + ) + + # Add nodes using add_node method + for agent in agents: + code_lines.append(f"builder.add_node({agent['name']})") + + code_lines.extend( + [ + "", + "# Add edges to define workflow", + ] + ) + + # Add edges using add_edge method + for edge in edges: + edge_code = generate_edge_code(edge) + code_lines.append(edge_code) + + code_lines.extend( + [ + "", + "# Build the graph", + "graph = builder.build()", + "", + "# Create the GraphFlow", + "flow = GraphFlow(participants=builder.get_participants(), graph=graph)", + "", + "# Define the extraction task", + 'task = """Extract all the specified fields from the document accurately. ', + "Follow the workflow defined in the graph to ensure proper data extraction.", + 'Each agent should focus on their specific role and pass results to the next agent."""', + "", + "# Run the extraction flow", + "stream = flow.run_stream(task=task)", + "", + "# Process the results", + "results = {}", + "final_output = None", + "", + "for event in stream:", + " print(f'Event: {event.type}') # Event type", + " print(f'Agent: {event.source}') # Source agent", + " print(f'Content: {event.content[:200]}...') # First 200 chars", + " print('---')", + " ", + " # Store results", + " if hasattr(event, 'content') and event.content:", + " results[event.source] = event.content", + " ", + " # Check if this is the final collation result", + " if event.source == 'data_collation_agent':", + " try:", + " final_output = json.loads(event.content)", + " except json.JSONDecodeError:", + " final_output = event.content", + "", + "# Display final results", + "print('\\n=== EXTRACTION COMPLETE ===')", + "if final_output:", + " print('Final extracted data:')", + " if isinstance(final_output, dict):", + " print(json.dumps(final_output, indent=2))", + " else:", + " print(final_output)", + "else:", + " print('No final output from collation agent')", + " print('All agent results:')", + " for agent_name, result in results.items():", + " print(f'{agent_name}: {result[:100]}...')", + ] + ) + + return "\n".join(code_lines) + + +def generate_agent_code(agent: dict[str, Any]) -> list[str]: + """Generate code lines for creating an agent.""" + name = agent["name"] + agent_type = agent["agent_type"] + system_message = agent["system_message"] + + # Escape quotes in system message + escaped_message = system_message.replace('"""', '\\"\\"\\"') + + code_lines = [ + f"{name} = {agent_type}(", + f' name="{name}",', + ' system_message="""', + f"{escaped_message}", + ' """,', + " llm_config=llm_config", + ")", + "", + ] + + return code_lines + + +def generate_edge_code(edge: dict[str, Any]) -> str: + """Generate code for creating an edge.""" + source = edge["source"] + target = edge["target"] + condition = edge.get("condition") + + if condition: + return f"builder.add_edge({source}, {target}, condition={condition})" + else: + return f"builder.add_edge({source}, {target})" + + +def create_execution_plan( + agents: list[dict[str, Any]], dependencies: dict[str, list[str]] +) -> dict[str, Any]: + """Create execution plan information.""" + # Simple analysis of the execution flow + extraction_agents = [ + a + for a in agents + if a.get("field_config", {}).get("agent_role") not in ["challenger", "collation"] + ] + has_challenger = any( + a.get("field_config", {}).get("agent_role") == "challenger" for a in agents + ) + has_collation = any( + a.get("field_config", {}).get("agent_role") == "collation" for a in agents + ) + + stages = ["Extraction"] + if has_challenger: + stages.append("Validation") + if has_collation: + stages.append("Collation") + + return { + "stages": stages, + "total_agents": len(agents), + "extraction_agents": len(extraction_agents), + "has_challenger": has_challenger, + "has_collation": has_collation, + "estimated_parallel_extraction": len(extraction_agents) > 1, + } diff --git a/unstract/prompt-service-helpers/digraph_generation/executor.py b/unstract/prompt-service-helpers/digraph_generation/executor.py new file mode 100644 index 0000000000..2a15ce0a82 --- /dev/null +++ b/unstract/prompt-service-helpers/digraph_generation/executor.py @@ -0,0 +1,111 @@ +"""Utility for executing Autogen GraphFlow instances. +This module provides functions to execute the generated GraphFlow and get results. +""" + +import logging +from typing import Any + +logger = logging.getLogger(__name__) + + +def execute_extraction_workflow( + graph_flow_data: dict[str, Any], + task_description: str | None = None, + doc_id: str | None = None, +) -> dict[str, Any]: + """Execute an extraction workflow using GraphFlow. + + Args: + graph_flow_data: Serialized GraphFlow data from digraph generation + task_description: Custom task description (optional) + doc_id: Document ID for context (optional) + + Returns: + Extraction results including final output and intermediate results + """ + try: + # Deserialize the GraphFlow (this would need to be implemented based on + # how Autogen handles serialization/deserialization) + # For now, this is a placeholder showing the intended structure + + # Default task if not provided + if not task_description: + task_description = f"""Extract all specified fields from the document accurately. + Document ID: {doc_id if doc_id else 'Not specified'} + Follow the workflow defined in the graph to ensure proper data extraction. + Each agent should focus on their specific role and pass results to the next agent.""" + + logger.info( + f"Starting extraction workflow with {len(graph_flow_data.get('agents', []))} agents" + ) + + # This would execute the actual GraphFlow + # graph_flow = deserialize_graph_flow(graph_flow_data) + # results = execute_graph_flow(graph_flow, task_description) + + # Placeholder results structure + results = { + "final_output": {}, + "all_results": {}, + "execution_status": "not_implemented", + "message": "GraphFlow execution needs to be implemented based on Autogen's serialization format", + } + + return results + + except Exception as e: + logger.error(f"Error executing extraction workflow: {str(e)}") + return { + "final_output": None, + "all_results": {}, + "execution_status": "error", + "error": str(e), + } + + +def validate_graph_flow_data(graph_flow_data: dict[str, Any]) -> dict[str, Any]: + """Validate GraphFlow data structure. + + Args: + graph_flow_data: GraphFlow data to validate + + Returns: + Validation results + """ + issues = [] + warnings = [] + + # Check required fields + required_fields = ["graph_flow", "agents", "graph", "metadata"] + for field in required_fields: + if field not in graph_flow_data: + issues.append(f"Missing required field: {field}") + + # Check agents + agents = graph_flow_data.get("agents", []) + if not agents: + issues.append("No agents found in GraphFlow data") + + for agent in agents: + if not isinstance(agent, dict): + issues.append(f"Invalid agent format: {agent}") + continue + + if "name" not in agent: + issues.append("Agent missing name field") + + # Check for collation agent + has_collation = any( + "data_collation_agent" in agent.get("name", "") for agent in agents + ) + if not has_collation: + warnings.append( + "No collation agent found - results may not be properly formatted" + ) + + return { + "is_valid": len(issues) == 0, + "issues": issues, + "warnings": warnings, + "agent_count": len(agents), + } diff --git a/unstract/prompt-service-helpers/digraph_generation/requirements.txt b/unstract/prompt-service-helpers/digraph_generation/requirements.txt new file mode 100644 index 0000000000..aee02880af --- /dev/null +++ b/unstract/prompt-service-helpers/digraph_generation/requirements.txt @@ -0,0 +1,15 @@ +# Dependencies for the digraph_generation module + +# Core dependencies +celery>=5.3.0 + +# Autogen dependencies - REQUIRED for this implementation +autogen-agentchat>=0.2.0 +pyautogen>=0.2.0 + +# Optional dependencies for enhanced functionality +typing-extensions>=4.0.0 + +# Development dependencies (optional) +# pytest>=7.0.0 +# pytest-mock>=3.10.0 diff --git a/unstract/prompt-service-helpers/digraph_generation/spec_parser.py b/unstract/prompt-service-helpers/digraph_generation/spec_parser.py new file mode 100644 index 0000000000..71a5564efc --- /dev/null +++ b/unstract/prompt-service-helpers/digraph_generation/spec_parser.py @@ -0,0 +1,418 @@ +"""Parser for extraction specifications from the answer_prompt payload. +Converts the prompt service payload format into Autogen DiGraphBuilder format. +""" + +import logging +from typing import Any + +logger = logging.getLogger(__name__) + + +class ExtractionSpecParser: + """Parser for extraction specifications compatible with Autogen DiGraphBuilder. + Uses predefined agent types with specific tools as per the specification. + """ + + # Predefined agent configurations with their tools + AGENT_CONFIGS = { + "generic_data_extraction_agent": { + "tools": ["rag", "calculator"], + "system_message_template": """You are a generic data extraction agent. Extract the field '{field_name}' from the document. + +Field Description: {prompt} +Field Type: {field_type} +Required: {required} + +Instructions: +1. Use RAG to search for relevant information in the document +2. Extract the specific information requested for this field +3. Use calculator for any numerical computations if needed +4. Ensure accuracy and completeness +5. If information is not found, clearly state it's unavailable + +Tools available: RAG, Calculator""", + }, + "table_data_extraction_agent": { + "tools": ["calculator"], + "system_message_template": """You are a table data extraction agent. Extract the field '{field_name}' from tabular data in the document. + +Field Description: {prompt} +Field Type: {field_type} +Required: {required} + +Instructions: +1. Identify and extract data from tables in the document +2. Preserve table structure and relationships +3. Use calculator for any numerical computations or aggregations +4. Ensure data accuracy and proper formatting +5. Handle missing or incomplete table data appropriately + +Tools available: Calculator""", + }, + "omniparse_data_extraction_agent": { + "tools": ["calculator"], + "system_message_template": """You are an omniparse data extraction agent specialized in complex document formats. Extract the field '{field_name}' from the document. + +Field Description: {prompt} +Field Type: {field_type} +Required: {required} + +Instructions: +1. Handle complex document layouts and formats +2. Extract from non-standard or visual elements +3. Use calculator for any numerical computations +4. Process complex data structures and relationships +5. Maintain accuracy across different document formats + +Tools available: Calculator""", + }, + "challenger_agent": { + "tools": ["rag", "calculator"], + "system_message_template": """You are a challenger agent responsible for validating extracted data quality. + +Your role: +1. Review all extracted field values from other agents +2. Challenge incorrect, incomplete, or inconsistent extractions +3. Use RAG to verify information against the source document +4. Use calculator to verify numerical computations +5. Ensure required fields are properly extracted +6. Check for logical consistency between related fields + +Validation criteria: +- Accuracy against source material +- Completeness for required fields +- Consistency across related fields +- Proper formatting and data types + +If you find issues, specify what needs correction and why. +If extractions are accurate, approve them for final collation. + +Tools available: RAG, Calculator""", + }, + "data_collation_agent": { + "tools": ["string_concatenation"], + "system_message_template": """You are a data collation agent responsible for combining all validated field values into the final output. + +Your role: +1. Collect all validated field values from extraction agents +2. Resolve any remaining conflicts between extractions +3. Format the final output as a structured JSON object +4. Ensure all required fields are included +5. Apply final formatting and string operations + +Output format: +{{ +{field_json_structure} +}} + +Instructions: +- Use the most recent validated values for each field +- Apply string concatenation and formatting as needed +- Ensure proper JSON structure +- Include null for missing fields + +Tools available: String concatenation""", + }, + } + + def parse(self, extraction_spec: dict[str, Any]) -> dict[str, Any]: + """Parse the extraction specification from answer_prompt payload format. + + Args: + extraction_spec: Raw extraction specification containing: + - outputs: List of field specifications (same as "prompts" in answer_prompt) + - tool_settings: Tool configuration including enable_challenge + - Other metadata from answer_prompt payload + + Returns: + Parsed specification ready for Autogen DiGraphBuilder + """ + try: + # Extract components from answer_prompt payload structure + outputs = extraction_spec.get( + "outputs", [] + ) # Same as "prompts" in answer_prompt + tool_settings = extraction_spec.get("tool_settings", {}) + + # Parse fields from outputs + fields = self._parse_fields(outputs) + + # Determine which agents are needed based on field types + required_agents = self._determine_required_agents(fields, tool_settings) + + # Parse dependencies + dependencies = self._parse_dependencies( + extraction_spec.get("dependencies", {}), fields + ) + + # Extract metadata + metadata = self._extract_metadata(extraction_spec) + + parsed_spec = { + "fields": fields, + "required_agents": required_agents, + "tool_settings": tool_settings, + "dependencies": dependencies, + "metadata": metadata, + } + + logger.info( + f"Parsed extraction spec: {len(fields)} fields, {len(required_agents)} agents" + ) + return parsed_spec + + except Exception as e: + logger.error(f"Error parsing extraction spec: {str(e)}") + raise ValueError(f"Invalid extraction specification: {str(e)}") + + def _parse_fields(self, outputs: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Parse field specifications from outputs (prompts from answer_prompt). + + Args: + outputs: List of output/prompt specifications + + Returns: + List of parsed field specifications + """ + fields = [] + + for output in outputs: + # Extract field information following answer_prompt structure + field = { + "name": output.get("name", ""), + "prompt": output.get("prompt", ""), + "type": output.get("type", "text"), + "required": output.get("required", False), + "chunk_size": output.get("chunk_size", 1000), + "chunk_overlap": output.get("chunk_overlap", 200), + } + + # Determine agent type based on field characteristics + field["agent_type"] = self._determine_field_agent_type(field) + + # Validate required fields + if not field["name"] or not field["prompt"]: + logger.warning(f"Skipping invalid field: {field}") + continue + + fields.append(field) + + return fields + + def _determine_field_agent_type(self, field: dict[str, Any]) -> str: + """Determine which agent type should handle this field. + + Args: + field: Field specification + + Returns: + Agent type name + """ + field_type = field.get("type", "text").lower() + prompt = field.get("prompt", "").lower() + + # Check for table extraction + if field_type == "table" or "table" in prompt: + return "table_data_extraction_agent" + + # Check for complex/omniparse extraction + if field_type in ["image", "chart", "diagram", "complex"] or any( + keyword in prompt + for keyword in ["image", "chart", "diagram", "visual", "complex", "layout"] + ): + return "omniparse_data_extraction_agent" + + # Default to generic extraction + return "generic_data_extraction_agent" + + def _determine_required_agents( + self, fields: list[dict[str, Any]], tool_settings: dict[str, Any] + ) -> list[str]: + """Determine which agents are required based on fields and settings. + + Args: + fields: List of field specifications + tool_settings: Tool settings from payload + + Returns: + List of required agent type names + """ + required_agents = set() + + # Add agents based on field requirements + for field in fields: + required_agents.add(field["agent_type"]) + + # Add challenger agent if enabled + if tool_settings.get("enable_challenge", False): + required_agents.add("challenger_agent") + + # Always add collation agent + required_agents.add("data_collation_agent") + + return list(required_agents) + + def _parse_dependencies( + self, explicit_dependencies: dict[str, Any], fields: list[dict[str, Any]] + ) -> dict[str, list[str]]: + """Parse field dependencies including implicit dependencies from variable references. + + Args: + explicit_dependencies: Explicitly defined dependencies + fields: List of field specifications + + Returns: + Complete dependency mapping + """ + dependencies = {} + + # Start with explicit dependencies + for field_name, deps in explicit_dependencies.items(): + if isinstance(deps, list): + dependencies[field_name] = [dep for dep in deps if isinstance(dep, str)] + elif isinstance(deps, str): + dependencies[field_name] = [deps] + + # Analyze implicit dependencies from variable references in prompts + for field in fields: + field_name = field.get("name", "") + prompt_text = field.get("prompt", "") + + # Find variable references like {{other_field}} + import re + + variable_pattern = r"\{\{(\w+)\}\}" + referenced_fields = re.findall(variable_pattern, prompt_text) + + if field_name not in dependencies: + dependencies[field_name] = [] + + # Add referenced fields as dependencies + for ref_field in referenced_fields: + if ref_field != field_name and ref_field not in dependencies[field_name]: + dependencies[field_name].append(ref_field) + + return dependencies + + def _extract_metadata(self, extraction_spec: dict[str, Any]) -> dict[str, Any]: + """Extract metadata from the extraction specification. + + Args: + extraction_spec: Raw extraction specification + + Returns: + Extracted metadata + """ + metadata = {} + + # Common metadata fields from answer_prompt payload + metadata_fields = [ + "tool_id", + "run_id", + "execution_id", + "file_hash", + "file_path", + "file_name", + "log_events_id", + "execution_source", + "user_data", + ] + + for field in metadata_fields: + if field in extraction_spec: + metadata[field] = extraction_spec[field] + + return metadata + + def get_agent_config(self, agent_type: str) -> dict[str, Any]: + """Get the configuration for a specific agent type. + + Args: + agent_type: Type of agent + + Returns: + Agent configuration including tools and system message template + """ + return self.AGENT_CONFIGS.get(agent_type, {}) + + def create_agent_system_message( + self, + agent_type: str, + field: dict[str, Any] | None = None, + all_fields: list[dict[str, Any]] | None = None, + ) -> str: + """Create system message for an agent. + + Args: + agent_type: Type of agent + field: Field specification for extraction agents + all_fields: All fields for challenger/collation agents + + Returns: + Formatted system message + """ + config = self.get_agent_config(agent_type) + template = config.get("system_message_template", "") + + if field: + # For extraction agents + return template.format( + field_name=field.get("name", ""), + prompt=field.get("prompt", ""), + field_type=field.get("type", "text"), + required=field.get("required", False), + ) + elif all_fields and agent_type == "data_collation_agent": + # For collation agent, create JSON structure + field_json_structure = ",\n".join( + [ + f' "{field.get("name", "")}": "extracted_value"' + for field in all_fields + ] + ) + return template.format(field_json_structure=field_json_structure) + + return template + + def validate_spec(self, parsed_spec: dict[str, Any]) -> dict[str, Any]: + """Validate the parsed extraction specification. + + Args: + parsed_spec: Parsed extraction specification + + Returns: + Validation results + """ + issues = [] + warnings = [] + + fields = parsed_spec.get("fields", []) + dependencies = parsed_spec.get("dependencies", {}) + + # Check if we have fields + if not fields: + issues.append("No fields specified for extraction") + + # Validate field names are unique + field_names = [field.get("name") for field in fields] + if len(field_names) != len(set(field_names)): + issues.append("Duplicate field names found") + + # Validate dependencies reference existing fields + for field_name, deps in dependencies.items(): + if field_name not in field_names: + warnings.append(f"Dependency for unknown field: {field_name}") + + for dep in deps: + if dep not in field_names: + warnings.append( + f"Field '{field_name}' depends on unknown field: {dep}" + ) + + return { + "is_valid": len(issues) == 0, + "issues": issues, + "warnings": warnings, + "field_count": len(fields), + "required_agent_count": len(parsed_spec.get("required_agents", [])), + } diff --git a/unstract/prompt-service-helpers/indexing/README.md b/unstract/prompt-service-helpers/indexing/README.md new file mode 100644 index 0000000000..6956e44924 --- /dev/null +++ b/unstract/prompt-service-helpers/indexing/README.md @@ -0,0 +1,148 @@ +# Extraction Module - Prompt Service Helpers + +This module provides celery tasks and utilities for the document extraction pipeline, focusing on text chunking and embedding generation as part of the agentic data extraction process. + +## Overview + +The extraction module handles the critical step of preparing extracted text for vector database storage and retrieval. It takes raw text extracted from documents (PDFs, etc.) and processes it into searchable chunks with embeddings. + +## Components + +### 1. **chunking_embedding_task.py** +Main celery task that processes text for chunking and embedding generation. + +**Key Features:** +- Retrieves extracted text from MinIO using SDK's FileStorage +- Generates unique document IDs using adapter configurations +- Chunks text based on user-defined parameters (chunk_size, chunk_overlap) +- Creates embeddings using configured embedding models +- Stores chunks and embeddings in vector database +- Supports document reindexing +- Calculates token usage for the entire document + +**Task Name:** `chunking_embedding_task` + +**Inputs:** +- `minio_text_path`: Path to extracted text in MinIO +- `chunking_params`: User-defined chunking configuration + - `chunk_size`: Size of text chunks (default: 1000) + - `chunk_overlap`: Overlap between chunks (default: 200) + - `enable_smart_chunking`: Auto-adjust based on LLM context (optional) +- `embedding_params`: Embedding configuration + - `adapter_instance_id`: Embedding model adapter ID + - `vector_db_instance_id`: Vector database adapter ID + - `platform_key`: Authentication key +- `llm_config`: Optional LLM configuration for smart chunking + +**Outputs:** +- `doc_id`: Unique document identifier for retrieval +- `chunk_count`: Number of chunks created +- `embedding_count`: Number of embeddings generated +- `total_input_tokens`: Total tokens in the input file +- `metadata`: Processing details and statistics + +### 2. **token_helper.py** +Utility for token calculation and model context window management. + +**Key Features:** +- Fetches model pricing and context data from LiteLLM's public repository +- Caches model data locally with configurable TTL (default: 7 days) +- Counts tokens using tiktoken (with fallback approximation) +- Determines optimal chunk sizes based on model context windows +- Supports all major LLM providers (OpenAI, Anthropic, Meta, etc.) + +**Data Source:** +``` +https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json +``` + +**Main Methods:** +- `count_tokens(text, model_name)`: Count tokens in text +- `get_model_context_window(model_name)`: Get model's max context size +- `calculate_optimal_chunk_size(model_name)`: Calculate recommended chunk size + +## Usage Example + +```python +from celery import Celery +from indexing.chunking_embedding_task import process_chunking_and_embedding + +# Trigger the celery task +result = process_chunking_and_embedding.delay( + minio_text_path="bucket/documents/extracted_text.txt", + chunking_params={ + "chunk_size": 1000, + "chunk_overlap": 200, + "enable_smart_chunking": True + }, + embedding_params={ + "adapter_instance_id": "embedding-adapter-123", + "vector_db_instance_id": "vectordb-adapter-456", + "platform_key": "your-platform-key" + }, + llm_config={ + "model_name": "gpt-4", + "provider": "openai" + } +) + +# Get task result +output = result.get() +print(f"Document ID: {output['doc_id']}") +print(f"Total tokens: {output['total_input_tokens']}") +print(f"Chunks created: {output['chunk_count']}") +``` + +## Integration with Celery Chain + +This task is designed to be part of the larger extraction pipeline: + +``` +Text Extraction → Chunking & Embedding → Digraph Generation → Agent Execution +``` + +The output `doc_id` from this task can be used by downstream agents to retrieve relevant context using RAG (Retrieval-Augmented Generation). + +## Environment Variables + +Required environment variables for MinIO access: +- `MINIO_ENDPOINT`: MinIO server endpoint +- `MINIO_ACCESS_KEY`: Access key for MinIO +- `MINIO_SECRET_KEY`: Secret key for MinIO +- `MINIO_BUCKET_NAME`: Default bucket name +- `MINIO_SECURE`: Use HTTPS (true/false) + +## SDK Dependencies + +This module heavily utilizes the Unstract SDK: +- `FileStorage`: For MinIO file operations +- `ToolAdapter`: For adapter configuration retrieval +- `ToolUtils`: For hashing and utility functions +- `VectorDB`: For vector database operations +- `Embedding`: For embedding generation + +## Design Decisions + +1. **No Redundant Storage**: Chunk metadata is not stored separately as it's already handled by the backend vector database. + +2. **SDK-First Approach**: All operations use SDK methods to ensure consistency with the rest of the platform. + +3. **Index Key Generation**: Uses the same logic as `index_v2.py` to generate unique document IDs based on file hash and configuration. + +4. **Token Awareness**: Calculates and tracks token usage for cost estimation and optimization. + +5. **Smart Chunking**: Optional feature that adjusts chunk size based on the LLM's context window to optimize retrieval and processing. + +## Performance Considerations + +- Model data is cached locally to avoid repeated API calls +- Documents are checked for existing indexing to avoid redundant processing +- Chunking and embedding happen in a single pass for efficiency +- Vector database operations are batched by the SDK + +## Future Enhancements + +- Support for different chunking strategies (semantic, paragraph-based) +- Parallel processing of large documents +- Support for incremental updates to existing documents +- Integration with document structure detection for smarter chunking diff --git a/unstract/prompt-service-helpers/indexing/__init__.py b/unstract/prompt-service-helpers/indexing/__init__.py new file mode 100644 index 0000000000..dfbead0f9a --- /dev/null +++ b/unstract/prompt-service-helpers/indexing/__init__.py @@ -0,0 +1,33 @@ +"""Extraction module for document chunking and embedding generation. + +This module provides celery tasks for the document extraction pipeline, +handling text chunking and embedding generation as part of the agentic +data extraction process. +""" + +from .chunking_embedding_task import process_chunking_and_embedding +from .token_helper import TokenCalculationHelper + +__all__ = [ + "process_chunking_and_embedding", + "TokenCalculationHelper", +] + +# Celery task registration information +CELERY_TASKS = { + "chunking_embedding_task": { + "task": process_chunking_and_embedding, + "name": "chunking_embedding_task", + "queue": "processing_queue", # Can be configured based on requirements + "routing_key": "extraction.chunking", + "priority": 5, # Medium priority + "rate_limit": "100/m", # Rate limiting if needed + "time_limit": 3600, # 1 hour timeout + "soft_time_limit": 3300, # Soft limit at 55 minutes + } +} + +# Module metadata +__version__ = "0.1.0" +__author__ = "Unstract Team" +__description__ = "Document chunking and embedding extraction tasks" diff --git a/unstract/prompt-service-helpers/indexing/chunking_embedding_task.py b/unstract/prompt-service-helpers/indexing/chunking_embedding_task.py new file mode 100644 index 0000000000..887df106b6 --- /dev/null +++ b/unstract/prompt-service-helpers/indexing/chunking_embedding_task.py @@ -0,0 +1,440 @@ +"""Celery task for chunking and embedding text extraction results. +This task handles text chunking based on user-defined parameters and generates +embeddings for vector database storage, following the pattern from index_v2.py. +""" + +import json +import logging +from typing import Any + +from celery import shared_task +from llama_index.core import Document +from llama_index.core.vector_stores import ( + FilterOperator, + MetadataFilter, + MetadataFilters, + VectorStoreQuery, + VectorStoreQueryResult, +) + +from unstract.sdk.adapter import ToolAdapter +from unstract.sdk.embedding import Embedding +from unstract.sdk.exceptions import IndexingError, SdkError +from unstract.sdk.file_storage.impl import FileStorage +from unstract.sdk.file_storage.provider import FileStorageProvider +from unstract.sdk.tool.base import BaseTool +from unstract.sdk.utils.tool_utils import ToolUtils +from unstract.sdk.vector_db import VectorDB + +from .token_helper import TokenCalculationHelper + +logger = logging.getLogger(__name__) + + +@shared_task(bind=True, name="chunking_embedding_task") +def process_chunking_and_embedding( + self, + minio_text_path: str, + chunking_params: dict[str, Any], + embedding_params: dict[str, Any], + llm_config: dict[str, Any] | None = None, +) -> dict[str, Any]: + """Process text chunking and embedding generation following index_v2 pattern. + + Args: + minio_text_path: Path to the extracted text file in MinIO + chunking_params: Parameters for chunking including: + - chunk_size: Target chunk size in tokens/characters + - chunk_overlap: Overlap between chunks + - enable_smart_chunking: Enable intelligent chunking based on LLM context (optional) + embedding_params: Parameters for embedding including: + - adapter_instance_id: ID of the embedding adapter + - vector_db_instance_id: ID of the vector database adapter + - platform_key: Platform key for authentication + - x2text_instance_id: ID of the x2text adapter (optional) + - file_hash: Hash of the file (optional, will be calculated if not provided) + llm_config: Optional LLM configuration for context size determination + + Returns: + Dict containing: + - doc_id: Document ID for accessing chunks and embeddings + - minio_text_path: Original text file path + - chunk_count: Number of chunks created + - embedding_count: Number of embeddings generated + - total_input_tokens: Total tokens in the input file + - metadata: Additional processing metadata + """ + task_id = self.request.id + logger.info(f"[Task {task_id}] Starting chunking and embedding task") + + try: + # Step 1: Initialize FileStorage for MinIO access using SDK + file_storage = FileStorage( + provider=FileStorageProvider.MINIO, **get_minio_config() + ) + + # Step 2: Retrieve the extracted text from MinIO using SDK read method + logger.info(f"[Task {task_id}] Retrieving text from MinIO: {minio_text_path}") + text_content = file_storage.read(path=minio_text_path, mode="r", encoding="utf-8") + + if not text_content: + raise ValueError(f"No text content found at {minio_text_path}") + + # Step 3: Initialize token calculation helper + token_helper = TokenCalculationHelper() + + # Calculate total input tokens in the file + model_name = ( + llm_config.get("model_name", "gpt-3.5-turbo") + if llm_config + else "gpt-3.5-turbo" + ) + total_input_tokens = token_helper.count_tokens(text_content, model_name) + logger.info(f"[Task {task_id}] Total input tokens in file: {total_input_tokens}") + + # Step 4: Get chunking parameters from user input + chunk_size = chunking_params.get("chunk_size", 1000) + chunk_overlap = chunking_params.get("chunk_overlap", 200) + enable_smart_chunking = chunking_params.get("enable_smart_chunking", False) + + # Optional: Adjust chunk size based on LLM context if smart chunking is enabled + if enable_smart_chunking and llm_config: + provider = llm_config.get("provider") + optimal_chunk_size = token_helper.calculate_optimal_chunk_size( + model_name, provider, target_utilization=0.25 + ) + if optimal_chunk_size: + chunk_size = min(chunk_size, optimal_chunk_size) + logger.info( + f"[Task {task_id}] Adjusted chunk size to {chunk_size} " + f"based on model {model_name} context window" + ) + + # Step 5: Initialize SDK components + platform_key = embedding_params.get("platform_key", "") + tool = BaseTool(platform_key=platform_key) + + # Step 6: Generate document ID using SDK methods (similar to index key in index_v2) + doc_id = generate_index_key_with_sdk( + tool=tool, + file_hash=embedding_params.get("file_hash"), + file_path=minio_text_path, + embedding_instance_id=embedding_params.get("adapter_instance_id"), + vector_db_instance_id=embedding_params.get("vector_db_instance_id"), + x2text_instance_id=embedding_params.get("x2text_instance_id"), + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + file_storage=file_storage, + ) + + logger.info(f"[Task {task_id}] Generated doc_id: {doc_id}") + + # Step 7: Initialize embedding using SDK + embedding = Embedding( + tool=tool, + adapter_instance_id=embedding_params.get("adapter_instance_id"), + ) + + # Step 8: Initialize vector DB using SDK + vector_db = VectorDB( + tool=tool, + adapter_instance_id=embedding_params.get("vector_db_instance_id"), + embedding=embedding, + ) + + # Step 9: Check if document is already indexed using SDK methods + doc_already_indexed = is_document_indexed(doc_id, embedding, vector_db, tool) + + reindex = embedding_params.get("reindex", False) + if doc_already_indexed and not reindex: + logger.info( + f"[Task {task_id}] Document already indexed with doc_id: {doc_id}" + ) + return { + "doc_id": doc_id, + "minio_text_path": minio_text_path, + "chunk_count": 0, + "embedding_count": 0, + "total_input_tokens": total_input_tokens, + "metadata": { + "already_indexed": True, + "task_id": task_id, + }, + } + + # Step 10: Prepare document for chunking (following index_v2 pattern) + logger.info(f"[Task {task_id}] Preparing document for chunking") + + # Create document structure similar to index_v2 + full_text = [ + { + "section": "full", + "text_contents": str(text_content), + } + ] + + # Convert to LlamaIndex Document using SDK patterns + documents = prepare_documents_with_sdk(doc_id, full_text, tool) + + # Step 11: Delete existing nodes if reindexing using SDK methods + if reindex and doc_already_indexed: + logger.info(f"[Task {task_id}] Deleting existing nodes for reindexing") + try: + vector_db.delete(ref_doc_id=doc_id) + tool.stream_log(f"Deleted existing nodes for {doc_id}") + except Exception as e: + logger.error(f"[Task {task_id}] Error deleting nodes: {e}") + raise SdkError(f"Error deleting nodes for {doc_id}: {e}") from e + + # Step 12: Perform indexing with chunking using SDK methods + logger.info( + f"[Task {task_id}] Indexing with chunk_size: {chunk_size}, " + f"chunk_overlap: {chunk_overlap}" + ) + + try: + # Using SDK's vector_db.index_document method (follows index_v2._trigger_indexing) + tool.stream_log("Adding nodes to vector db...") + + # The SDK's index_document method handles chunking internally + nodes = vector_db.index_document( + documents, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + show_progress=True, + ) + + tool.stream_log("File has been indexed successfully") + logger.info( + f"[Task {task_id}] Successfully indexed {len(nodes) if nodes else 0} nodes" + ) + + # Count the nodes created + chunk_count = len(nodes) if nodes else 0 + + except Exception as e: + tool.stream_log( + f"Error adding nodes to vector db: {e}", + level="ERROR", + ) + raise IndexingError(str(e)) from e + + # Step 13: Count embeddings (one per chunk) + embedding_count = chunk_count # Assuming one embedding per chunk + + # Calculate average chunk size in tokens + avg_chunk_tokens = total_input_tokens // chunk_count if chunk_count > 0 else 0 + + # Prepare response + result = { + "doc_id": doc_id, + "minio_text_path": minio_text_path, + "chunk_count": chunk_count, + "embedding_count": embedding_count, + "total_input_tokens": total_input_tokens, + "metadata": { + "chunk_size": chunk_size, + "chunk_overlap": chunk_overlap, + "average_chunk_tokens": avg_chunk_tokens, + "text_length": len(text_content), + "model_name": model_name, + "task_id": task_id, + "reindexed": reindex and doc_already_indexed, + }, + } + + logger.info(f"[Task {task_id}] Chunking and embedding completed successfully") + return result + + except Exception as e: + logger.error(f"[Task {task_id}] Error in chunking and embedding: {str(e)}") + raise + + +def generate_index_key_with_sdk( + tool: BaseTool, + file_hash: str | None, + file_path: str, + embedding_instance_id: str, + vector_db_instance_id: str, + x2text_instance_id: str | None, + chunk_size: int, + chunk_overlap: int, + file_storage: FileStorage, +) -> str: + """Generate a unique index key using SDK methods. + This follows the pattern from index_v2.generate_index_key but uses SDK methods. + + Args: + tool: BaseTool instance for SDK operations + file_hash: Optional pre-computed file hash + file_path: Path to the file + embedding_instance_id: Embedding adapter instance ID + vector_db_instance_id: Vector DB adapter instance ID + x2text_instance_id: Optional x2text adapter instance ID + chunk_size: Chunk size for splitting + chunk_overlap: Chunk overlap for splitting + file_storage: FileStorage instance for hash calculation + + Returns: + Unique index key (doc_id) for the document + """ + if not file_hash: + # Use SDK's file storage method to calculate file hash + file_hash = file_storage.get_hash_from_file(path=file_path) + + # Use SDK's ToolAdapter to get adapter configurations + # This ensures we're using the same configuration as index_v2 + index_key = { + "file_hash": file_hash, + "chunk_size": str(chunk_size), # Convert to string for compatibility + "chunk_overlap": str(chunk_overlap), # Convert to string for compatibility + } + + # Get adapter configurations using SDK methods (if tool has platform connection) + try: + # Get vector DB config using SDK + vector_db_config = ToolAdapter.get_adapter_config(tool, vector_db_instance_id) + if vector_db_config: + index_key["vector_db_config"] = vector_db_config + + # Get embedding config using SDK + embedding_config = ToolAdapter.get_adapter_config(tool, embedding_instance_id) + if embedding_config: + index_key["embedding_config"] = embedding_config + + # Get x2text config if provided + if x2text_instance_id: + x2text_config = ToolAdapter.get_adapter_config(tool, x2text_instance_id) + if x2text_config: + index_key["x2text_config"] = x2text_config + + except Exception as e: + logger.warning( + f"Could not retrieve adapter configs, using instance IDs instead: {e}" + ) + # Fallback to using instance IDs directly + index_key["vector_db_instance_id"] = vector_db_instance_id + index_key["embedding_instance_id"] = embedding_instance_id + if x2text_instance_id: + index_key["x2text_instance_id"] = x2text_instance_id + + # Use SDK's ToolUtils.hash_str to generate the hash + # Sort keys to ensure consistent hashing + hashed_index_key = ToolUtils.hash_str( + json.dumps(index_key, sort_keys=True), + hash_method="sha256", # Use SHA256 for better uniqueness + ) + + return hashed_index_key + + +def is_document_indexed( + doc_id: str, + embedding: Embedding, + vector_db: VectorDB, + tool: BaseTool, +) -> bool: + """Check if a document is already indexed using SDK methods. + This follows the pattern from index_v2.is_document_indexed. + + Args: + doc_id: Document ID to check + embedding: Embedding instance + vector_db: Vector DB instance + tool: BaseTool instance for logging + + Returns: + True if document is already indexed, False otherwise + """ + try: + # Create filter for doc_id using SDK patterns + doc_id_eq_filter = MetadataFilter.from_dict( + {"key": "doc_id", "operator": FilterOperator.EQ, "value": doc_id} + ) + filters = MetadataFilters(filters=[doc_id_eq_filter]) + + # Query with minimal embedding using SDK method + q = VectorStoreQuery( + query_embedding=embedding.get_query_embedding(" "), + doc_ids=[doc_id], + filters=filters, + ) + + # Check if nodes exist using SDK's vector_db.query + result: VectorStoreQueryResult = vector_db.query(query=q) + + if len(result.nodes) > 0: + tool.stream_log(f"Found {len(result.nodes)} nodes for {doc_id}") + return True + else: + tool.stream_log(f"No nodes found for {doc_id}") + return False + + except Exception as e: + logger.warning( + f"Error querying vector DB: {str(e)}, proceeding to index", + exc_info=True, + ) + return False + + +def prepare_documents_with_sdk( + doc_id: str, full_text: list[dict[str, Any]], tool: BaseTool +) -> list[Document]: + """Prepare documents for indexing using SDK patterns. + This follows the pattern from index_v2._prepare_documents. + + Args: + doc_id: Document identifier + full_text: List of text sections with metadata + tool: BaseTool instance for logging + + Returns: + List of LlamaIndex Document objects + """ + documents = [] + + try: + for item in full_text: + text = item["text_contents"] + + # Create Document using LlamaIndex (as used by SDK) + document = Document( + text=text, + doc_id=doc_id, + metadata={"section": item["section"]}, + ) + document.id_ = doc_id + documents.append(document) + + tool.stream_log(f"Number of documents: {len(documents)}") + return documents + + except Exception as e: + tool.stream_log( + f"Error while processing documents {doc_id}: {e}", + level="ERROR", + ) + raise SdkError( + f"Error while processing documents for indexing {doc_id}: {e}" + ) from e + + +def get_minio_config() -> dict[str, Any]: + """Get MinIO configuration from environment or settings. + This uses SDK-compatible configuration format. + + Returns: + Dict with MinIO configuration parameters for SDK FileStorage + """ + import os + + # Return configuration in the format expected by SDK's FileStorage + return { + "endpoint": os.getenv("MINIO_ENDPOINT", "localhost:9000"), + "access_key": os.getenv("MINIO_ACCESS_KEY", "minioadmin"), + "secret_key": os.getenv("MINIO_SECRET_KEY", "minioadmin"), + "secure": os.getenv("MINIO_SECURE", "false").lower() == "true", + "bucket_name": os.getenv("MINIO_BUCKET_NAME", "unstract-data"), + } diff --git a/unstract/prompt-service-helpers/indexing/requirements.txt b/unstract/prompt-service-helpers/indexing/requirements.txt new file mode 100644 index 0000000000..2a7ffef1a8 --- /dev/null +++ b/unstract/prompt-service-helpers/indexing/requirements.txt @@ -0,0 +1,20 @@ +# Dependencies for the extraction module + +# Core dependencies +celery>=5.3.0 +requests>=2.31.0 +tiktoken>=0.5.0 # For token counting + +# Unstract SDK (should be installed from internal registry) +# unstract-sdk>=0.77.1 + +# LlamaIndex for document processing +llama-index-core>=0.10.0 + +# Optional dependencies for enhanced functionality +python-magic>=0.4.27 # For file type detection + +# Development dependencies (optional) +# pytest>=7.0.0 +# pytest-celery>=0.0.0 +# pytest-mock>=3.10.0 diff --git a/unstract/prompt-service-helpers/indexing/token_helper.py b/unstract/prompt-service-helpers/indexing/token_helper.py new file mode 100644 index 0000000000..36b8949d05 --- /dev/null +++ b/unstract/prompt-service-helpers/indexing/token_helper.py @@ -0,0 +1,230 @@ +"""Helper for token calculation using LiteLLM model pricing data.""" + +import json +import logging +from datetime import UTC, datetime, timedelta +from typing import Any + +import requests +import tiktoken + +from unstract.sdk.file_storage.impl import FileStorage +from unstract.sdk.file_storage.provider import FileStorageProvider + +logger = logging.getLogger(__name__) + +MODEL_PRICES_URL = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" +MODEL_PRICES_TTL_IN_DAYS = 7 +MODEL_PRICES_FILE_PATH = "/tmp/model_prices_and_context.json" + + +class TokenCalculationHelper: + """Helper class for calculating tokens and context sizes for LLM models.""" + + def __init__( + self, + url: str = MODEL_PRICES_URL, + ttl_days: int = MODEL_PRICES_TTL_IN_DAYS, + file_path: str = MODEL_PRICES_FILE_PATH, + ): + self.ttl_days = ttl_days + self.url = url + self.file_path = file_path + + # Use local file storage for caching + self.file_storage = FileStorage(provider=FileStorageProvider.LOCAL) + self.model_data = self._get_model_data() + + def get_model_context_window( + self, model_name: str, provider: str | None = None + ) -> int | None: + """Get the context window size for a specific model. + + Args: + model_name: Name of the model + provider: Optional provider name to disambiguate models + + Returns: + Context window size in tokens, or None if not found + """ + if not self.model_data: + return None + + # Try exact match first + if model_name in self.model_data: + model_info = self.model_data[model_name] + return model_info.get("max_input_tokens") or model_info.get("max_tokens") + + # Filter models that contain the model name + filtered_models = { + k: v for k, v in self.model_data.items() if model_name in k.lower() + } + + if not filtered_models: + # Try partial match + filtered_models = { + k: v + for k, v in self.model_data.items() + if any(part in k.lower() for part in model_name.lower().split("-")) + } + + # If provider is specified, filter by provider + if provider and filtered_models: + for key, model_info in filtered_models.items(): + if provider.lower() in model_info.get("litellm_provider", "").lower(): + return model_info.get("max_input_tokens") or model_info.get( + "max_tokens" + ) + + # Return the first match if no provider specified + if filtered_models: + first_model = next(iter(filtered_models.values())) + return first_model.get("max_input_tokens") or first_model.get("max_tokens") + + return None + + def count_tokens(self, text: str, model_name: str | None = "gpt-3.5-turbo") -> int: + """Count tokens in the given text using the appropriate tokenizer. + + Args: + text: Text to count tokens for + model_name: Model name to determine the tokenizer + + Returns: + Number of tokens in the text + """ + try: + # Try to get the appropriate encoding for the model + if "gpt-4" in model_name.lower() or "gpt-3" in model_name.lower(): + encoding_name = "cl100k_base" + elif "codex" in model_name.lower(): + encoding_name = "p50k_base" + else: + # Default to cl100k_base for newer models + encoding_name = "cl100k_base" + + encoding = tiktoken.get_encoding(encoding_name) + return len(encoding.encode(text)) + + except Exception as e: + logger.warning( + f"Error counting tokens with tiktoken: {e}. " + f"Falling back to approximation." + ) + # Fallback: approximate 1 token ≈ 4 characters + return len(text) // 4 + + def calculate_optimal_chunk_size( + self, + model_name: str, + provider: str | None = None, + target_utilization: float = 0.25, + ) -> int: + """Calculate optimal chunk size based on model's context window. + + Args: + model_name: Name of the model + provider: Optional provider name + target_utilization: Fraction of context window to use per chunk (default 0.25) + + Returns: + Optimal chunk size in tokens + """ + context_window = self.get_model_context_window(model_name, provider) + + if not context_window: + # Default chunk size if model not found + logger.warning( + f"Model {model_name} not found in pricing data. Using default chunk size." + ) + return 1000 + + # Calculate optimal chunk size as a fraction of context window + optimal_size = int(context_window * target_utilization) + + # Apply reasonable bounds + min_chunk_size = 500 + max_chunk_size = 8000 + + return max(min_chunk_size, min(optimal_size, max_chunk_size)) + + def _get_model_data(self) -> dict[str, Any] | None: + """Get model pricing and context data, using cache if available. + + Returns: + Dictionary of model data, or None if unavailable + """ + try: + # Check if cached file exists and is still valid + if self.file_storage.exists(self.file_path): + file_mtime = self.file_storage.modification_time(self.file_path) + file_expiry_date = file_mtime + timedelta(days=self.ttl_days) + file_expiry_date_utc = file_expiry_date.replace(tzinfo=UTC) + now_utc = datetime.now().replace(tzinfo=UTC) + + if now_utc < file_expiry_date_utc: + logger.info(f"Reading model data from cache: {self.file_path}") + file_contents = self.file_storage.read( + self.file_path, mode="r", encoding="utf-8" + ) + return json.loads(file_contents) + + # Fetch fresh data from URL + return self._fetch_and_save_data() + + except Exception as e: + logger.error(f"Error getting model data: {e}") + # Return default model data as fallback + return self._get_default_model_data() + + def _fetch_and_save_data(self) -> dict[str, Any] | None: + """Fetch model data from URL and cache it. + + Returns: + Dictionary of model data, or None if fetch fails + """ + try: + logger.info(f"Fetching model data from {self.url}") + response = requests.get(self.url, timeout=10) + response.raise_for_status() + json_data = response.json() + + # Cache the data + self.file_storage.write( + path=self.file_path, + mode="w", + encoding="utf-8", + data=json.dumps(json_data, indent=2), + ) + + logger.info(f"Model data cached successfully at {self.file_path}") + return json_data + + except Exception as e: + logger.error(f"Error fetching model data: {e}") + return self._get_default_model_data() + + def _get_default_model_data(self) -> dict[str, Any]: + """Get default model data as fallback. + + Returns: + Dictionary with default model configurations + """ + return { + "gpt-4": {"max_tokens": 8192, "litellm_provider": "openai"}, + "gpt-4-32k": {"max_tokens": 32768, "litellm_provider": "openai"}, + "gpt-4-turbo": {"max_tokens": 128000, "litellm_provider": "openai"}, + "gpt-3.5-turbo": {"max_tokens": 4096, "litellm_provider": "openai"}, + "gpt-3.5-turbo-16k": {"max_tokens": 16384, "litellm_provider": "openai"}, + "claude-2": {"max_tokens": 100000, "litellm_provider": "anthropic"}, + "claude-3-opus": {"max_tokens": 200000, "litellm_provider": "anthropic"}, + "claude-3-sonnet": {"max_tokens": 200000, "litellm_provider": "anthropic"}, + "claude-3-haiku": {"max_tokens": 200000, "litellm_provider": "anthropic"}, + "llama-2-7b": {"max_tokens": 4096, "litellm_provider": "together_ai"}, + "llama-2-13b": {"max_tokens": 4096, "litellm_provider": "together_ai"}, + "llama-2-70b": {"max_tokens": 4096, "litellm_provider": "together_ai"}, + "llama-3-8b": {"max_tokens": 8192, "litellm_provider": "together_ai"}, + "llama-3-70b": {"max_tokens": 8192, "litellm_provider": "together_ai"}, + "mistral-7b": {"max_tokens": 8192, "litellm_provider": "together_ai"}, + "mixtral-8x7b": {"max_tokens": 32768, "litellm_provider": "together_ai"}, + }