From 1e515aae64287226d891f05db8617b70aa1f73d6 Mon Sep 17 00:00:00 2001 From: Carson Date: Thu, 19 Dec 2024 17:46:02 -0600 Subject: [PATCH] Add .set_token_limit() method to automatically drop old turns when specified limits are reached --- chatlas/_chat.py | 129 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 129 insertions(+) diff --git a/chatlas/_chat.py b/chatlas/_chat.py index 8dfd109a..5ca2cbf9 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -89,6 +89,7 @@ def __init__( self.provider = provider self._turns: list[Turn] = list(turns or []) self._tools: dict[str, Tool] = {} + self.token_limits: Optional[tuple[int, int]] = None self._echo_options: EchoOptions = { "rich_markdown": {}, "rich_console": {}, @@ -381,6 +382,121 @@ async def token_count_async( data_model=data_model, ) + def set_token_limits( + self, + context_window: int, + max_tokens: int, + ): + """ + Set a limit on the number of tokens that can be sent to the model. + + By default, the size of the chat history is unbounded -- it keeps + growing as you submit more input. This can be wasteful if you don't + need to keep the entire chat history around, and can also lead to + errors if the chat history gets too large for the model to handle. + + This method allows you to set a limit to the number of tokens that can + be sent to the model. If the limit is exceeded, the chat history will be + truncated to fit within the limit (i.e., the oldest turns will be + dropped). + + Note that many models publish a context window as well as a maximum + output token limit. For example, + + + + + Also, since the context window is the maximum number of input + output + tokens, the maximum number of tokens that can be sent to the model in a + single request is `context_window - max_tokens`. + + Parameters + ---------- + context_window + The maximum number of tokens that can be sent to the model. + max_tokens + The maximum number of tokens that the model is allowed to generate + in a single response. + + Note + ---- + This method uses `.token_count()` to estimate the token count for new input + before truncating the chat history. This is an estimate, so it may not be + perfect. Morever, any chat models based on `ChatOpenAI()` currently do not + take the tool loop into account when estimating token counts. This means, if + your input will trigger many tool calls, and/or the tool results are large, + it's recommended to set a conservative limit on the `context_window`. + + Examples + -------- + ```python + from chatlas import ChatOpenAI + + chat = ChatOpenAI(model="claude-3-5-sonnet-20241022") + chat.set_token_limit(200000, 8192) + ``` + """ + if max_tokens >= context_window: + raise ValueError("`max_tokens` must be less than the `context_window`.") + self.token_limits = (context_window, max_tokens) + + def _maybe_drop_turns( + self, + *args: Content | str, + data_model: Optional[type[BaseModel]] = None, + ): + """ + Drop turns from the chat history if they exceed the token limits. + """ + + # Do nothing if token limits are not set + if self.token_limits is None: + return None + + turns = self.get_turns(include_system_prompt=False) + + # Do nothing if this is the first turn + if len(turns) == 0: + return None + + last_turn = turns[-1] + + # Sanity checks (i.e., when about to submit new input, the last turn should + # be from the assistant and should contain token counts) + if last_turn.role != "assistant": + raise ValueError( + "Expected the last turn must be from the assistant. Please report this issue." + ) + + if last_turn.tokens is None: + raise ValueError( + "Can't impose token limits since assistant turns contain token counts. " + "Please report this issue and consider setting `.token_limits` to `None`." + ) + + context_window, max_tokens = self.token_limits + max_input_size = context_window - max_tokens + + # Estimate the token count for the (new) user turn + input_tokens = self.token_count(*args, data_model=data_model) + + # Do nothing if current history size plus input size is within the limit + remaining_tokens = max_input_size - input_tokens + if sum(last_turn.tokens) < remaining_tokens: + return self + + tokens = self.tokens(values="discrete") + + # Drop turns until they (plus the new input) fit within the token limits + # TODO: we also need to account for the fact that dropping part of a tool loop is problematic + while sum(tokens) >= remaining_tokens: + del turns[2:] + del tokens[2:] + + self.set_turns(turns) + + return None + def app( self, *, @@ -531,6 +647,8 @@ def chat( A (consumed) response from the chat. Apply `str()` to this object to get the text content of the response. """ + self._maybe_drop_turns(*args) + turn = user_turn(*args) display = self._markdown_display(echo=echo) @@ -581,6 +699,9 @@ async def chat_async( A (consumed) response from the chat. Apply `str()` to this object to get the text content of the response. """ + # TODO: async version? + self._maybe_drop_turns(*args) + turn = user_turn(*args) display = self._markdown_display(echo=echo) @@ -627,6 +748,8 @@ def stream( An (unconsumed) response from the chat. Iterate over this object to consume the response. """ + self._maybe_drop_turns(*args) + turn = user_turn(*args) display = self._markdown_display(echo=echo) @@ -672,6 +795,9 @@ async def stream_async( An (unconsumed) response from the chat. Iterate over this object to consume the response. """ + # TODO: async version? + self._maybe_drop_turns(*args) + turn = user_turn(*args) display = self._markdown_display(echo=echo) @@ -715,6 +841,7 @@ def extract_data( dict[str, Any] The extracted data. """ + self._maybe_drop_turns(*args, data_model=data_model) display = self._markdown_display(echo=echo) @@ -775,6 +902,8 @@ async def extract_data_async( dict[str, Any] The extracted data. """ + # TODO: async version? + self._maybe_drop_turns(*args, data_model=data_model) display = self._markdown_display(echo=echo)