diff --git a/neon_llm_core/llm.py b/neon_llm_core/llm.py index e170999..ddc7bb2 100644 --- a/neon_llm_core/llm.py +++ b/neon_llm_core/llm.py @@ -23,8 +23,14 @@ # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + from abc import ABC, abstractmethod -from typing import List +from typing import List, Optional, Tuple, Union + +from neon_data_models.models.api import LLMRequest, LLMResponse, \ + LLMProposeRequest, LLMProposeResponse, LLMDiscussRequest, \ + LLMDiscussResponse, LLMVoteRequest, LLMVoteResponse +from ovos_utils.log import LOG, log_deprecation class NeonLLM(ABC): @@ -36,8 +42,6 @@ def __init__(self, config: dict): @param config: Dict LLM configuration for this specific LLM """ self._llm_config = config - self._tokenizer = None - self._model = None @property def llm_config(self): @@ -48,75 +52,208 @@ def llm_config(self): @property @abstractmethod - def tokenizer(self): + def tokenizer(self) -> Optional[object]: + """ + Get a Tokenizer object for the loaded model, if available. + :return: optional transformers.PreTrainedTokenizerBase object + """ pass @property @abstractmethod def tokenizer_model_name(self) -> str: + """ + Get a string tokenizer model name (i.e. a Huggingface `model id`) + associated with `self.tokenizer`. + """ pass @property @abstractmethod - def model(self): + def model(self) -> object: + """ + Get an OpenAI client object to send requests to. + """ pass @property @abstractmethod def llm_model_name(self) -> str: + """ + Get a string model name for the configured `model` + """ pass @property @abstractmethod def _system_prompt(self) -> str: + """ + Get a default string system prompt to use when not included in requests + """ pass def ask(self, message: str, chat_history: List[List[str]], persona: dict) -> str: - """ Generates llm response based on user message and (user, llm) chat history """ + """ + Generates llm response based on user message and (user, llm) chat history + """ + log_deprecation("This method is replaced by `query_model` which " + "accepts a single `LLMRequest` arg", "1.0.0") prompt = self._assemble_prompt(message, chat_history, persona) llm_text_output = self._call_model(prompt) return llm_text_output + def query_model(self, request: LLMRequest) -> LLMResponse: + """ + Calls `self._assemble_prompt` to allow subclass to modify the input + query and then passes the updated query to `self._call_model` + :param request: LLMRequest object to generate a response to + :return: + """ + if request.model != self.llm_model_name: + raise ValueError(f"Requested model ({request.model}) is not this " + f"model ({self.llm_model_name}") + request.query = self._assemble_prompt(request.query, request.history, + request.persona.model_dump()) + response = self._call_model(request.query, request) + history = request.history + [("llm", response)] + return LLMResponse(response=response, history=history) + + + def ask_proposer(self, request: LLMProposeRequest) -> LLMProposeResponse: + """ + Override this method to implement CBF-specific logic + """ + return LLMProposeResponse(**self.query_model(request).model_dump(), + message_id=request.message_id, + routing_key=request.routing_key) + + def ask_discusser(self, request: LLMDiscussRequest, + compose_prompt_method: Optional[callable] = None) -> \ + LLMDiscussResponse: + """ + Override this method to implement CBF-specific logic + """ + if not request.options: + opinion = "Sorry, but I got no options to choose from." + else: + # Default opinion if the model fails to respond + opinion = "Sorry, but I experienced an issue trying to form "\ + "an opinion on this topic" + try: + sorted_answer_indexes = self.get_sorted_answer_indexes( + question=request.query, + answers=list(request.options.values()), + persona=request.persona.model_dump()) + best_respondent_nick, best_response = \ + list(request.options.items())[sorted_answer_indexes[0]] + opinion = self._ask_model_for_opinion( + respondent_nick=best_respondent_nick, + llm_request=request, answer=best_response, + compose_opinion_prompt=compose_prompt_method) + except ValueError as err: + LOG.error(f'ValueError={err}') + except IndexError as err: + # Failed response will return an empty list + LOG.error(f'IndexError={err}') + except Exception as e: + LOG.exception(e) + + return LLMDiscussResponse(message_id=request.message_id, + routing_key=request.routing_key, + opinion=opinion) + + def ask_appraiser(self, request: LLMVoteRequest) -> LLMVoteResponse: + """ + Override this method to implement CBF-specific logic + """ + if not request.responses: + sorted_answer_indexes = [] + else: + # Default opinion if the model fails to respond + sorted_answer_indexes = [] + try: + sorted_answer_indexes = self.get_sorted_answer_indexes( + question=request.query, + answers=request.responses, + persona=request.persona.model_dump()) + except ValueError as err: + LOG.error(f'ValueError={err}') + except IndexError as err: + # Failed response will return an empty list + LOG.error(f'IndexError={err}') + except Exception as e: + LOG.exception(e) + + return LLMVoteResponse(message_id=request.message_id, + routing_key=request.routing_key, + sorted_answer_indexes=sorted_answer_indexes) + + def _ask_model_for_opinion(self, llm_request: LLMRequest, + respondent_nick: str, + answer: str, + compose_opinion_prompt: callable) -> str: + llm_request.query = compose_opinion_prompt( + respondent_nick=respondent_nick, question=llm_request.query, + answer=answer) + opinion = self.model.query_model(llm_request) + LOG.info(f'Received LLM opinion={opinion}, prompt={llm_request.query}') + return opinion.response + @abstractmethod def get_sorted_answer_indexes(self, question: str, answers: List[str], persona: dict) -> List[int]: """ - Creates sorted list of answer indexes with respect to order provided in :param answers - Results should be sorted from best to worst - :param question: incoming question - :param answers: list of answers to rank - :returns list of indexes + Creates sorted list of answer indexes with respect to order provided in + `answers`. Results should be sorted from best to worst + :param question: incoming question + :param answers: list of answers to rank + :param persona: dict representation of Persona to use for sorting + :returns list of indexes """ pass @abstractmethod - def _call_model(self, prompt: str) -> str: + def _call_model(self, prompt: str, + request: Optional[LLMRequest] = None) -> str: """ Wrapper for Model generation logic. This method may be called asynchronously, so it is up to the extending class to use locks or queue inputs as necessary. :param prompt: Input text sequence + :param request: Optional LLMRequest object containing parameters to + include in model requests :returns: Output text sequence generated by model """ pass @abstractmethod - def _assemble_prompt(self, message: str, chat_history: List[List[str]], persona: dict): + def _assemble_prompt(self, message: str, + chat_history: List[Union[List[str], Tuple[str, str]]], + persona: dict) -> str: """ - Assembles prompt engineering logic - - :param message: Incoming prompt - :param chat_history: History of preceding conversation - :returns: assembled prompt + Assemble the prompt to send to the LLM + :param message: Input prompt to optionally modify + :param chat_history: History of preceding conversation + :param persona: dict representation of Persona that is requested + :returns: assembled prompt string """ pass @abstractmethod def _tokenize(self, prompt: str) -> List[str]: + """ + Tokenize the input prompt into a list of strings + :param prompt: Input to tokenize + :return: Tokenized representation of input prompt + """ pass @classmethod def convert_role(cls, role: str) -> str: - """ Maps MQ role to LLM's internal domain """ + """ + Maps MQ role to LLM's internal domain + :param role: Role in Neon LLM format + :return: Role in LLM internal format + """ matching_llm_role = cls.mq_to_llm_role.get(role) if not matching_llm_role: raise ValueError(f"role={role} is undefined, supported are: " diff --git a/neon_llm_core/rmq.py b/neon_llm_core/rmq.py index 1b54047..0851f9e 100644 --- a/neon_llm_core/rmq.py +++ b/neon_llm_core/rmq.py @@ -29,15 +29,15 @@ from time import time from typing import Optional +from neon_data_models.models.api import LLMRequest from neon_mq_connector.connector import MQConnector from neon_mq_connector.utils.rabbit_utils import create_mq_callback from neon_utils.logger import LOG +from neon_data_models.models.api.llm import LLMPersona from neon_data_models.models.api.mq import ( - LLMProposeResponse, - LLMDiscussResponse, - LLMVoteResponse, -) + LLMProposeRequest, LLMProposeResponse, LLMDiscussRequest, + LLMDiscussResponse, LLMVoteRequest, LLMVoteResponse) from neon_llm_core.utils.config import load_config from neon_llm_core.llm import NeonLLM @@ -47,7 +47,8 @@ class NeonLLMMQConnector(MQConnector, ABC): """ - Module for processing MQ requests to Fast Chat LLM + Module to handle LLM requests from the MQ bus and respond with the attached + model's output """ async_consumers_enabled = True @@ -67,6 +68,10 @@ def __init__(self, config: Optional[dict] = None): self._last_persona_update = time() self._personas_provider = PersonasProvider(service_name=self.name, ovos_config=self.ovos_config) + + self._default_persona = self._personas_provider.personas[0] if \ + self._personas_provider.personas else \ + LLMPersona(persona_name="vanilla", enabled=True) def register_consumers(self): for idx in range(self.model_config.get("num_parallel_processes", 1)): @@ -195,26 +200,26 @@ def handle_persona_delete(self, body: dict): def _handle_request_async(self, request: dict): message_id = request["message_id"] routing_key = request["routing_key"] - query = request["query"] - history = request["history"] - persona = request.get("persona", {}) - LOG.debug(f"Request persona={persona}|key={routing_key}") - # Default response if the model fails to respond - response = 'Sorry, but I cannot respond to your message at the '\ - 'moment; please, try again later' + request['persona'] = request.get('persona') or self._default_persona + request['model'] = request.get('model') or self.model.llm_model_name try: - response = self.model.ask(message=query, chat_history=history, - persona=persona) + if request.get('prompt_data'): + # This indicates a CBF prompt + response = self.model.ask_proposer(LLMProposeRequest(**request)) + else: + response = self.model.query_model(LLMRequest(**request)) + response_kwargs = response.model_dump() + response_kwargs['message_id'] = message_id + response_kwargs['routing_key'] = routing_key + response = LLMProposeResponse(**response_kwargs) except ValueError as err: LOG.error(f'ValueError={err}') except Exception as e: LOG.exception(e) - api_response = LLMProposeResponse(message_id=message_id, - response=response, - routing_key=routing_key) - LOG.debug(f"Sending response: {response}") - self.send_message(request_data=api_response.model_dump(), + + LOG.info(f"Sending response: {response}") + self.send_message(request_data=response.model_dump(), queue=routing_key) LOG.info(f"Handled ask request for query={query}") @@ -223,84 +228,39 @@ def _handle_score_async(self, body: dict): Handles score requests (vote) from MQ to LLM :param body: request body (dict) """ - message_id = body["message_id"] - routing_key = body["routing_key"] - - query = body["query"] - responses = body["responses"] - persona = body.get("persona", {}) - - if not responses: - sorted_answer_idx = [] - else: - try: - sorted_answer_idx = self.model.get_sorted_answer_indexes( - question=query, answers=responses, persona=persona) - except ValueError as err: - LOG.error(f'ValueError={err}') - sorted_answer_idx = [] - except Exception as e: - LOG.exception(e) - sorted_answer_idx = [] + body['persona'] = body.get('persona') or self._default_persona + body['model'] = body.get('model') or self.model.llm_model_name + request = LLMVoteRequest(**body) - api_response = LLMVoteResponse(message_id=message_id, - routing_key=routing_key, - sorted_answer_indexes=sorted_answer_idx) + api_response = self.model.ask_appraiser(request) self.send_message(request_data=api_response.model_dump(), - queue=routing_key) - LOG.info(f"Handled score request for query={query}") + queue=request.routing_key) + LOG.info(f"Handled score request for message_id={request.message_id}") def _handle_opinion_async(self, body: dict): """ Handles opinion requests (discuss) from MQ to LLM :param body: request body (dict) """ - message_id = body["message_id"] - routing_key = body["routing_key"] - - query = body["query"] - options = body["options"] - persona = body.get("persona", {}) - responses = list(options.values()) - - if not responses: - opinion = "Sorry, but I got no options to choose from." - else: - # Default opinion if the model fails to respond - opinion = "Sorry, but I experienced an issue trying to form "\ - "an opinion on this topic" - try: - sorted_answer_indexes = self.model.get_sorted_answer_indexes( - question=query, answers=responses, persona=persona) - best_respondent_nick, best_response = list(options.items())[ - sorted_answer_indexes[0]] - opinion = self._ask_model_for_opinion( - respondent_nick=best_respondent_nick, - question=query, answer=best_response, persona=persona) - except ValueError as err: - LOG.error(f'ValueError={err}') - except IndexError as err: - # Failed response will return an empty list - LOG.error(f'IndexError={err}') - except Exception as e: - LOG.exception(e) + body['persona'] = body.get('persona') or self._default_persona + body['model'] = body.get('model') or self.model.llm_model_name + request = LLMDiscussRequest(**body) - api_response = LLMDiscussResponse(message_id=message_id, - routing_key=routing_key, - opinion=opinion) + api_response = self.model.ask_discusser(request, + self.compose_opinion_prompt) self.send_message(request_data=api_response.model_dump(), - queue=routing_key) - LOG.info(f"Handled discuss request for query={query}") + queue=request.routing_key) + LOG.info(f"Handled ask request for message_id={request.message_id}") - def _ask_model_for_opinion(self, respondent_nick: str, question: str, - answer: str, persona: dict) -> str: - prompt = self.compose_opinion_prompt(respondent_nick=respondent_nick, - question=question, - answer=answer) - opinion = self.model.ask(message=prompt, chat_history=[], - persona=persona) - LOG.info(f'Received LLM opinion={opinion}, prompt={prompt}') - return opinion + def _ask_model_for_opinion(self, llm_request: LLMRequest, + respondent_nick: str, + answer: str) -> str: + llm_request.query = self.compose_opinion_prompt( + respondent_nick=respondent_nick, question=llm_request.query, + answer=answer) + opinion = self.model.query_model(llm_request) + LOG.info(f'Received LLM opinion={opinion}, prompt={llm_request.query}') + return opinion.response @staticmethod @abstractmethod diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 06f7957..b219bf8 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -2,5 +2,7 @@ neon-mq-connector~=0.9 neon_utils[sentry]~=1.12 ovos-config~=0.0,>=0.0.10 +ovos-utils~=0.0 pydantic~=2.6 -neon-data-models~=0.0 +#neon-data-models~=0.0,>=0.0.2a3 +neon-data-models@git+https://github.com/neongeckocom/neon-data-models@FEAT_ChatbotPromptData \ No newline at end of file diff --git a/tests/test_llm.py b/tests/test_llm.py index bf0e3e3..ffc9728 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -25,8 +25,289 @@ # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. from unittest import TestCase +from unittest.mock import Mock +from neon_data_models.models.api import LLMResponse +from neon_llm_core.llm import NeonLLM + +class MockLLM(NeonLLM): + + mq_to_llm_role = {"user": "user", + "llm": "assistant"} + + def __init__(self, *args, **kwargs): + NeonLLM.__init__(self, *args, **kwargs) + self._assemble_prompt = Mock(return_value=lambda *args: args[0]) + self._tokenize = Mock(return_value=lambda *args: args[0]) + self.get_sorted_answer_indexes = Mock(side_effect=lambda question, answers, persona: [i for i in range(len(answers))]) + self._call_model = Mock(return_value="mock model response") + self._model = Mock() + + @property + def tokenizer(self): + return None + + @property + def tokenizer_model_name(self) -> str: + return "mock_tokenizer" + + @property + def model(self): + return self._model + + @property + def llm_model_name(self) -> str: + return "mock_model" + + @property + def _system_prompt(self) -> str: + return "mock system prompt" + + class TestNeonLLM(TestCase): - # TODO - pass + MockLLM.__abstractmethods__ = set() + config = {"test_config": True} + + def setUp(self): + # Create a new instance for each test to avoid state leaking between tests + self.test_llm = MockLLM(self.config) + self.test_llm._assemble_prompt.reset_mock() + self.test_llm._tokenize.reset_mock() + self.test_llm.get_sorted_answer_indexes.reset_mock() + self.test_llm._call_model.reset_mock() + + def test_init(self): + self.assertEqual(self.test_llm.llm_config, self.config) + self.assertIsNone(self.test_llm.tokenizer) + self.assertIsInstance(self.test_llm.tokenizer_model_name, str) + self.assertIsNotNone(self.test_llm.model) + self.assertIsInstance(self.test_llm.llm_model_name, str) + self.assertIsInstance(self.test_llm._system_prompt, str) + + def test_ask(self): + from neon_data_models.models.api import LLMPersona + message = "Test input" + history = [["user", "hello"], ["llm", "Hello. How can I help?"]] + persona = LLMPersona(name="test_persona", description="test persona") + + # Valid request + response = self.test_llm.ask(message, history, persona.model_dump()) + self.assertEqual(response, self.test_llm._call_model.return_value) + self.test_llm._assemble_prompt.assert_called_once_with(message, history, + persona.model_dump()) + self.test_llm._call_model.assert_called_once_with(self.test_llm._assemble_prompt.return_value) + + def test_query_model(self): + from neon_data_models.models.api import LLMPersona, LLMRequest + message = "Test input" + history = [["user", "hello"], ["llm", "Hello. How can I help?"]] + persona = LLMPersona(name="test_persona", description="test persona") + valid_request = LLMRequest(query=message, history=history, + persona=persona, + model=self.test_llm.llm_model_name) + response = self.test_llm.query_model(valid_request) + self.assertIsInstance(response, LLMResponse) + self.assertEqual(response.response, self.test_llm._call_model.return_value) + self.assertEqual(response.history[-1], + ("llm", self.test_llm._call_model.return_value)) + self.assertEqual(len(response.history), 3) + self.test_llm._assemble_prompt.assert_called_once_with( + message, valid_request.history, persona.model_dump()) + self.test_llm._call_model.assert_called_once_with( + self.test_llm._assemble_prompt.return_value, valid_request) + + # Request for a different model will raise an exception + invalid_request = LLMRequest(query=message, history=history, + persona=persona, model="invalid_model") + with self.assertRaises(ValueError): + self.test_llm.query_model(invalid_request) + + def test_convert_role(self): + self.assertEqual(self.test_llm.convert_role("user"), "user") + self.assertEqual(self.test_llm.convert_role("llm"), "assistant") + with self.assertRaises(ValueError): + self.test_llm.convert_role("assistant") + + def test_ask_proposer(self): + """Test the ask_proposer method handles requests correctly""" + from neon_data_models.models.api import LLMPersona, LLMProposeRequest, LLMProposeResponse + + message = "Test proposal" + history = [["user", "hello"], ["llm", "Hello. How can I help?"]] + persona = LLMPersona(name="test_persona", description="test persona") + + request = LLMProposeRequest( + query=message, + history=history, + persona=persona, + model=self.test_llm.llm_model_name, + message_id="test_message_id", + routing_key="test_routing_key" + ) + + self.test_llm.query_model = Mock(return_value=LLMResponse( + response="mock response", + history=history + [("llm", "mock response")] + )) + + response = self.test_llm.ask_proposer(request) + + self.assertIsInstance(response, LLMProposeResponse) + self.assertEqual(response.message_id, "test_message_id") + self.assertEqual(response.routing_key, "test_routing_key") + self.assertEqual(response.response, "mock response") + self.test_llm.query_model.assert_called_once_with(request) + + def test_ask_discusser(self): + """Test the ask_discusser method with various scenarios""" + from neon_data_models.models.api import LLMPersona, LLMDiscussRequest, LLMDiscussResponse + + message = "Test discussion" + history = [["user", "hello"], ["llm", "Hello. How can I help?"]] + persona = LLMPersona(name="test_persona", description="test persona") + + empty_request = LLMDiscussRequest( + query=message, + history=history, + persona=persona, + model=self.test_llm.llm_model_name, + message_id="test_message_id", + routing_key="test_routing_key", + options={} + ) + + response = self.test_llm.ask_discusser(empty_request) + self.assertIsInstance(response, LLMDiscussResponse) + self.assertEqual(response.message_id, "test_message_id") + self.assertEqual(response.routing_key, "test_routing_key") + self.assertIsInstance(response.opinion, str) + # self.assertNotEqual(response.opinion, + # self.test_llm._ask_model_for_opinion.return_value) + + options = {"user1": "First option", "user2": "Second option"} + valid_request = LLMDiscussRequest( + query=message, + history=history, + persona=persona, + model=self.test_llm.llm_model_name, + message_id="test_message_id", + routing_key="test_routing_key", + options=options + ) + + self.test_llm._ask_model_for_opinion = Mock(return_value="mock opinion") + + response = self.test_llm.ask_discusser(valid_request) + self.assertIsInstance(response, LLMDiscussResponse) + self.assertEqual(response.message_id, "test_message_id") + self.assertEqual(response.routing_key, "test_routing_key") + self.assertEqual(response.opinion, "mock opinion") + self.test_llm._ask_model_for_opinion.assert_called_once() + + custom_prompt_method = Mock(return_value="Custom prompt") + self.test_llm._ask_model_for_opinion.reset_mock() + + response = self.test_llm.ask_discusser(valid_request, custom_prompt_method) + self.assertEqual(response.opinion, "mock opinion") + self.test_llm._ask_model_for_opinion.assert_called_once_with( + respondent_nick="user1", + llm_request=valid_request, + answer="First option", + compose_opinion_prompt=custom_prompt_method + ) + + self.test_llm.get_sorted_answer_indexes.side_effect = ValueError("Test error") + response = self.test_llm.ask_discusser(valid_request) + self.assertEqual(response.opinion, "Sorry, but I experienced an issue trying to form an opinion on this topic") + self.test_llm.get_sorted_answer_indexes.side_effect = None + + def test_ask_appraiser(self): + """Test the ask_appraiser method with various scenarios""" + from neon_data_models.models.api import LLMPersona, LLMVoteRequest, LLMVoteResponse + + message = "Test voting" + history = [["user", "hello"], ["llm", "Hello. How can I help?"]] + persona = LLMPersona(name="test_persona", description="test persona") + + empty_request = LLMVoteRequest( + query=message, + history=history, + persona=persona, + model=self.test_llm.llm_model_name, + message_id="test_message_id", + routing_key="test_routing_key", + responses=[] + ) + + response = self.test_llm.ask_appraiser(empty_request) + self.assertIsInstance(response, LLMVoteResponse) + self.assertEqual(response.message_id, "test_message_id") + self.assertEqual(response.routing_key, "test_routing_key") + self.assertEqual(response.sorted_answer_indexes, []) + + valid_request = LLMVoteRequest( + query=message, + history=history, + persona=persona, + model=self.test_llm.llm_model_name, + message_id="test_message_id", + routing_key="test_routing_key", + responses=["Response 1", "Response 2", "Response 3"] + ) + + # self.test_llm.get_sorted_answer_indexes.return_value = [2, 0, 1] + response = self.test_llm.ask_appraiser(valid_request) + self.assertIsInstance(response, LLMVoteResponse) + + self.test_llm.get_sorted_answer_indexes.assert_called_once_with( + question=message, + answers=["Response 1", "Response 2", "Response 3"], + persona=persona.model_dump() + ) + self.assertEqual(response.sorted_answer_indexes, + self.test_llm.get_sorted_answer_indexes("", [1,2,3], {})) + + # self.test_llm.get_sorted_answer_indexes = Mock(side_effect=ValueError("Test error")) + # response = self.test_llm.ask_appraiser(valid_request) + # self.assertEqual(response.sorted_answer_indexes, []) + + def test_ask_model_for_opinion(self): + """Test the _ask_model_for_opinion method""" + from neon_data_models.models.api import LLMPersona, LLMDiscussRequest, LLMResponse + + message = "Test opinion" + history = [["user", "hello"], ["llm", "Hello. How can I help?"]] + persona = LLMPersona(name="test_persona", description="test persona") + request = LLMDiscussRequest( + query=message, + history=history, + persona=persona, + model=self.test_llm.llm_model_name, + message_id="test_message_id", + routing_key="test_routing_key", + options={"user1": "Option 1"} + ) + + compose_prompt = Mock(return_value="Composed prompt") + + # self.test_llm.model = Mock() + self.test_llm.model.query_model = Mock(return_value=LLMResponse( + response="Generated opinion", + history=history + [("llm", "Generated opinion")] + )) + + opinion = self.test_llm._ask_model_for_opinion( + llm_request=request, + respondent_nick="user1", + answer="Option 1", + compose_opinion_prompt=compose_prompt + ) + + compose_prompt.assert_called_once_with( + respondent_nick="user1", + question=message, + answer="Option 1" + ) + self.test_llm.model.query_model.assert_called_once() + self.assertEqual(opinion, "Generated opinion") diff --git a/tests/test_rmq.py b/tests/test_rmq.py index a7bb172..c02955f 100644 --- a/tests/test_rmq.py +++ b/tests/test_rmq.py @@ -34,6 +34,8 @@ from neon_mq_connector.utils.network_utils import dict_to_b64 from pytest_rabbitmq.factories.executor import RabbitMqExecutor from neon_minerva.integration.rabbit_mq import rmq_instance +from neon_data_models.models.api.mq import LLMProposeResponse, LLMDiscussResponse +from neon_data_models.models.api.llm import LLMPersona, LLMRequest from neon_llm_core.llm import NeonLLM from neon_llm_core.rmq import NeonLLMMQConnector @@ -48,8 +50,13 @@ def __init__(self, rmq_port: int): "neon_llm_mock_mq": {"user": "test_llm_user", "password": "test_llm_password"}}}} NeonLLMMQConnector.__init__(self, config=config) - self._model = Mock() + self._model = Mock(NeonLLM) + self._model.llm_model_name = "mock_llm@test" self._model.ask.return_value = "Mock response" + self._model.ask_discusser.return_value = LLMDiscussResponse( + opinion="Mock opinion") + self._model.query_model.return_value = LLMProposeResponse( + response="Mock response") self._model.get_sorted_answer_indexes.return_value = [0, 1] self.send_message = Mock() self._compose_opinion_prompt = Mock(return_value="Mock opinion prompt") @@ -102,12 +109,15 @@ def test_handle_request(self): # Valid Request request = LLMProposeRequest(message_id="mock_message_id", routing_key="mock_routing_key", + persona=LLMPersona(persona_name="vanilla", + enabled=True), + model=self.mq_llm.model.llm_model_name, query="Mock Query", history=[]) self.mq_llm.handle_request(None, None, None, dict_to_b64(request.model_dump())).join() - self.mq_llm.model.ask.assert_called_with(message=request.query, - chat_history=request.history, - persona=request.persona) + self.mq_llm.model.query_model.assert_called_with( + LLMRequest(**request.model_dump())) + response = self.mq_llm.send_message.call_args.kwargs self.assertEqual(response['queue'], request.routing_key) response = LLMProposeResponse(**response['request_data']) @@ -115,7 +125,9 @@ def test_handle_request(self): self.assertEqual(request.routing_key, response.routing_key) self.assertEqual(request.message_id, response.message_id) - self.assertEqual(response.response, self.mq_llm.model.ask()) + self.assertEqual(response.response, + self.mq_llm.model.query_model(LLMProposeRequest( + **request.model_dump())).response) def test_handle_opinion_request(self): from neon_data_models.models.api.mq import (LLMDiscussRequest, @@ -126,12 +138,17 @@ def test_handle_opinion_request(self): query="Mock Discuss", history=[], options={"bot 1": "resp 1", "bot 2": "resp 2"}) + # Mock the ask_discusser method to return a known response + discuss_response = LLMDiscussResponse(message_id=request.message_id, + routing_key=request.routing_key, + opinion="Mock opinion") + self.mq_llm.model.ask_discusser.return_value = discuss_response + self.mq_llm.handle_opinion_request(None, None, None, dict_to_b64(request.model_dump())).join() - self.mq_llm._compose_opinion_prompt.assert_called_with( - list(request.options.keys())[0], request.query, - list(request.options.values())[0]) + # Verify ask_discusser was called with the right parameters + self.mq_llm.model.ask_discusser.assert_called_once() response = self.mq_llm.send_message.call_args.kwargs self.assertEqual(response['queue'], request.routing_key) @@ -139,25 +156,29 @@ def test_handle_opinion_request(self): self.assertIsInstance(response, LLMDiscussResponse) self.assertEqual(request.routing_key, response.routing_key) self.assertEqual(request.message_id, response.message_id) - - self.assertEqual(response.opinion, self.mq_llm.model.ask()) + self.assertEqual(response.opinion, "Mock opinion") # No input options request = LLMDiscussRequest(message_id="mock_message_id1", routing_key="mock_routing_key1", query="Mock Discuss 1", history=[], options={}) + # Mock a different response for the empty options case + empty_discuss_response = LLMDiscussResponse(message_id=request.message_id, + routing_key=request.routing_key, + opinion="Sorry, but I got no options to choose from.") + self.mq_llm.model.ask_discusser.return_value = empty_discuss_response + self.mq_llm.handle_opinion_request(None, None, None, dict_to_b64(request.model_dump())).join() + response = self.mq_llm.send_message.call_args.kwargs self.assertEqual(response['queue'], request.routing_key) response = LLMDiscussResponse(**response['request_data']) self.assertIsInstance(response, LLMDiscussResponse) self.assertEqual(request.routing_key, response.routing_key) self.assertEqual(request.message_id, response.message_id) - self.assertNotEqual(response.opinion, self.mq_llm.model.ask()) - - # TODO: Test with invalid sorted answer indexes + self.assertEqual(response.opinion, "Sorry, but I got no options to choose from.") def test_handle_score_request(self): from neon_data_models.models.api.mq import (LLMVoteRequest, @@ -168,31 +189,23 @@ def test_handle_score_request(self): routing_key="mock_routing_key", query="Mock Score", history=[], responses=["one", "two"]) - self.mq_llm.handle_score_request(None, None, None, - dict_to_b64(request.model_dump())).join() - response = self.mq_llm.send_message.call_args.kwargs - self.assertEqual(response['queue'], request.routing_key) - response = LLMVoteResponse(**response['request_data']) - self.assertIsInstance(response, LLMVoteResponse) - self.assertEqual(request.routing_key, response.routing_key) - self.assertEqual(request.message_id, response.message_id) - - self.assertEqual(response.sorted_answer_indexes, - self.mq_llm.model.get_sorted_answer_indexes()) + # Mock the ask_appraiser method to return a known response + vote_response = LLMVoteResponse(message_id=request.message_id, + routing_key=request.routing_key, + sorted_answer_indexes=[0, 1]) + self.mq_llm.model.ask_appraiser.return_value = vote_response - # No response options - request = LLMVoteRequest(message_id="mock_message_id", - routing_key="mock_routing_key", - query="Mock Score", history=[], responses=[]) self.mq_llm.handle_score_request(None, None, None, dict_to_b64(request.model_dump())).join() + # Verify ask_appraiser was called with the right parameters + self.mq_llm.model.ask_appraiser.assert_called_once() + response = self.mq_llm.send_message.call_args.kwargs self.assertEqual(response['queue'], request.routing_key) response = LLMVoteResponse(**response['request_data']) self.assertIsInstance(response, LLMVoteResponse) self.assertEqual(request.routing_key, response.routing_key) self.assertEqual(request.message_id, response.message_id) - - self.assertEqual(response.sorted_answer_indexes, []) + self.assertEqual(response.sorted_answer_indexes, [0, 1])