Skip to content

Commit 82e3b02

Browse files
Merge pull request #6438 from mq200/feat/watsonx-structured-tool-calling
watsonx tool support
2 parents 4e3ccde + 189f649 commit 82e3b02

File tree

1 file changed

+69
-29
lines changed

1 file changed

+69
-29
lines changed

core/llm/llms/WatsonX.ts

Lines changed: 69 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import { streamResponse, streamSse } from "@continuedev/fetch";
22
import {
3+
AssistantChatMessage,
34
ChatMessage,
45
Chunk,
56
CompletionOptions,
67
LLMOptions,
8+
ToolCallDelta,
9+
ToolResultChatMessage,
710
} from "../../index.js";
811
import { BaseLLM } from "../index.js";
912
import { fromChatCompletionChunk } from "../openaiTypeConverters.js";
@@ -89,30 +92,24 @@ class WatsonX extends BaseLLM {
8992
static providerName = "watsonx";
9093

9194
protected _convertMessage(message: ChatMessage) {
92-
if (typeof message.content === "string") {
93-
return message;
94-
}
95-
96-
if (message.role === "tool") {
97-
return null;
95+
let message_ = message as any;
96+
if (message_.role === "tool") {
97+
message_.tool_call_id = (message as ToolResultChatMessage).toolCallId;
98+
delete message_.toolCallId;
99+
} else if (message.role === "assistant" && !!message.toolCalls) {
100+
message_.tool_calls = message.toolCalls.map((t) => ({
101+
...t,
102+
type: "function",
103+
}));
104+
delete message_.toolCalls;
105+
delete message_.content;
106+
} else if (
107+
message_.role === "user" &&
108+
typeof message_.content === "string"
109+
) {
110+
message_.content = [{ type: "text", text: message_.content }];
98111
}
99-
100-
const parts = message.content.map((part) => {
101-
if (part.type === "imageUrl") {
102-
return {
103-
type: "image_url",
104-
image_url: { ...part.imageUrl, detail: "low" },
105-
};
106-
}
107-
return {
108-
type: "text",
109-
text: part.text,
110-
};
111-
});
112-
return {
113-
...message,
114-
content: parts,
115-
};
112+
return message_;
116113
}
117114

118115
protected _convertArgs(options: any, messages: ChatMessage[]) {
@@ -257,11 +254,11 @@ class WatsonX extends BaseLLM {
257254
const headers = this._getHeaders();
258255

259256
const payload: any = {
260-
messages: messages,
257+
messages: messages.map(this._convertMessage).filter(Boolean),
261258
max_tokens: options.maxTokens ?? 1024,
262259
stop: stopSequences,
263-
frequency_penalty: options.frequencyPenalty || 1,
264-
presence_penalty: options.presencePenalty || 1,
260+
frequency_penalty: options.frequencyPenalty ?? 0,
261+
presence_penalty: options.presencePenalty ?? 0,
265262
};
266263

267264
if (!this.deploymentId) {
@@ -291,10 +288,53 @@ class WatsonX extends BaseLLM {
291288
signal,
292289
});
293290

291+
let toolName;
292+
let toolCallId = null;
293+
let accumulatedArgs = "";
294+
294295
for await (const value of streamSse(response)) {
295-
const chunk = fromChatCompletionChunk(value);
296-
if (chunk) {
297-
yield chunk;
296+
const message = fromChatCompletionChunk(value);
297+
if (!!message) {
298+
if (
299+
(message as AssistantChatMessage)?.toolCalls &&
300+
(message as AssistantChatMessage).toolCalls?.length !== 0
301+
) {
302+
let chunk = message as AssistantChatMessage;
303+
if (!!chunk.toolCalls?.[0]?.id) {
304+
toolCallId = chunk.toolCalls?.[0]?.id;
305+
}
306+
if (!!chunk.toolCalls?.[0]?.function?.name) {
307+
accumulatedArgs = "";
308+
toolName = chunk.toolCalls[0].function.name;
309+
continue;
310+
}
311+
if (!!toolName) {
312+
if (value?.choices?.[0]?.finish_reason === "tool_calls") {
313+
// If final assistant message has "tool_calls" as finish_reason
314+
let args: string | undefined;
315+
try {
316+
accumulatedArgs += chunk.toolCalls?.[0]?.function?.arguments;
317+
// Check if accumulated argument chunks are parsable
318+
args = JSON.stringify(JSON.parse(accumulatedArgs));
319+
} catch (e) {
320+
// Otherwise use arguments from final assistant tool call message
321+
args = chunk.toolCalls?.[0]?.function?.arguments;
322+
}
323+
const toolCall = {
324+
function: { name: toolName, arguments: args },
325+
id: toolCallId,
326+
};
327+
chunk.toolCalls = [toolCall as ToolCallDelta];
328+
} else {
329+
if (!!chunk.toolCalls?.[0]?.function?.arguments)
330+
accumulatedArgs += chunk.toolCalls?.[0]?.function?.arguments;
331+
continue;
332+
}
333+
}
334+
yield chunk;
335+
} else {
336+
yield message;
337+
}
298338
}
299339
}
300340
}

0 commit comments

Comments
 (0)