Skip to content

Commit d0ff379

Browse files
committed
Merge branch 'feat/code-llm-round-robin' into staging
2 parents 892a447 + 811b8b4 commit d0ff379

File tree

17 files changed

+886
-13
lines changed

17 files changed

+886
-13
lines changed

.env.example

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ JWT_SECRET=your-secret-key
1515
# Thesis Auth Server URL
1616
THESIS_AUTH_SERVER_URL=
1717

18-
# Run Mode (PROD)
18+
# Run Mode (PROD)
1919
RUN_MODE='PROD'
2020

2121

config.template.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,10 +211,13 @@ model = "gpt-4o"
211211
# https://github.com/All-Hands-AI/OpenHands/pull/4711
212212
#native_tool_calling = None
213213

214+
# weight of the LLM. This is used to select the LLM when multiple LLMs are available.
215+
#weight = 0.6
214216

215217
[llm.gpt4o-mini]
216218
api_key = ""
217219
model = "gpt-4o"
220+
weight = 0.4
218221

219222

220223
#################################### Agent ###################################

openhands/agenthub/codeact_agent/codeact_agent.py

Lines changed: 199 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
import asyncio
12
import json
23
import os
34
from collections import deque
45
from copy import deepcopy
56
from datetime import datetime
67
from typing import override
78

9+
import litellm
810
from httpx import request
911

1012
import openhands.agenthub.codeact_agent.function_calling as codeact_function_calling
@@ -22,6 +24,7 @@
2224
AgentFinishAction,
2325
)
2426
from openhands.events.event import Event
27+
from openhands.llm.health_check import perform_health_check
2528
from openhands.llm.llm import LLM
2629
from openhands.memory.condenser import Condenser
2730
from openhands.memory.condenser.condenser import Condensation, View
@@ -317,11 +320,11 @@ def step(self, state: State) -> Action:
317320
complexity_score = res['outputs'][0]['data'][0]
318321
logger.debug(f'Complexity score: {complexity_score}')
319322
if complexity_score > 0.3:
320-
response = self.llm.completion(**params)
323+
response = self._completion_with_failover(**params)
321324
else:
322325
response = self.routing_llms['simple'].completion(**params)
323326
else:
324-
response = self.llm.completion(**params)
327+
response = self._completion_with_failover(**params)
325328

326329
logger.debug(f'Response from LLM: {response}')
327330

@@ -449,3 +452,197 @@ def _enhance_messages(self, messages: list[Message]) -> list[Message]:
449452
prev_role = msg.role
450453

451454
return results
455+
456+
def _completion_with_failover(self, **params):
457+
# First try the default LLM
458+
try:
459+
return self.llm.completion(**params)
460+
except (litellm.RateLimitError, litellm.InternalServerError) as e:
461+
logger.error(
462+
f'Error completing with default LLM: {e}. Trying routing LLMs.'
463+
)
464+
last_exception = e
465+
except Exception as e:
466+
logger.error(
467+
f'Error completing with default LLM: {e}. Not rate-limited or server error -> raise'
468+
)
469+
raise e
470+
471+
# If no routing LLMs available, raise the original exception
472+
if not self.routing_llms:
473+
error_msg = (
474+
'Default LLM failed and no routing LLMs available. Try again later'
475+
)
476+
logger.error(error_msg)
477+
raise Exception(error_msg) from last_exception
478+
479+
# Sort LLMs by weight in descending order
480+
sorted_llms = sorted(
481+
self.routing_llms.items(), key=lambda x: x[1].config.weight, reverse=True
482+
)
483+
484+
has_simple_llm = False
485+
for name, llm in sorted_llms:
486+
# Skip if this is the current default LLM (comparing by config) since it's already tried
487+
if (
488+
llm.config.model == self.llm.config.model
489+
and llm.config.api_key == self.llm.config.api_key
490+
and llm.config.base_url == self.llm.config.base_url
491+
):
492+
continue
493+
# skip simple LLM for now in case all weight is 0
494+
if name == 'simple':
495+
has_simple_llm = True
496+
continue
497+
try:
498+
resp = llm.completion(**params)
499+
# If successful, assign this LLM as the new default
500+
self.llm = llm
501+
return resp
502+
except Exception as e:
503+
logger.error(f'Error completing with {name}: {e}. Trying next LLM.')
504+
last_exception = e
505+
506+
if has_simple_llm:
507+
try:
508+
# for simple routing, we don't want to re-assign the llm since the model quality is not good
509+
return self.routing_llms['simple'].completion(**params)
510+
except Exception as e:
511+
logger.error(f"Error completing with 'simple' LLM: {e}.")
512+
last_exception = e
513+
# If we get here, all LLMs failed
514+
error_msg = 'All LLMs are not available to process this prompt. Try again later'
515+
logger.error(error_msg)
516+
raise Exception(error_msg) from last_exception
517+
518+
@override
519+
async def select_llm_from_weight_and_availability(self):
520+
try:
521+
self.llm = await self._select_llm_from_weight_and_availability()
522+
logger.info(f'Selected LLM: {self.llm.config.model}')
523+
except Exception as e:
524+
logger.warning(
525+
f'Error selecting LLM from weight and availability: {e}. Use default LLM.'
526+
)
527+
528+
async def _select_llm_from_weight_and_availability(
529+
self, perform_health_check_fn=None, now_fn=None
530+
) -> LLM:
531+
"""
532+
Select an LLM from a list of LLMs based on the weight and availability using round-robin selection.
533+
534+
Args:
535+
routing_llms (dict[str, LLM]): Dictionary mapping LLM names to their instances
536+
perform_health_check_fn (callable, optional): Function to perform health check (for testing)
537+
now_fn (callable, optional): Function to get current datetime (for testing)
538+
Returns:
539+
LLM: The selected LLM instance
540+
Raises:
541+
ValueError: If no available LLMs are found
542+
"""
543+
544+
if not self.routing_llms:
545+
raise ValueError('No LLMs available for routing')
546+
547+
# Get available LLMs from health check
548+
models_rate_limit = await self._get_available_llms_from_health_check(
549+
perform_health_check_fn
550+
)
551+
if not models_rate_limit:
552+
raise ValueError('No available LLMs found')
553+
554+
# Select LLM based on weights
555+
selected_name = self._select_llm_from_weights(models_rate_limit, now_fn)
556+
return self.routing_llms[selected_name]
557+
558+
async def _get_available_llms_from_health_check(
559+
self, perform_health_check_fn=None
560+
) -> dict[str, tuple[int, int, float]]:
561+
"""
562+
Get available LLMs by performing health checks.
563+
564+
Args:
565+
perform_health_check_fn: Function to perform health check
566+
567+
Returns:
568+
dict[str, tuple[int, int, float]]: Dictionary mapping LLM names to their rate limits and weights
569+
"""
570+
if not self.routing_llms:
571+
raise ValueError('No LLMs available for routing')
572+
if perform_health_check_fn is None:
573+
perform_health_check_fn = perform_health_check
574+
models_rate_limit: dict[str, tuple[int, int, float]] = {}
575+
576+
async def check_llm(
577+
name: str, llm: LLM
578+
) -> tuple[str, tuple[int, int, float]] | None:
579+
(remaining_requests, remaining_tokens) = await perform_health_check_fn(
580+
{
581+
'model': llm.config.model,
582+
'api_key': llm.config.api_key,
583+
'base_url': llm.config.base_url,
584+
}
585+
)
586+
if remaining_requests is not None and remaining_tokens is not None:
587+
return name, (remaining_requests, remaining_tokens, llm.config.weight)
588+
return None
589+
590+
tasks = [check_llm(name, llm) for name, llm in self.routing_llms.items()]
591+
results = await asyncio.gather(*tasks, return_exceptions=True)
592+
593+
models_rate_limit = {
594+
name: data
595+
for result in results
596+
if result is not None
597+
and not isinstance(result, Exception)
598+
and not isinstance(result, BaseException)
599+
for name, data in [result]
600+
}
601+
return models_rate_limit
602+
603+
def _select_llm_from_weights(
604+
self, models_rate_limit: dict[str, tuple[int, int, float]], now_fn=None
605+
) -> str:
606+
"""
607+
Select an LLM based on weights using round-robin selection.
608+
609+
Args:
610+
models_rate_limit: Dictionary mapping LLM names to their rate limits and weights
611+
now_fn: Function to get current datetime
612+
613+
Returns:
614+
str: Name of the selected LLM
615+
616+
Raises:
617+
ValueError: If no available LLMs found or total weight is 0
618+
"""
619+
if now_fn is None:
620+
from datetime import datetime as dt
621+
622+
now_fn = dt.now
623+
# Calculate total weight and normalize in a single pass
624+
total_weight = 0.0
625+
normalized_weights = {}
626+
for name, (_, _, weight) in models_rate_limit.items():
627+
total_weight += weight
628+
normalized_weights[name] = weight
629+
630+
if total_weight <= 0:
631+
raise ValueError('No available LLMs found')
632+
633+
# Normalize weights and create selection pool in one pass
634+
selection_pool = []
635+
for name, weight in normalized_weights.items():
636+
count = int((weight / total_weight) * 100)
637+
if count > 0:
638+
selection_pool.extend([name] * count)
639+
640+
if not selection_pool:
641+
# Fallback to equal weights if no weights specified
642+
selection_pool = list(models_rate_limit.keys())
643+
644+
# Get current timestamp for deterministic but changing selection
645+
current_time = int(now_fn().timestamp())
646+
# Select LLM using timestamp-based index
647+
selected_index = current_time % len(selection_pool)
648+
return selection_pool[selected_index]

openhands/controller/agent.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,3 +171,9 @@ def update_agent_knowledge_base(
171171
for k in knowledge_base:
172172
if k.get('chunkId', None):
173173
self.knowledge_base[k['chunkId']] = k
174+
175+
async def select_llm_from_weight_and_availability(self):
176+
"""
177+
Select an LLM from a list of LLMs based on the weight and availability using round-robin selection.
178+
"""
179+
raise NotImplementedError('This method should be implemented by the subclass')

openhands/core/config/app_config.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,21 @@ class AppConfig(BaseModel):
117117
def get_llm_config(self, name: str = 'llm') -> LLMConfig:
118118
"""'llm' is the name for default config (for backward compatibility prior to 0.8)."""
119119
if name in self.llms:
120-
return self.llms[name]
120+
llm = self.llms[name]
121+
if llm.model is not None and (
122+
llm.api_key is not None or llm.base_url is not None
123+
):
124+
return llm
121125
if name is not None and name != 'llm':
122126
logger.openhands_logger.warning(
123127
f'llm config group {name} not found, using default config'
124128
)
125-
if 'llm' not in self.llms:
129+
if len(self.llms) == 0:
126130
self.llms['llm'] = LLMConfig()
131+
else:
132+
# Get the LLM config with highest weight
133+
highest_weight_llm = max(self.llms.items(), key=lambda x: x[1].weight)
134+
self.llms['llm'] = highest_weight_llm[1]
127135
return self.llms['llm']
128136

129137
def set_llm_config(self, value: LLMConfig, name: str = 'llm') -> None:

openhands/core/config/llm_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ class LLMConfig(BaseModel):
8484
native_tool_calling: bool | None = Field(default=None)
8585
reasoning_effort: str | None = Field(default='high')
8686
seed: int | None = Field(default=None)
87+
weight: float = Field(default=0.0)
8788

8889
model_config = {'extra': 'forbid'}
8990

openhands/core/config/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def get_optional_type(union_type: UnionType | type | None) -> type | None:
6262
# helper function to set attributes based on env vars
6363
def set_attr_from_env(sub_config: BaseModel, prefix: str = '') -> None:
6464
"""Set attributes of a config model based on environment variables."""
65-
for field_name, field_info in sub_config.model_fields.items():
65+
for field_name, field_info in sub_config.__class__.model_fields.items():
6666
field_value = getattr(sub_config, field_name)
6767
field_type = field_info.annotation
6868

openhands/core/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ async def run_controller(
9999
if agent is None:
100100
agent = create_agent(config)
101101
mcp_tools = await fetch_mcp_tools_from_config(config.dict_mcp_config, sid=sid)
102+
await agent.select_llm_from_weight_and_availability()
102103
logger.info(f'MCP tools: {mcp_tools}')
103104
agent.set_mcp_tools(mcp_tools)
104105

openhands/llm/health_check.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# This file runs a health check for the LLM, used on litellm/proxy
2+
import random
3+
4+
import litellm
5+
from pydantic import SecretStr
6+
7+
from openhands.core.logger import openhands_logger as logger
8+
9+
10+
def _get_random_llm_message():
11+
"""
12+
Get a random message from the LLM.
13+
"""
14+
messages = ["Hey how's it going?", "What's 1 + 1?"]
15+
16+
return [{'role': 'user', 'content': random.choice(messages)}]
17+
18+
19+
# NOTE: are default values sufficient?
20+
async def perform_health_check(
21+
model_params: dict,
22+
min_remaining_requests: int = 20,
23+
min_remaining_tokens: int = 20000,
24+
):
25+
"""
26+
Perform a health check for each model in the list.
27+
model_params must have the following keys:
28+
- model
29+
- api_key
30+
- base_url (optional)
31+
"""
32+
model_params['messages'] = _get_random_llm_message()
33+
api_key: SecretStr = model_params.get('api_key', None)
34+
if api_key is None:
35+
raise ValueError('api_key is required')
36+
if model_params.get('model', None) is None:
37+
raise ValueError('model is required')
38+
api_key_str = api_key.get_secret_value()
39+
model_params['api_key'] = api_key_str
40+
try:
41+
result = await litellm.ahealth_check(
42+
model_params=model_params,
43+
mode='chat',
44+
)
45+
remaining_requests = int(result.get('x-ratelimit-remaining-requests', 0))
46+
remaining_tokens = int(result.get('x-ratelimit-remaining-tokens', 0))
47+
if (
48+
remaining_requests > min_remaining_requests
49+
and remaining_tokens > min_remaining_tokens
50+
):
51+
return remaining_requests, remaining_tokens
52+
else:
53+
return None, None
54+
except Exception as e:
55+
logger.error(f'Error performing health check: {e}')
56+
return None, None

openhands/llm/llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
223223

224224
messages: list[dict[str, Any]] | dict[str, Any] = []
225225
mock_function_calling = not self.is_function_calling_active()
226-
logger.info(f'Mock function calling: {mock_function_calling}')
226+
logger.debug(f'Mock function calling: {mock_function_calling}')
227227
# Add session_id and user_id as span attributes if they exist
228228
try:
229229
span = trace.get_current_span()

0 commit comments

Comments
 (0)