Skip to content

Commit 10c75b0

Browse files
committed
test2
1 parent b9a54ea commit 10c75b0

File tree

1 file changed

+11
-8
lines changed
  • integrations/aimlapi/src/haystack_integrations/components/generators/aimlapi/chat

1 file changed

+11
-8
lines changed

integrations/aimlapi/src/haystack_integrations/components/generators/aimlapi/chat/chat_generator.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
from typing import Any, Dict, List, Optional, Union
5+
from typing import Any, Dict, List, Optional, Union, cast
66

77
from haystack import component, default_to_dict, logging
88
from haystack.components.generators.chat import OpenAIChatGenerator
@@ -157,7 +157,7 @@ def _prepare_api_call(
157157
messages: List[ChatMessage],
158158
streaming_callback: Optional[StreamingCallbackT] = None,
159159
generation_kwargs: Optional[Dict[str, Any]] = None,
160-
tools: Optional[Union[List[Tool], Toolset]] = None,
160+
tools: Optional[Union[list[Union[Tool, Toolset]], Toolset]] = None,
161161
tools_strict: Optional[bool] = None,
162162
) -> Dict[str, Any]:
163163
# update generation kwargs by merging with the generation kwargs passed to the run method
@@ -167,17 +167,20 @@ def _prepare_api_call(
167167
# adapt ChatMessage(s) to the format expected by the OpenAI API (AIMLAPI uses the same format)
168168
aimlapi_formatted_messages: List[Dict[str, Any]] = [message.to_openai_dict_format() for message in messages]
169169

170-
tools = tools or self.tools
171-
if isinstance(tools, Toolset):
172-
tools = list(tools)
170+
tools_in = tools or self.tools
171+
if isinstance(tools_in, Toolset):
172+
tools_list: List[Tool] = list(tools_in)
173+
else:
174+
tools_list: List[Tool] = cast(List[Tool], tools_in or [])
175+
173176
tools_strict = tools_strict if tools_strict is not None else self.tools_strict
174-
_check_duplicate_tool_names(list(tools or []))
177+
_check_duplicate_tool_names(tools_list)
175178

176179
aimlapi_tools = {}
177-
if tools:
180+
if tools_list:
178181
tool_definitions = [
179182
{"type": "function", "function": {**t.tool_spec, **({"strict": tools_strict} if tools_strict else {})}}
180-
for t in tools
183+
for t in tools_list
181184
]
182185
aimlapi_tools = {"tools": tool_definitions}
183186

0 commit comments

Comments
 (0)