Skip to content

Add ChatVLLM() #24

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### New features

* Adds vLLM support via a new `ChatVLLM` class. (#24)

### Bug fixes

* `ChatOllama` no longer fails when a `OPENAI_API_KEY` environment variable is not set.
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ pip install -U git+https://github.com/posit-dev/chatlas
* Ollama local models: [`ChatOllama()`](https://posit-dev.github.io/chatlas/reference/ChatOllama.html).
* OpenAI: [`ChatOpenAI()`](https://posit-dev.github.io/chatlas/reference/ChatOpenAI.html).
* perplexity.ai: [`ChatPerplexity()`](https://posit-dev.github.io/chatlas/reference/ChatPerplexity.html).
* vLLM: [`ChatVLLM()`](https://posit-dev.github.io/chatlas/reference/ChatVLLM.html).

It also supports the following enterprise cloud providers:

Expand Down
2 changes: 2 additions & 0 deletions chatlas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ._tokens import token_usage
from ._tools import Tool
from ._turn import Turn
from ._vllm import ChatVLLM

__all__ = (
"ChatAnthropic",
Expand All @@ -24,6 +25,7 @@
"ChatOpenAI",
"ChatAzureOpenAI",
"ChatPerplexity",
"ChatVLLM",
"Chat",
"content_image_file",
"content_image_plot",
Expand Down
6 changes: 3 additions & 3 deletions chatlas/_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def _chat_perform_args(
kwargs: Optional["SubmitInputArgs"] = None,
) -> "SubmitInputArgs":
tool_schemas = [
self._anthropic_tool_schema(tool.schema) for tool in tools.values()
self._tool_schema_json(tool.schema) for tool in tools.values()
]

# If data extraction is requested, add a "mock" tool with parameters inferred from the data model
Expand All @@ -306,7 +306,7 @@ def _structured_tool_call(**kwargs: Any):
},
}

tool_schemas.append(self._anthropic_tool_schema(data_model_tool.schema))
tool_schemas.append(self._tool_schema_json(data_model_tool.schema))

if stream:
stream = False
Expand Down Expand Up @@ -430,7 +430,7 @@ def _as_content_block(content: Content) -> "ContentBlockParam":
raise ValueError(f"Unknown content type: {type(content)}")

@staticmethod
def _anthropic_tool_schema(schema: "ChatCompletionToolParam") -> "ToolParam":
def _tool_schema_json(schema: "ChatCompletionToolParam") -> "ToolParam":
fn = schema["function"]
name = fn["name"]

Expand Down
9 changes: 8 additions & 1 deletion chatlas/_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
ChatCompletion,
ChatCompletionChunk,
ChatCompletionMessageParam,
ChatCompletionToolParam,
)
from openai.types.chat.chat_completion_assistant_message_param import (
ContentArrayOfContentPart,
Expand Down Expand Up @@ -288,7 +289,7 @@ def _chat_perform_args(
data_model: Optional[type[BaseModel]] = None,
kwargs: Optional["SubmitInputArgs"] = None,
) -> "SubmitInputArgs":
tool_schemas = [tool.schema for tool in tools.values()]
tool_schemas = [self._tool_schema_json(tool.schema) for tool in tools.values()]

kwargs_full: "SubmitInputArgs" = {
"stream": stream,
Expand Down Expand Up @@ -454,6 +455,12 @@ def _as_message_param(turns: list[Turn]) -> list["ChatCompletionMessageParam"]:

return res

@staticmethod
def _tool_schema_json(
schema: "ChatCompletionToolParam",
) -> "ChatCompletionToolParam":
return schema

def _as_turn(
self, completion: "ChatCompletion", has_data_model: bool
) -> Turn[ChatCompletion]:
Expand Down
146 changes: 146 additions & 0 deletions chatlas/_vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import os
from typing import TYPE_CHECKING, Optional

import requests

from ._chat import Chat
from ._openai import OpenAIProvider
from ._turn import Turn, normalize_turns

if TYPE_CHECKING:
from openai.types.chat import ChatCompletionToolParam

from .types.openai import ChatClientArgs


def ChatVLLM(
*,
base_url: str,
system_prompt: Optional[str] = None,
turns: Optional[list[Turn]] = None,
model: Optional[str] = None,
api_key: Optional[str] = None,
seed: Optional[int] = None,
kwargs: Optional["ChatClientArgs"] = None,
) -> Chat:
"""
Chat with a model hosted by vLLM

[vLLM](https://docs.vllm.ai/en/latest/) is an open source library that
provides an efficient and convenient LLMs model server. You can use
`ChatVLLM()` to connect to endpoints powered by vLLM.

Prerequisites
-------------

::: {.callout-note}
## vLLM runtime

`ChatVLLM` requires a vLLM server to be running somewhere (either on your
machine or a remote server). If you want to run a vLLM server locally, see
the [vLLM documentation](https://docs.vllm.ai/en/v0.5.3/getting_started/quickstart.html).
:::

::: {.callout-note}
## Python requirements

`ChatVLLM` requires the `openai` package (e.g., `pip install openai`).
:::


Parameters
----------
base_url
A system prompt to set the behavior of the assistant.
system_prompt
Optional system prompt to prepend to conversation.
turns
A list of turns to start the chat with (i.e., continuing a previous
conversation). If not provided, the conversation begins from scratch. Do
not provide non-`None` values for both `turns` and `system_prompt`. Each
message in the list should be a dictionary with at least `role` (usually
`system`, `user`, or `assistant`, but `tool` is also possible). Normally
there is also a `content` field, which is a string.
model
Model identifier to use.
seed
Random seed for reproducibility.
api_key
API key for authentication. If not provided, the `VLLM_API_KEY` environment
variable will be used.
kwargs
Additional arguments to pass to the LLM client.

Returns:
Chat instance configured for vLLM
"""

if api_key is None:
api_key = get_vllm_key()

if model is None:
models = get_vllm_models(base_url, api_key)
available_models = ", ".join(models)
raise ValueError(f"Must specify model. Available models: {available_models}")

return Chat(
provider=VLLMProvider(
base_url=base_url,
model=model,
seed=seed,
api_key=api_key,
kwargs=kwargs,
),
turns=normalize_turns(
turns or [],
system_prompt,
),
)


class VLLMProvider(OpenAIProvider):
def __init__(
self,
base_url: str,
model: str,
seed: Optional[int],
api_key: Optional[str],
kwargs: Optional["ChatClientArgs"],
):
self.base_url = base_url
self.model = model
self.seed = seed
self.api_key = api_key
self.kwargs = kwargs

# Just like OpenAI but no strict
@staticmethod
def _tool_schema_json(
schema: "ChatCompletionToolParam",
) -> "ChatCompletionToolParam":
schema["function"]["strict"] = False
return schema


def get_vllm_key() -> str:
key = os.getenv("VLLM_API_KEY", os.getenv("VLLM_KEY"))
if not key:
raise ValueError("VLLM_API_KEY environment variable not set")
return key


def get_vllm_models(base_url: str, api_key: Optional[str] = None) -> list[str]:
if api_key is None:
api_key = get_vllm_key()

headers = {"Authorization": f"Bearer {api_key}"}
response = requests.get(f"{base_url}/v1/models", headers=headers)
response.raise_for_status()
data = response.json()

return [model["id"] for model in data["data"]]


# def chat_vllm_test(**kwargs) -> Chat:
# """Create a test chat instance with default parameters."""
# return ChatVLLM(base_url="https://llm.nrp-nautilus.io/", model="llama3", **kwargs)
1 change: 1 addition & 0 deletions docs/_quarto.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ quartodoc:
- ChatOllama
- ChatOpenAI
- ChatPerplexity
- ChatVLLM
- title: The chat object
desc: Methods and attributes available on a chat instance
contents:
Expand Down
Loading