diff --git a/src/nodetool/nodes/llms/__init__.py b/src/nodetool/nodes/llms/__init__.py new file mode 100644 index 0000000..7616960 --- /dev/null +++ b/src/nodetool/nodes/llms/__init__.py @@ -0,0 +1,3 @@ +from .synthesizer import Synthesizer + +__all__ = ["Synthesizer"] diff --git a/src/nodetool/nodes/llms/synthesizer.py b/src/nodetool/nodes/llms/synthesizer.py new file mode 100644 index 0000000..329e4a9 --- /dev/null +++ b/src/nodetool/nodes/llms/synthesizer.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +import re +from jinja2 import Environment, BaseLoader +from pydantic import Field + +from nodetool.metadata.types import ( + Message, + MessageTextContent, + LanguageModel, + Provider, +) +from nodetool.chat.providers import Chunk +from nodetool.workflows.base_node import BaseNode +from nodetool.workflows.processing_context import ProcessingContext + + +class Synthesizer(BaseNode): + """Generate text from a Jinja2 prompt using dynamic properties.""" + + _is_dynamic = True + + model: LanguageModel = Field( + default=LanguageModel(), + description="Model to use for generation", + ) + system: str = Field( + default="You are a helpful assistant.", + description="System prompt for the LLM", + ) + prompt: str = Field( + default="", + description="Prompt template rendered with dynamic properties", + ) + max_tokens: int = Field(default=4096, ge=1, le=100000) + + @classmethod + def get_title(cls) -> str: + return "Synthesizer" + + @classmethod + def get_basic_fields(cls) -> list[str]: + return ["prompt", "model"] + + async def process(self, context: ProcessingContext) -> str: + if self.model.provider == Provider.Empty: + raise ValueError("Select a model") + + env = Environment(loader=BaseLoader()) + + template_str = self.prompt + for var in re.findall(r"{{\s*([^|}]+)", template_str): + template_str = template_str.replace(var, var.lower()) + + template = env.from_string(template_str) + properties = {k.lower(): v for k, v in self._dynamic_properties.items()} + user_prompt = template.render(**properties) + + messages = [ + Message(role="system", content=self.system), + Message(role="user", content=[MessageTextContent(text=user_prompt)]), + ] + + result = "" + async for chunk in context.generate_messages( + messages=messages, + provider=self.model.provider, + model=self.model.id, + node_id=self.id, + max_tokens=self.max_tokens, + ): + if isinstance(chunk, Chunk): + context.post_message( + Chunk( + node_id=self.id, + content=chunk.content, + content_type=chunk.content_type, + ) + ) + result += chunk.content + return result diff --git a/tests/nodetool/test_llms.py b/tests/nodetool/test_llms.py new file mode 100644 index 0000000..0e838f6 --- /dev/null +++ b/tests/nodetool/test_llms.py @@ -0,0 +1,32 @@ +import pytest + +from nodetool.workflows.processing_context import ProcessingContext +from nodetool.nodes.llms.synthesizer import Synthesizer +from nodetool.metadata.types import LanguageModel, Provider +from nodetool.chat.providers import Chunk + + +@pytest.fixture +def context(): + return ProcessingContext(user_id="test", auth_token="test") + + +@pytest.mark.asyncio +async def test_synthesizer_process(context, monkeypatch): + node = Synthesizer( + prompt="Hello {{ name }}!", + model=LanguageModel(provider=Provider.OpenAI, id="gpt"), + ) + node._dynamic_properties = {"name": "Alice"} + + async def fake_generate_messages(**kwargs): + messages = kwargs.get("messages") + assert messages[1].content[0].text == "Hello Alice!" + yield Chunk(content="Hi Alice", content_type="text") + + monkeypatch.setattr(context, "generate_messages", fake_generate_messages) + + result = await node.process(context) + assert result == "Hi Alice" + +