|
2 | 2 | Copyright (c) 2024, 2025, Oracle and/or its affiliates.
|
3 | 3 | Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl.
|
4 | 4 | """
|
5 |
| - |
6 | 5 | from typing import List
|
7 | 6 | from mcp.server.fastmcp import FastMCP
|
8 | 7 | import os
|
9 | 8 | from dotenv import load_dotenv
|
10 |
| - |
11 |
| -# from sentence_transformers import CrossEncoder |
12 |
| -# from langchain_community.embeddings import HuggingFaceEmbeddings |
13 | 9 | from langchain_core.prompts import PromptTemplate
|
14 | 10 | from langchain_core.runnables import RunnablePassthrough
|
15 | 11 | from langchain_core.output_parsers import StrOutputParser
|
16 | 12 | import json
|
17 | 13 | import logging
|
18 |
| - |
19 | 14 | logging.basicConfig(level=logging.DEBUG)
|
20 | 15 |
|
21 | 16 | from optimizer_utils import config
|
22 | 17 |
|
23 |
| -_optimizer_settings_path = "" |
24 |
| - |
| 18 | +_optimizer_settings_path= "" |
25 | 19 |
|
26 | 20 | def set_optimizer_settings_path(path: str):
|
27 | 21 | global _optimizer_settings_path
|
28 | 22 | _optimizer_settings_path = path
|
29 | 23 |
|
30 |
| - |
31 | 24 | def rag_tool_base(question: str) -> str:
|
32 | 25 | """
|
33 | 26 | Use this tool to answer any question that may benefit from up-to-date or domain-specific information.
|
34 |
| -
|
| 27 | + |
35 | 28 | Args:
|
36 | 29 | question: the question for which are you looking for an answer
|
37 |
| -
|
| 30 | + |
38 | 31 | Returns:
|
39 | 32 | JSON string with answer
|
40 | 33 | """
|
41 | 34 | with open(_optimizer_settings_path, "r") as file:
|
42 | 35 | data = json.load(file)
|
43 |
| - try: |
| 36 | + logging.info("Json loaded!") |
| 37 | + try: |
| 38 | + |
44 | 39 | embeddings = config.get_embeddings(data)
|
45 |
| - |
46 |
| - print("Embedding successful!") |
47 |
| - knowledge_base = config.get_vectorstore(data, embeddings) |
48 |
| - print("DB Connection successful!") |
49 |
| - |
50 |
| - print("knowledge_base successful!") |
| 40 | + |
| 41 | + logging.info("Embedding successful!") |
| 42 | + knowledge_base = config.get_vectorstore(data,embeddings) |
| 43 | + logging.info("DB Connection successful!") |
| 44 | + |
| 45 | + logging.info("knowledge_base successful!") |
51 | 46 | user_question = question
|
52 |
| - # result_chunks=knowledge_base.similarity_search(user_question, 5) |
53 |
| - |
54 |
| - for d in data["prompt_configs"]: |
55 |
| - if d["name"] == data["client_settings"]["prompts"]["sys"]: |
56 |
| - rag_prompt = d["prompt"] |
57 |
| - |
58 |
| - template = """DOCUMENTS: {context} \n""" + rag_prompt + """\nQuestion: {question} """ |
59 |
| - # template = """Answer the question based only on the following context:{context} Question: {question} """ |
60 |
| - print(template) |
| 47 | + logging.info("start looking for prompts") |
| 48 | + for d in data["prompts_config"]: |
| 49 | + if d["name"]==data["user_settings"]["prompts"]["sys"]: |
| 50 | + |
| 51 | + rag_prompt=d["prompt"] |
| 52 | + |
| 53 | + logging.info("rag_prompt:") |
| 54 | + logging.info(rag_prompt) |
| 55 | + template = """DOCUMENTS: {context} \n"""+rag_prompt+"""\nQuestion: {question} """ |
| 56 | + logging.info(template) |
61 | 57 | prompt = PromptTemplate.from_template(template)
|
62 |
| - print("before retriever") |
63 |
| - print(data["client_settings"]["rag"]["top_k"]) |
64 |
| - retriever = knowledge_base.as_retriever(search_kwargs={"k": data["client_settings"]["rag"]["top_k"]}) |
65 |
| - print("after retriever") |
| 58 | + logging.info("before retriever") |
| 59 | + logging.info(data["user_settings"]["vector_search"]["top_k"]) |
| 60 | + retriever = knowledge_base.as_retriever(search_kwargs={"k": data["user_settings"]["vector_search"]["top_k"]}) |
| 61 | + logging.info("after retriever") |
| 62 | + |
66 | 63 |
|
67 | 64 | # Initialize the LLM
|
68 |
| - llm = config.get_llm(data) |
69 |
| - |
70 |
| - chain = {"context": retriever, "question": RunnablePassthrough()} | prompt | llm | StrOutputParser() |
71 |
| - print("pre-chain successful!") |
| 65 | + llm = config.get_llm(data) |
| 66 | + |
| 67 | + chain = ( |
| 68 | + {"context": retriever, "question": RunnablePassthrough()} |
| 69 | + | prompt |
| 70 | + | llm |
| 71 | + | StrOutputParser() |
| 72 | + ) |
| 73 | + logging.info("pre-chain successful!") |
72 | 74 | answer = chain.invoke(user_question)
|
73 | 75 |
|
74 |
| - # print(f"Results provided for question: {question}") |
75 |
| - # print(f"{answer}") |
| 76 | + |
76 | 77 | except Exception as e:
|
77 |
| - print(e) |
78 |
| - print("Connection failed!") |
79 |
| - answer = "" |
| 78 | + logging.info(e) |
| 79 | + logging.info("Connection failed!") |
| 80 | + answer="" |
80 | 81 |
|
81 | 82 | return f"{answer}"
|
0 commit comments