Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 69 additions & 29 deletions core/llm/llms/WatsonX.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import { streamResponse, streamSse } from "@continuedev/fetch";
import {
AssistantChatMessage,
ChatMessage,
Chunk,
CompletionOptions,
LLMOptions,
ToolCallDelta,
ToolResultChatMessage,
} from "../../index.js";
import { BaseLLM } from "../index.js";
import { fromChatCompletionChunk } from "../openaiTypeConverters.js";
Expand Down Expand Up @@ -89,30 +92,24 @@ class WatsonX extends BaseLLM {
static providerName = "watsonx";

protected _convertMessage(message: ChatMessage) {
if (typeof message.content === "string") {
return message;
}

if (message.role === "tool") {
return null;
let message_ = message as any;
if (message_.role === "tool") {
message_.tool_call_id = (message as ToolResultChatMessage).toolCallId;
delete message_.toolCallId;
} else if (message.role === "assistant" && !!message.toolCalls) {
message_.tool_calls = message.toolCalls.map((t) => ({
...t,
type: "function",
}));
delete message_.toolCalls;
delete message_.content;
} else if (
message_.role === "user" &&
typeof message_.content === "string"
) {
message_.content = [{ type: "text", text: message_.content }];
}

const parts = message.content.map((part) => {
if (part.type === "imageUrl") {
return {
type: "image_url",
image_url: { ...part.imageUrl, detail: "low" },
};
}
return {
type: "text",
text: part.text,
};
});
return {
...message,
content: parts,
};
return message_;
}

protected _convertArgs(options: any, messages: ChatMessage[]) {
Expand Down Expand Up @@ -257,11 +254,11 @@ class WatsonX extends BaseLLM {
const headers = this._getHeaders();

const payload: any = {
messages: messages,
messages: messages.map(this._convertMessage).filter(Boolean),
max_tokens: options.maxTokens ?? 1024,
stop: stopSequences,
frequency_penalty: options.frequencyPenalty || 1,
presence_penalty: options.presencePenalty || 1,
frequency_penalty: options.frequencyPenalty ?? 0,
presence_penalty: options.presencePenalty ?? 0,
};

if (!this.deploymentId) {
Expand Down Expand Up @@ -291,10 +288,53 @@ class WatsonX extends BaseLLM {
signal,
});

let toolName;
let toolCallId = null;
let accumulatedArgs = "";

for await (const value of streamSse(response)) {
const chunk = fromChatCompletionChunk(value);
if (chunk) {
yield chunk;
const message = fromChatCompletionChunk(value);
if (!!message) {
if (
(message as AssistantChatMessage)?.toolCalls &&
(message as AssistantChatMessage).toolCalls?.length !== 0
) {
let chunk = message as AssistantChatMessage;
if (!!chunk.toolCalls?.[0]?.id) {
toolCallId = chunk.toolCalls?.[0]?.id;
}
if (!!chunk.toolCalls?.[0]?.function?.name) {
accumulatedArgs = "";
toolName = chunk.toolCalls[0].function.name;
continue;
}
if (!!toolName) {
if (value?.choices?.[0]?.finish_reason === "tool_calls") {
// If final assistant message has "tool_calls" as finish_reason
let args: string | undefined;
try {
accumulatedArgs += chunk.toolCalls?.[0]?.function?.arguments;
// Check if accumulated argument chunks are parsable
args = JSON.stringify(JSON.parse(accumulatedArgs));
} catch (e) {
// Otherwise use arguments from final assistant tool call message
args = chunk.toolCalls?.[0]?.function?.arguments;
}
const toolCall = {
function: { name: toolName, arguments: args },
id: toolCallId,
};
chunk.toolCalls = [toolCall as ToolCallDelta];
} else {
if (!!chunk.toolCalls?.[0]?.function?.arguments)
accumulatedArgs += chunk.toolCalls?.[0]?.function?.arguments;
continue;
}
}
yield chunk;
} else {
yield message;
}
}
}
}
Expand Down
Loading