|
| 1 | +import asyncio |
1 | 2 | import json |
2 | 3 | import os |
3 | 4 | from collections import deque |
4 | 5 | from copy import deepcopy |
5 | 6 | from datetime import datetime |
6 | 7 | from typing import override |
7 | 8 |
|
| 9 | +import litellm |
8 | 10 | from httpx import request |
9 | 11 |
|
10 | 12 | import openhands.agenthub.codeact_agent.function_calling as codeact_function_calling |
|
22 | 24 | AgentFinishAction, |
23 | 25 | ) |
24 | 26 | from openhands.events.event import Event |
| 27 | +from openhands.llm.health_check import perform_health_check |
25 | 28 | from openhands.llm.llm import LLM |
26 | 29 | from openhands.memory.condenser import Condenser |
27 | 30 | from openhands.memory.condenser.condenser import Condensation, View |
@@ -317,11 +320,11 @@ def step(self, state: State) -> Action: |
317 | 320 | complexity_score = res['outputs'][0]['data'][0] |
318 | 321 | logger.debug(f'Complexity score: {complexity_score}') |
319 | 322 | if complexity_score > 0.3: |
320 | | - response = self.llm.completion(**params) |
| 323 | + response = self._completion_with_failover(**params) |
321 | 324 | else: |
322 | 325 | response = self.routing_llms['simple'].completion(**params) |
323 | 326 | else: |
324 | | - response = self.llm.completion(**params) |
| 327 | + response = self._completion_with_failover(**params) |
325 | 328 |
|
326 | 329 | logger.debug(f'Response from LLM: {response}') |
327 | 330 |
|
@@ -449,3 +452,197 @@ def _enhance_messages(self, messages: list[Message]) -> list[Message]: |
449 | 452 | prev_role = msg.role |
450 | 453 |
|
451 | 454 | return results |
| 455 | + |
| 456 | + def _completion_with_failover(self, **params): |
| 457 | + # First try the default LLM |
| 458 | + try: |
| 459 | + return self.llm.completion(**params) |
| 460 | + except (litellm.RateLimitError, litellm.InternalServerError) as e: |
| 461 | + logger.error( |
| 462 | + f'Error completing with default LLM: {e}. Trying routing LLMs.' |
| 463 | + ) |
| 464 | + last_exception = e |
| 465 | + except Exception as e: |
| 466 | + logger.error( |
| 467 | + f'Error completing with default LLM: {e}. Not rate-limited or server error -> raise' |
| 468 | + ) |
| 469 | + raise e |
| 470 | + |
| 471 | + # If no routing LLMs available, raise the original exception |
| 472 | + if not self.routing_llms: |
| 473 | + error_msg = ( |
| 474 | + 'Default LLM failed and no routing LLMs available. Try again later' |
| 475 | + ) |
| 476 | + logger.error(error_msg) |
| 477 | + raise Exception(error_msg) from last_exception |
| 478 | + |
| 479 | + # Sort LLMs by weight in descending order |
| 480 | + sorted_llms = sorted( |
| 481 | + self.routing_llms.items(), key=lambda x: x[1].config.weight, reverse=True |
| 482 | + ) |
| 483 | + |
| 484 | + has_simple_llm = False |
| 485 | + for name, llm in sorted_llms: |
| 486 | + # Skip if this is the current default LLM (comparing by config) since it's already tried |
| 487 | + if ( |
| 488 | + llm.config.model == self.llm.config.model |
| 489 | + and llm.config.api_key == self.llm.config.api_key |
| 490 | + and llm.config.base_url == self.llm.config.base_url |
| 491 | + ): |
| 492 | + continue |
| 493 | + # skip simple LLM for now in case all weight is 0 |
| 494 | + if name == 'simple': |
| 495 | + has_simple_llm = True |
| 496 | + continue |
| 497 | + try: |
| 498 | + resp = llm.completion(**params) |
| 499 | + # If successful, assign this LLM as the new default |
| 500 | + self.llm = llm |
| 501 | + return resp |
| 502 | + except Exception as e: |
| 503 | + logger.error(f'Error completing with {name}: {e}. Trying next LLM.') |
| 504 | + last_exception = e |
| 505 | + |
| 506 | + if has_simple_llm: |
| 507 | + try: |
| 508 | + # for simple routing, we don't want to re-assign the llm since the model quality is not good |
| 509 | + return self.routing_llms['simple'].completion(**params) |
| 510 | + except Exception as e: |
| 511 | + logger.error(f"Error completing with 'simple' LLM: {e}.") |
| 512 | + last_exception = e |
| 513 | + # If we get here, all LLMs failed |
| 514 | + error_msg = 'All LLMs are not available to process this prompt. Try again later' |
| 515 | + logger.error(error_msg) |
| 516 | + raise Exception(error_msg) from last_exception |
| 517 | + |
| 518 | + @override |
| 519 | + async def select_llm_from_weight_and_availability(self): |
| 520 | + try: |
| 521 | + self.llm = await self._select_llm_from_weight_and_availability() |
| 522 | + logger.info(f'Selected LLM: {self.llm.config.model}') |
| 523 | + except Exception as e: |
| 524 | + logger.warning( |
| 525 | + f'Error selecting LLM from weight and availability: {e}. Use default LLM.' |
| 526 | + ) |
| 527 | + |
| 528 | + async def _select_llm_from_weight_and_availability( |
| 529 | + self, perform_health_check_fn=None, now_fn=None |
| 530 | + ) -> LLM: |
| 531 | + """ |
| 532 | + Select an LLM from a list of LLMs based on the weight and availability using round-robin selection. |
| 533 | +
|
| 534 | + Args: |
| 535 | + routing_llms (dict[str, LLM]): Dictionary mapping LLM names to their instances |
| 536 | + perform_health_check_fn (callable, optional): Function to perform health check (for testing) |
| 537 | + now_fn (callable, optional): Function to get current datetime (for testing) |
| 538 | + Returns: |
| 539 | + LLM: The selected LLM instance |
| 540 | + Raises: |
| 541 | + ValueError: If no available LLMs are found |
| 542 | + """ |
| 543 | + |
| 544 | + if not self.routing_llms: |
| 545 | + raise ValueError('No LLMs available for routing') |
| 546 | + |
| 547 | + # Get available LLMs from health check |
| 548 | + models_rate_limit = await self._get_available_llms_from_health_check( |
| 549 | + perform_health_check_fn |
| 550 | + ) |
| 551 | + if not models_rate_limit: |
| 552 | + raise ValueError('No available LLMs found') |
| 553 | + |
| 554 | + # Select LLM based on weights |
| 555 | + selected_name = self._select_llm_from_weights(models_rate_limit, now_fn) |
| 556 | + return self.routing_llms[selected_name] |
| 557 | + |
| 558 | + async def _get_available_llms_from_health_check( |
| 559 | + self, perform_health_check_fn=None |
| 560 | + ) -> dict[str, tuple[int, int, float]]: |
| 561 | + """ |
| 562 | + Get available LLMs by performing health checks. |
| 563 | +
|
| 564 | + Args: |
| 565 | + perform_health_check_fn: Function to perform health check |
| 566 | +
|
| 567 | + Returns: |
| 568 | + dict[str, tuple[int, int, float]]: Dictionary mapping LLM names to their rate limits and weights |
| 569 | + """ |
| 570 | + if not self.routing_llms: |
| 571 | + raise ValueError('No LLMs available for routing') |
| 572 | + if perform_health_check_fn is None: |
| 573 | + perform_health_check_fn = perform_health_check |
| 574 | + models_rate_limit: dict[str, tuple[int, int, float]] = {} |
| 575 | + |
| 576 | + async def check_llm( |
| 577 | + name: str, llm: LLM |
| 578 | + ) -> tuple[str, tuple[int, int, float]] | None: |
| 579 | + (remaining_requests, remaining_tokens) = await perform_health_check_fn( |
| 580 | + { |
| 581 | + 'model': llm.config.model, |
| 582 | + 'api_key': llm.config.api_key, |
| 583 | + 'base_url': llm.config.base_url, |
| 584 | + } |
| 585 | + ) |
| 586 | + if remaining_requests is not None and remaining_tokens is not None: |
| 587 | + return name, (remaining_requests, remaining_tokens, llm.config.weight) |
| 588 | + return None |
| 589 | + |
| 590 | + tasks = [check_llm(name, llm) for name, llm in self.routing_llms.items()] |
| 591 | + results = await asyncio.gather(*tasks, return_exceptions=True) |
| 592 | + |
| 593 | + models_rate_limit = { |
| 594 | + name: data |
| 595 | + for result in results |
| 596 | + if result is not None |
| 597 | + and not isinstance(result, Exception) |
| 598 | + and not isinstance(result, BaseException) |
| 599 | + for name, data in [result] |
| 600 | + } |
| 601 | + return models_rate_limit |
| 602 | + |
| 603 | + def _select_llm_from_weights( |
| 604 | + self, models_rate_limit: dict[str, tuple[int, int, float]], now_fn=None |
| 605 | + ) -> str: |
| 606 | + """ |
| 607 | + Select an LLM based on weights using round-robin selection. |
| 608 | +
|
| 609 | + Args: |
| 610 | + models_rate_limit: Dictionary mapping LLM names to their rate limits and weights |
| 611 | + now_fn: Function to get current datetime |
| 612 | +
|
| 613 | + Returns: |
| 614 | + str: Name of the selected LLM |
| 615 | +
|
| 616 | + Raises: |
| 617 | + ValueError: If no available LLMs found or total weight is 0 |
| 618 | + """ |
| 619 | + if now_fn is None: |
| 620 | + from datetime import datetime as dt |
| 621 | + |
| 622 | + now_fn = dt.now |
| 623 | + # Calculate total weight and normalize in a single pass |
| 624 | + total_weight = 0.0 |
| 625 | + normalized_weights = {} |
| 626 | + for name, (_, _, weight) in models_rate_limit.items(): |
| 627 | + total_weight += weight |
| 628 | + normalized_weights[name] = weight |
| 629 | + |
| 630 | + if total_weight <= 0: |
| 631 | + raise ValueError('No available LLMs found') |
| 632 | + |
| 633 | + # Normalize weights and create selection pool in one pass |
| 634 | + selection_pool = [] |
| 635 | + for name, weight in normalized_weights.items(): |
| 636 | + count = int((weight / total_weight) * 100) |
| 637 | + if count > 0: |
| 638 | + selection_pool.extend([name] * count) |
| 639 | + |
| 640 | + if not selection_pool: |
| 641 | + # Fallback to equal weights if no weights specified |
| 642 | + selection_pool = list(models_rate_limit.keys()) |
| 643 | + |
| 644 | + # Get current timestamp for deterministic but changing selection |
| 645 | + current_time = int(now_fn().timestamp()) |
| 646 | + # Select LLM using timestamp-based index |
| 647 | + selected_index = current_time % len(selection_pool) |
| 648 | + return selection_pool[selected_index] |
0 commit comments