Skip to content

Commit 2f8c7e9

Browse files
authored
[template] support SeedAgentTemplate (#6270)
1 parent eb8fda5 commit 2f8c7e9

File tree

6 files changed

+235
-5
lines changed

6 files changed

+235
-5
lines changed

swift/llm/template/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -998,6 +998,8 @@ def _jinja_encode(self, inputs: StdTemplateInputs):
998998
kwargs = {}
999999
if inputs.tools:
10001000
kwargs['tools'] = inputs.tools
1001+
if 'thinking_budget' in inputs.extra_kwargs:
1002+
kwargs['thinking_budget'] = inputs.extra_kwargs.get('thinking_budget', 0)
10011003
text = self.tokenizer.apply_chat_template(
10021004
messages, tokenize=False, add_generation_prompt=add_generation_prompt, **kwargs)
10031005
answer_len = 1 if self.is_training else 0

swift/llm/template/template/seed.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def insert_budget_markers(text: str, tokenizer, interval: int, total_budget: int
9494
return '\n'.join(result)
9595
else:
9696
return ('<seed:cot_budget_reflect>The current thinking budget is 0, so I will '
97-
'directly start answering the question.</seed:cot_budget_reflect>\n\n')
97+
'directly start answering the question.</seed:cot_budget_reflect>\n')
9898

9999
def _prepare_system(self, inputs):
100100
budget = self.get_thinking_budget(inputs)
@@ -143,7 +143,7 @@ def _swift_prepare_inputs(self, inputs: StdTemplateInputs):
143143
message['content'] = (
144144
'<seed:think><seed:cot_budget_reflect>The current thinking budget is 0, '
145145
'so I will directly start answering the question.'
146-
'</seed:cot_budget_reflect>\n\n</seed:think>') + message['content']
146+
'</seed:cot_budget_reflect>\n</seed:think>') + message['content']
147147

148148
def _simplify_context_list(self, context_list, loss_scale_list, inputs):
149149
res, res_loss_scale = super()._simplify_context_list(context_list, loss_scale_list, inputs)
@@ -154,7 +154,6 @@ def _simplify_context_list(self, context_list, loss_scale_list, inputs):
154154
return res, res_loss_scale
155155

156156
def _jinja_encode(self, inputs: StdTemplateInputs):
157-
self._prepare_system(inputs)
158157
return super()._jinja_encode(inputs)
159158

160159

swift/plugin/agent_template/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .qwen import QwenEnAgentTemplate, QwenEnParallelAgentTemplate, QwenZhAgentTemplate, QwenZhParallelAgentTemplate
1010
from .qwen3_coder import Qwen3CoderAgentTemplate
1111
from .react import ReactEnAgentTemplate, ReactZnAgentTemplate
12+
from .seed_oss import SeedAgentTemplate
1213
from .toolbench import ToolBenchAgentTemplate
1314

1415
agent_templates = {
@@ -31,6 +32,7 @@
3132
'llama4': Llama4AgentTemplate,
3233
# ref: https://huggingface.co/deepseek-ai/DeepSeek-V3.1
3334
'deepseek_v3_1': DeepSeekV31AgentTemplate,
35+
'seed_oss': SeedAgentTemplate,
3436
# extra
3537
'react_grpo': ReactGRPOAgentTemplate,
3638
'mistral': MistralAgentTemplate
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import re
2+
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
3+
4+
import json
5+
6+
from .base import BaseAgentTemplate
7+
8+
if TYPE_CHECKING:
9+
from swift.llm.infer import Function
10+
from swift.llm.template import Prompt
11+
12+
13+
class SeedAgentTemplate(BaseAgentTemplate):
14+
TOOL_CALL_START = '<seed:tool_call>'
15+
TOOL_CALL_END = '</seed:tool_call>'
16+
FUNCTION_TAG = 'function'
17+
PARAMETER_TAG = 'parameter'
18+
19+
_PY_TYPE_MAPPING = {
20+
'string': 'str',
21+
'number': 'int',
22+
'integer': 'int',
23+
'boolean': 'bool',
24+
'array': 'list',
25+
}
26+
27+
@staticmethod
28+
def _py_type(t: str) -> str:
29+
return SeedAgentTemplate._PY_TYPE_MAPPING.get(t, 'Any')
30+
31+
def get_toolcall(self, response: str) -> List['Function']:
32+
from swift.llm.infer import Function
33+
34+
res_list = re.findall(rf'{self.TOOL_CALL_START}(.+?){self.TOOL_CALL_END}', response, re.DOTALL)
35+
if not res_list:
36+
return super().get_toolcall(response)
37+
38+
functions = []
39+
for res in res_list:
40+
func_name_match = re.search(rf'<{self.FUNCTION_TAG}=([^>]+)>', res)
41+
if not func_name_match:
42+
continue
43+
44+
func_name = func_name_match.group(1)
45+
param_matches = re.findall(rf'<{self.PARAMETER_TAG}=([^>]+)>(.*?)</{self.PARAMETER_TAG}>', res, re.DOTALL)
46+
arguments = {name: value for name, value in param_matches}
47+
functions.append(Function(name=func_name, arguments=arguments))
48+
49+
return functions
50+
51+
def _get_tool_responses(self, tool_messages: List[dict]) -> str:
52+
responses = [f"<seed:bos>tool\n{tool_message['content']}<seed:eos>" for tool_message in tool_messages]
53+
return ''.join(responses) + '<seed:bos>assistant\n'
54+
55+
def _format_tool_responses(
56+
self,
57+
assistant_content: str,
58+
tool_messages: List[dict],
59+
) -> Tuple[str, 'Prompt']:
60+
with_action = self.keyword.action in assistant_content and self.keyword.action_input in assistant_content
61+
if with_action:
62+
return super()._format_tool_responses(assistant_content, tool_messages)
63+
64+
formatted_tool_responses = self._get_tool_responses(tool_messages)
65+
return assistant_content, ['<seed:eos>', formatted_tool_responses]
66+
67+
def _build_tool_def_string(self, tool: dict) -> str:
68+
"""Helper to build a single tool definition string."""
69+
func = tool.get('function', {})
70+
func_name = func.get('name')
71+
72+
if not func_name:
73+
return ''
74+
75+
parameters = func.get('parameters', {})
76+
properties = parameters.get('properties', {})
77+
params = [
78+
f"{name}: {self._py_type(spec.get('type', 'any'))}" for name, spec in properties.items()
79+
if isinstance(spec, dict)
80+
]
81+
param_str = ','.join(params)
82+
83+
docstring_parts = [' """', f' {func.get("description", "").strip()}']
84+
85+
if properties:
86+
docstring_parts.append('\n Args:')
87+
required_params = parameters.get('required', [])
88+
for name, spec in properties.items():
89+
if isinstance(spec, dict):
90+
req_tag = '[必填]' if name in required_params else '[选填]'
91+
desc = spec.get('description', '')
92+
type_str = self._py_type(spec.get('type', 'any'))
93+
docstring_parts.append(f' - {name} ({type_str}) {req_tag}: {desc}')
94+
95+
returns_props = func.get('returns', {}).get('properties', {})
96+
if returns_props:
97+
docstring_parts.append('\n Returns:')
98+
for name, spec in returns_props.items():
99+
desc = spec.get('description', '')
100+
type_str = self._py_type(spec.get('type', 'any'))
101+
docstring_parts.append(f' - {name} ({type_str}): {desc}')
102+
103+
docstring_parts.append('\n """')
104+
docstring = '\n'.join(docstring_parts)
105+
106+
return f'Function:\ndef {func_name}({param_str}):\n{docstring}'
107+
108+
def _format_tools(self, tools: List[Union[str, dict]], system: Optional[str] = None, user_message=None) -> str:
109+
if not tools:
110+
return system or ''
111+
112+
tool_defs = [
113+
tool_def for tool in tools if (wrapped_tool := self.wrap_tool(tool)).get('type') == 'function' and
114+
(tool_def := self._build_tool_def_string(wrapped_tool)) != ''
115+
]
116+
tool_defs_joined = '\n\n'.join(tool_defs)
117+
118+
tool_call_format_instruction = (
119+
'工具调用请遵循如下格式:\n'
120+
f'{self.TOOL_CALL_START}\n'
121+
f'<{self.FUNCTION_TAG}=example_function_name>\n'
122+
f'<{self.PARAMETER_TAG}=example_parameter_1>value_1</{self.PARAMETER_TAG}>\n'
123+
f'<{self.PARAMETER_TAG}=example_parameter_2>This is the value for the second parameter\n'
124+
'that can span\n'
125+
f'multiple lines</{self.PARAMETER_TAG}>\n'
126+
f'</{self.FUNCTION_TAG}>\n'
127+
f'{self.TOOL_CALL_END}')
128+
129+
split_token = '<seed:eos><seed:bos>system'
130+
131+
if system and split_token in system:
132+
parts = system.split(split_token, 1)
133+
return f'{parts[0]}\n\n{tool_defs_joined}\n{tool_call_format_instruction}\n{split_token}{parts[1]}'
134+
else:
135+
doubao_prompt = ('You are Doubao, a helpful AI assistant. '
136+
'You may call one or more functions to assist with the user query.')
137+
return (f'{doubao_prompt}\n\n{tool_defs_joined}\n{tool_call_format_instruction}\n'
138+
f'{split_token}\n{system or ""}')
139+
140+
def _format_tool_calls(self, tool_call_messages: List[dict]) -> str:
141+
formatted_calls = []
142+
for message in tool_call_messages:
143+
tool_call = self._parse_tool_call(message['content'])
144+
func_name = tool_call['name']
145+
arguments = tool_call.get('arguments', {})
146+
147+
call_parts = [f'<{self.FUNCTION_TAG}={func_name}>']
148+
for arg_name, arg_value in arguments.items():
149+
arg_value_str = arg_value if isinstance(arg_value, str) else json.dumps(arg_value, ensure_ascii=False)
150+
call_parts.append(f'<{self.PARAMETER_TAG}={arg_name}>{arg_value_str}</{self.PARAMETER_TAG}>')
151+
152+
call_parts.append(f'</{self.FUNCTION_TAG}>')
153+
call_parts_joined = '\n'.join(call_parts)
154+
155+
full_call = f'{self.TOOL_CALL_START}\n{call_parts_joined}\n{self.TOOL_CALL_END}'
156+
formatted_calls.append(full_call)
157+
return '\n'.join(formatted_calls)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
{
22
"<think>\\s*</think>\\s*": [0.0],
3-
"<seed:think><seed:cot_budget_reflect>The current thinking budget is 0, so I will directly start answering the question.</seed:cot_budget_reflect>\n\n</seed:think>\\s*": [0.0]
3+
"<seed:think><seed:cot_budget_reflect>The current thinking budget is 0, so I will directly start answering the question.</seed:cot_budget_reflect>\n</seed:think>\\s*": [0.0]
44
}

tests/test_align/test_template/test_agent.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,75 @@ def test_deepseek_v3_1():
442442
assert encoded['input_ids'][-122:] == encoded2['input_ids'][1:]
443443

444444

445+
def test_seed_oss():
446+
agent_template = agent_templates['seed_oss']()
447+
448+
engine = PtEngine('ByteDance-Seed/Seed-OSS-36B-Instruct', load_model=False, download_model=False)
449+
450+
template = engine.default_template
451+
template.agent_template = agent_template
452+
453+
dataset = load_dataset('AI-ModelScope/function-calling-chatml')[0]
454+
data = dataset[6]
455+
# To test multiple tool calls and responses, we duplicate some messages.
456+
data['messages'].insert(1, data['messages'][1])
457+
data['messages'].insert(3, data['messages'][3])
458+
459+
# Incomplete tool function will cause seed template to throw an error.
460+
data['tools'] = [('{\n'
461+
' "name": "convert_temperature",\n'
462+
' "description": "Convert temperature from one unit to another",\n'
463+
' "parameters": {\n'
464+
' "type": "object",\n'
465+
' "properties": {\n'
466+
' "temperature": {\n'
467+
' "type": "number",\n'
468+
' "description": "The temperature value"\n'
469+
' },\n'
470+
' "from_unit": {\n'
471+
' "type": "string",\n'
472+
' "description": "The unit to convert from"\n'
473+
' },\n'
474+
' "to_unit": {\n'
475+
' "type": "string",\n'
476+
' "description": "The unit to convert to"\n'
477+
' }\n'
478+
' },\n'
479+
' "required": [\n'
480+
' "temperature",\n'
481+
' "from_unit",\n'
482+
' "to_unit"\n'
483+
' ]\n'
484+
' }\n'
485+
'}'),
486+
('{\n'
487+
' "name": "get_current_date",\n'
488+
' "description": "Get the current date",\n'
489+
' "parameters": {\n'
490+
' "type": "object",\n'
491+
' "properties": {\n'
492+
' "date": {\n'
493+
' "type": "number",\n'
494+
' "description": "The date value"}}}\n'
495+
'}')]
496+
497+
data['thinking_budget'] = 0
498+
499+
template.template_backend = 'swift'
500+
template.set_mode('train')
501+
encoded = template.encode(data)
502+
print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
503+
print(f'labels: {template.safe_decode(encoded["labels"])}')
504+
import re
505+
expected_input_ids = re.sub(
506+
r'<seed:think>.*?</seed:think>', '', template.safe_decode(encoded['input_ids']), flags=re.DOTALL)
507+
template.template_backend = 'jinja'
508+
encoded2 = template.encode(data)
509+
print(f'input_ids: {template.safe_decode(encoded2["input_ids"])}')
510+
print(f'labels: {template.safe_decode(encoded2["labels"])}')
511+
assert template.safe_decode(encoded2['input_ids']) == expected_input_ids
512+
513+
445514
if __name__ == '__main__':
446515
from swift.plugin import agent_templates
447516
from swift.llm import PtEngine, InferRequest, RequestConfig, load_dataset
@@ -460,4 +529,5 @@ def test_deepseek_v3_1():
460529
# test_hunyuan()
461530
# test_glm4_5()
462531
# test_qwen3_coder()
463-
test_deepseek_v3_1()
532+
# test_deepseek_v3_1()
533+
test_seed_oss()

0 commit comments

Comments
 (0)