|  | 
|  | 1 | +import re | 
|  | 2 | +from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union | 
|  | 3 | + | 
|  | 4 | +import json | 
|  | 5 | + | 
|  | 6 | +from .base import BaseAgentTemplate | 
|  | 7 | + | 
|  | 8 | +if TYPE_CHECKING: | 
|  | 9 | +    from swift.llm.infer import Function | 
|  | 10 | +    from swift.llm.template import Prompt | 
|  | 11 | + | 
|  | 12 | + | 
|  | 13 | +class SeedAgentTemplate(BaseAgentTemplate): | 
|  | 14 | +    TOOL_CALL_START = '<seed:tool_call>' | 
|  | 15 | +    TOOL_CALL_END = '</seed:tool_call>' | 
|  | 16 | +    FUNCTION_TAG = 'function' | 
|  | 17 | +    PARAMETER_TAG = 'parameter' | 
|  | 18 | + | 
|  | 19 | +    _PY_TYPE_MAPPING = { | 
|  | 20 | +        'string': 'str', | 
|  | 21 | +        'number': 'int', | 
|  | 22 | +        'integer': 'int', | 
|  | 23 | +        'boolean': 'bool', | 
|  | 24 | +        'array': 'list', | 
|  | 25 | +    } | 
|  | 26 | + | 
|  | 27 | +    @staticmethod | 
|  | 28 | +    def _py_type(t: str) -> str: | 
|  | 29 | +        return SeedAgentTemplate._PY_TYPE_MAPPING.get(t, 'Any') | 
|  | 30 | + | 
|  | 31 | +    def get_toolcall(self, response: str) -> List['Function']: | 
|  | 32 | +        from swift.llm.infer import Function | 
|  | 33 | + | 
|  | 34 | +        res_list = re.findall(rf'{self.TOOL_CALL_START}(.+?){self.TOOL_CALL_END}', response, re.DOTALL) | 
|  | 35 | +        if not res_list: | 
|  | 36 | +            return super().get_toolcall(response) | 
|  | 37 | + | 
|  | 38 | +        functions = [] | 
|  | 39 | +        for res in res_list: | 
|  | 40 | +            func_name_match = re.search(rf'<{self.FUNCTION_TAG}=([^>]+)>', res) | 
|  | 41 | +            if not func_name_match: | 
|  | 42 | +                continue | 
|  | 43 | + | 
|  | 44 | +            func_name = func_name_match.group(1) | 
|  | 45 | +            param_matches = re.findall(rf'<{self.PARAMETER_TAG}=([^>]+)>(.*?)</{self.PARAMETER_TAG}>', res, re.DOTALL) | 
|  | 46 | +            arguments = {name: value for name, value in param_matches} | 
|  | 47 | +            functions.append(Function(name=func_name, arguments=arguments)) | 
|  | 48 | + | 
|  | 49 | +        return functions | 
|  | 50 | + | 
|  | 51 | +    def _get_tool_responses(self, tool_messages: List[dict]) -> str: | 
|  | 52 | +        responses = [f"<seed:bos>tool\n{tool_message['content']}<seed:eos>" for tool_message in tool_messages] | 
|  | 53 | +        return ''.join(responses) + '<seed:bos>assistant\n' | 
|  | 54 | + | 
|  | 55 | +    def _format_tool_responses( | 
|  | 56 | +        self, | 
|  | 57 | +        assistant_content: str, | 
|  | 58 | +        tool_messages: List[dict], | 
|  | 59 | +    ) -> Tuple[str, 'Prompt']: | 
|  | 60 | +        with_action = self.keyword.action in assistant_content and self.keyword.action_input in assistant_content | 
|  | 61 | +        if with_action: | 
|  | 62 | +            return super()._format_tool_responses(assistant_content, tool_messages) | 
|  | 63 | + | 
|  | 64 | +        formatted_tool_responses = self._get_tool_responses(tool_messages) | 
|  | 65 | +        return assistant_content, ['<seed:eos>', formatted_tool_responses] | 
|  | 66 | + | 
|  | 67 | +    def _build_tool_def_string(self, tool: dict) -> str: | 
|  | 68 | +        """Helper to build a single tool definition string.""" | 
|  | 69 | +        func = tool.get('function', {}) | 
|  | 70 | +        func_name = func.get('name') | 
|  | 71 | + | 
|  | 72 | +        if not func_name: | 
|  | 73 | +            return '' | 
|  | 74 | + | 
|  | 75 | +        parameters = func.get('parameters', {}) | 
|  | 76 | +        properties = parameters.get('properties', {}) | 
|  | 77 | +        params = [ | 
|  | 78 | +            f"{name}: {self._py_type(spec.get('type', 'any'))}" for name, spec in properties.items() | 
|  | 79 | +            if isinstance(spec, dict) | 
|  | 80 | +        ] | 
|  | 81 | +        param_str = ','.join(params) | 
|  | 82 | + | 
|  | 83 | +        docstring_parts = ['    """', f'    {func.get("description", "").strip()}'] | 
|  | 84 | + | 
|  | 85 | +        if properties: | 
|  | 86 | +            docstring_parts.append('\n    Args:') | 
|  | 87 | +            required_params = parameters.get('required', []) | 
|  | 88 | +            for name, spec in properties.items(): | 
|  | 89 | +                if isinstance(spec, dict): | 
|  | 90 | +                    req_tag = '[必填]' if name in required_params else '[选填]' | 
|  | 91 | +                    desc = spec.get('description', '') | 
|  | 92 | +                    type_str = self._py_type(spec.get('type', 'any')) | 
|  | 93 | +                    docstring_parts.append(f'    - {name} ({type_str}) {req_tag}: {desc}') | 
|  | 94 | + | 
|  | 95 | +        returns_props = func.get('returns', {}).get('properties', {}) | 
|  | 96 | +        if returns_props: | 
|  | 97 | +            docstring_parts.append('\n    Returns:') | 
|  | 98 | +            for name, spec in returns_props.items(): | 
|  | 99 | +                desc = spec.get('description', '') | 
|  | 100 | +                type_str = self._py_type(spec.get('type', 'any')) | 
|  | 101 | +                docstring_parts.append(f'    - {name} ({type_str}): {desc}') | 
|  | 102 | + | 
|  | 103 | +        docstring_parts.append('\n    """') | 
|  | 104 | +        docstring = '\n'.join(docstring_parts) | 
|  | 105 | + | 
|  | 106 | +        return f'Function:\ndef {func_name}({param_str}):\n{docstring}' | 
|  | 107 | + | 
|  | 108 | +    def _format_tools(self, tools: List[Union[str, dict]], system: Optional[str] = None, user_message=None) -> str: | 
|  | 109 | +        if not tools: | 
|  | 110 | +            return system or '' | 
|  | 111 | + | 
|  | 112 | +        tool_defs = [ | 
|  | 113 | +            tool_def for tool in tools if (wrapped_tool := self.wrap_tool(tool)).get('type') == 'function' and | 
|  | 114 | +            (tool_def := self._build_tool_def_string(wrapped_tool)) != '' | 
|  | 115 | +        ] | 
|  | 116 | +        tool_defs_joined = '\n\n'.join(tool_defs) | 
|  | 117 | + | 
|  | 118 | +        tool_call_format_instruction = ( | 
|  | 119 | +            '工具调用请遵循如下格式:\n' | 
|  | 120 | +            f'{self.TOOL_CALL_START}\n' | 
|  | 121 | +            f'<{self.FUNCTION_TAG}=example_function_name>\n' | 
|  | 122 | +            f'<{self.PARAMETER_TAG}=example_parameter_1>value_1</{self.PARAMETER_TAG}>\n' | 
|  | 123 | +            f'<{self.PARAMETER_TAG}=example_parameter_2>This is the value for the second parameter\n' | 
|  | 124 | +            'that can span\n' | 
|  | 125 | +            f'multiple lines</{self.PARAMETER_TAG}>\n' | 
|  | 126 | +            f'</{self.FUNCTION_TAG}>\n' | 
|  | 127 | +            f'{self.TOOL_CALL_END}') | 
|  | 128 | + | 
|  | 129 | +        split_token = '<seed:eos><seed:bos>system' | 
|  | 130 | + | 
|  | 131 | +        if system and split_token in system: | 
|  | 132 | +            parts = system.split(split_token, 1) | 
|  | 133 | +            return f'{parts[0]}\n\n{tool_defs_joined}\n{tool_call_format_instruction}\n{split_token}{parts[1]}' | 
|  | 134 | +        else: | 
|  | 135 | +            doubao_prompt = ('You are Doubao, a helpful AI assistant. ' | 
|  | 136 | +                             'You may call one or more functions to assist with the user query.') | 
|  | 137 | +            return (f'{doubao_prompt}\n\n{tool_defs_joined}\n{tool_call_format_instruction}\n' | 
|  | 138 | +                    f'{split_token}\n{system or ""}') | 
|  | 139 | + | 
|  | 140 | +    def _format_tool_calls(self, tool_call_messages: List[dict]) -> str: | 
|  | 141 | +        formatted_calls = [] | 
|  | 142 | +        for message in tool_call_messages: | 
|  | 143 | +            tool_call = self._parse_tool_call(message['content']) | 
|  | 144 | +            func_name = tool_call['name'] | 
|  | 145 | +            arguments = tool_call.get('arguments', {}) | 
|  | 146 | + | 
|  | 147 | +            call_parts = [f'<{self.FUNCTION_TAG}={func_name}>'] | 
|  | 148 | +            for arg_name, arg_value in arguments.items(): | 
|  | 149 | +                arg_value_str = arg_value if isinstance(arg_value, str) else json.dumps(arg_value, ensure_ascii=False) | 
|  | 150 | +                call_parts.append(f'<{self.PARAMETER_TAG}={arg_name}>{arg_value_str}</{self.PARAMETER_TAG}>') | 
|  | 151 | + | 
|  | 152 | +            call_parts.append(f'</{self.FUNCTION_TAG}>') | 
|  | 153 | +            call_parts_joined = '\n'.join(call_parts) | 
|  | 154 | + | 
|  | 155 | +            full_call = f'{self.TOOL_CALL_START}\n{call_parts_joined}\n{self.TOOL_CALL_END}' | 
|  | 156 | +            formatted_calls.append(full_call) | 
|  | 157 | +        return '\n'.join(formatted_calls) | 
0 commit comments