|
1 | 1 | import { streamResponse, streamSse } from "@continuedev/fetch";
|
2 | 2 | import {
|
| 3 | + AssistantChatMessage, |
3 | 4 | ChatMessage,
|
4 | 5 | Chunk,
|
5 | 6 | CompletionOptions,
|
6 | 7 | LLMOptions,
|
| 8 | + ToolCallDelta, |
| 9 | + ToolResultChatMessage, |
7 | 10 | } from "../../index.js";
|
8 | 11 | import { BaseLLM } from "../index.js";
|
9 | 12 | import { fromChatCompletionChunk } from "../openaiTypeConverters.js";
|
@@ -89,30 +92,24 @@ class WatsonX extends BaseLLM {
|
89 | 92 | static providerName = "watsonx";
|
90 | 93 |
|
91 | 94 | protected _convertMessage(message: ChatMessage) {
|
92 |
| - if (typeof message.content === "string") { |
93 |
| - return message; |
94 |
| - } |
95 |
| - |
96 |
| - if (message.role === "tool") { |
97 |
| - return null; |
| 95 | + let message_ = message as any; |
| 96 | + if (message_.role === "tool") { |
| 97 | + message_.tool_call_id = (message as ToolResultChatMessage).toolCallId; |
| 98 | + delete message_.toolCallId; |
| 99 | + } else if (message.role === "assistant" && !!message.toolCalls) { |
| 100 | + message_.tool_calls = message.toolCalls.map((t) => ({ |
| 101 | + ...t, |
| 102 | + type: "function", |
| 103 | + })); |
| 104 | + delete message_.toolCalls; |
| 105 | + delete message_.content; |
| 106 | + } else if ( |
| 107 | + message_.role === "user" && |
| 108 | + typeof message_.content === "string" |
| 109 | + ) { |
| 110 | + message_.content = [{ type: "text", text: message_.content }]; |
98 | 111 | }
|
99 |
| - |
100 |
| - const parts = message.content.map((part) => { |
101 |
| - if (part.type === "imageUrl") { |
102 |
| - return { |
103 |
| - type: "image_url", |
104 |
| - image_url: { ...part.imageUrl, detail: "low" }, |
105 |
| - }; |
106 |
| - } |
107 |
| - return { |
108 |
| - type: "text", |
109 |
| - text: part.text, |
110 |
| - }; |
111 |
| - }); |
112 |
| - return { |
113 |
| - ...message, |
114 |
| - content: parts, |
115 |
| - }; |
| 112 | + return message_; |
116 | 113 | }
|
117 | 114 |
|
118 | 115 | protected _convertArgs(options: any, messages: ChatMessage[]) {
|
@@ -257,11 +254,11 @@ class WatsonX extends BaseLLM {
|
257 | 254 | const headers = this._getHeaders();
|
258 | 255 |
|
259 | 256 | const payload: any = {
|
260 |
| - messages: messages, |
| 257 | + messages: messages.map(this._convertMessage).filter(Boolean), |
261 | 258 | max_tokens: options.maxTokens ?? 1024,
|
262 | 259 | stop: stopSequences,
|
263 |
| - frequency_penalty: options.frequencyPenalty || 1, |
264 |
| - presence_penalty: options.presencePenalty || 1, |
| 260 | + frequency_penalty: options.frequencyPenalty ?? 0, |
| 261 | + presence_penalty: options.presencePenalty ?? 0, |
265 | 262 | };
|
266 | 263 |
|
267 | 264 | if (!this.deploymentId) {
|
@@ -291,10 +288,53 @@ class WatsonX extends BaseLLM {
|
291 | 288 | signal,
|
292 | 289 | });
|
293 | 290 |
|
| 291 | + let toolName; |
| 292 | + let toolCallId = null; |
| 293 | + let accumulatedArgs = ""; |
| 294 | + |
294 | 295 | for await (const value of streamSse(response)) {
|
295 |
| - const chunk = fromChatCompletionChunk(value); |
296 |
| - if (chunk) { |
297 |
| - yield chunk; |
| 296 | + const message = fromChatCompletionChunk(value); |
| 297 | + if (!!message) { |
| 298 | + if ( |
| 299 | + (message as AssistantChatMessage)?.toolCalls && |
| 300 | + (message as AssistantChatMessage).toolCalls?.length !== 0 |
| 301 | + ) { |
| 302 | + let chunk = message as AssistantChatMessage; |
| 303 | + if (!!chunk.toolCalls?.[0]?.id) { |
| 304 | + toolCallId = chunk.toolCalls?.[0]?.id; |
| 305 | + } |
| 306 | + if (!!chunk.toolCalls?.[0]?.function?.name) { |
| 307 | + accumulatedArgs = ""; |
| 308 | + toolName = chunk.toolCalls[0].function.name; |
| 309 | + continue; |
| 310 | + } |
| 311 | + if (!!toolName) { |
| 312 | + if (value?.choices?.[0]?.finish_reason === "tool_calls") { |
| 313 | + // If final assistant message has "tool_calls" as finish_reason |
| 314 | + let args: string | undefined; |
| 315 | + try { |
| 316 | + accumulatedArgs += chunk.toolCalls?.[0]?.function?.arguments; |
| 317 | + // Check if accumulated argument chunks are parsable |
| 318 | + args = JSON.stringify(JSON.parse(accumulatedArgs)); |
| 319 | + } catch (e) { |
| 320 | + // Otherwise use arguments from final assistant tool call message |
| 321 | + args = chunk.toolCalls?.[0]?.function?.arguments; |
| 322 | + } |
| 323 | + const toolCall = { |
| 324 | + function: { name: toolName, arguments: args }, |
| 325 | + id: toolCallId, |
| 326 | + }; |
| 327 | + chunk.toolCalls = [toolCall as ToolCallDelta]; |
| 328 | + } else { |
| 329 | + if (!!chunk.toolCalls?.[0]?.function?.arguments) |
| 330 | + accumulatedArgs += chunk.toolCalls?.[0]?.function?.arguments; |
| 331 | + continue; |
| 332 | + } |
| 333 | + } |
| 334 | + yield chunk; |
| 335 | + } else { |
| 336 | + yield message; |
| 337 | + } |
298 | 338 | }
|
299 | 339 | }
|
300 | 340 | }
|
|
0 commit comments