Skip to content

Commit a9dcbab

Browse files
integrated local db connection
1 parent 08d12b5 commit a9dcbab

File tree

4,393 files changed

+3022658
-1399
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

4,393 files changed

+3022658
-1399
lines changed

.DS_Store

2 KB
Binary file not shown.

__pycache__/agent.cpython-312.pyc

-790 Bytes
Binary file not shown.
12.8 KB
Binary file not shown.
13.3 KB
Binary file not shown.

__pycache__/tools.cpython-312.pyc

-197 Bytes
Binary file not shown.

agent.py

Lines changed: 30 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,112 +1,72 @@
1-
import streamlit as st
1+
# agent.py
22
from dataclasses import dataclass
33
from typing import Annotated, Sequence, Optional
4-
5-
from langchain.callbacks.base import BaseCallbackHandler
4+
from langchain_core.messages import BaseMessage
5+
from langgraph.graph.message import add_messages
66
from langchain_core.messages import SystemMessage
7-
from langchain_openai import ChatOpenAI
7+
from langchain_google_genai import ChatGoogleGenerativeAI
88
from langgraph.checkpoint.memory import MemorySaver
99
from langgraph.graph import START, END, StateGraph
1010
from langgraph.prebuilt import ToolNode, tools_condition
11-
from langgraph.graph.message import add_messages
12-
from langchain_core.messages import BaseMessage
1311

14-
from tools import retriever_tool
15-
from tools import search, sql_executor_tool
16-
from PIL import Image
17-
from io import BytesIO
12+
# Exported items
13+
__all__ = ["MessagesState", "create_agent"]
1814

1915
@dataclass
2016
class MessagesState:
2117
messages: Annotated[Sequence[BaseMessage], add_messages]
2218

23-
2419
memory = MemorySaver()
2520

26-
21+
# Model configuration for Google Gemini only
2722
@dataclass
2823
class ModelConfig:
2924
model_name: str
3025
api_key: str
3126
base_url: Optional[str] = None
3227

33-
3428
model_configurations = {
35-
# "o3-mini": ModelConfig(
36-
# model_name="o3-mini", api_key=st.secrets["OPENAI_API_KEY"]
37-
# ),
38-
39-
40-
"Deepseek R1": ModelConfig(
41-
model_name="deepseek-r1-distill-llama-70b",
42-
api_key=st.secrets["GROK_API_KEY"],
43-
base_url=f"https://api.groq.com/openai/v1",
44-
),
45-
46-
# "Mistral 7B": ModelConfig(
47-
# model_name="mistralai/mistral-7b-v0.1", api_key=st.secrets["REPLICATE_API_TOKEN"]
48-
# ),
49-
# "Qwen 2.5": ModelConfig(
50-
# model_name="accounts/fireworks/models/qwen2p5-coder-32b-instruct",
51-
# api_key=st.secrets["FIREWORKS_API_KEY"],
52-
# base_url="https://api.fireworks.ai/inference/v1",
53-
# ),
54-
# "Gemini Exp 1206": ModelConfig(
55-
# model_name="gemini-exp-1206",
56-
# api_key=st.secrets["GEMINI_API_KEY"],
57-
# base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
58-
# ),
29+
"Google Gemini": ModelConfig(
30+
model_name="models/gemini-2.0-flash",
31+
api_key=__import__("streamlit").secrets["GEMINI_API_KEY"],
32+
base_url=None,
33+
)
5934
}
35+
6036
sys_msg = SystemMessage(
61-
content="""You're an AI assistant specializing in data analysis with Snowflake SQL. When providing responses, strive to exhibit friendliness and adopt a conversational tone, similar to how a friend or tutor would communicate. Do not ask the user for schema or database details. You have access to the following tools:
62-
ALWAYS USE THE DATABASE_SCHEMA TOOL TO GET THE SCHEMA OF THE DATABASE BEFORE GENERATING SQL CODE.
63-
ALWAYS USE THE DATABASE_SCHEMA TOOL TO GET THE SCHEMA OF THE DATABASE BEFORE GENERATING SQL CODE.
64-
- Database_Schema: This tool allows you to search for database schema details when needed to generate the SQL code.
65-
- Internet_Search: This tool allows you to search the internet for snowflake sql related information when needed to generate the SQL code.
37+
content="""You're an AI assistant specializing in data analysis with Snowflake SQL. When providing responses, strive to be friendly and conversational (like a tutor or friend). You have access to the following tools:
38+
- Database_Schema: Search for database schema details before generating SQL code.
39+
- Internet_Search: Look up Snowflake SQL–related information when needed.
6640
"""
6741
)
68-
tools = [retriever_tool, search]
6942

70-
def create_agent(callback_handler: BaseCallbackHandler, model_name: str) -> StateGraph:
71-
config = model_configurations.get(model_name)
72-
if not config:
73-
raise ValueError(f"Unsupported model name: {model_name}")
43+
# Tools are imported from tools.py (see that file)
44+
from tools import retriever_tool, search
45+
tools = [retriever_tool, search]
7446

47+
def create_agent(callback_handler) -> StateGraph:
48+
config = model_configurations["Google Gemini"]
7549
if not config.api_key:
76-
raise ValueError(f"API key for model '{model_name}' is not set. Please check your environment variables or secrets configuration.")
77-
78-
llm = ChatOpenAI(
79-
model=config.model_name,
80-
api_key=config.api_key,
81-
callbacks=[callback_handler],
82-
streaming=True,
83-
base_url=config.base_url,
84-
# temperature=0.1,
85-
default_headers={"HTTP-Referer": "", "X-Title": "Snowchat"},
50+
raise ValueError("API key for Google Gemini is not set. Please check your secrets configuration.")
51+
llm = ChatGoogleGenerativeAI(
52+
model=config.model_name,
53+
google_api_key=config.api_key,
54+
callbacks=[callback_handler],
55+
temperature=0,
56+
base_url=config.base_url,
57+
streaming=True,
8658
)
87-
8859
llm_with_tools = llm.bind_tools(tools)
8960

9061
def llm_agent(state: MessagesState):
91-
return {"messages": [llm_with_tools.invoke([sys_msg] + state.messages)]}
62+
return {"messages": [llm_with_tools.invoke([sys_msg] + state.messages)]}
9263

9364
builder = StateGraph(MessagesState)
9465
builder.add_node("llm_agent", llm_agent)
9566
builder.add_node("tools", ToolNode(tools))
96-
9767
builder.add_edge(START, "llm_agent")
9868
builder.add_conditional_edges("llm_agent", tools_condition)
9969
builder.add_edge("tools", "llm_agent")
10070
builder.add_edge("llm_agent", END)
10171
react_graph = builder.compile(checkpointer=memory)
102-
103-
# png_data = react_graph.get_graph(xray=True).draw_mermaid_png()
104-
# with open("graph_2.png", "wb") as f:
105-
# f.write(png_data)
106-
107-
# image = Image.open(BytesIO(png_data))
108-
# st.image(image, caption="React Graph")
109-
11072
return react_graph
111-
112-

chain.py

Lines changed: 23 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,31 @@
1+
# chain.py
12
from dataclasses import dataclass, field
23
from operator import itemgetter
34
from typing import Any, Callable, Dict, Optional
4-
55
import streamlit as st
6-
from langchain.embeddings.openai import OpenAIEmbeddings
7-
from langchain.llms import OpenAI
6+
from langchain_community.embeddings import FakeEmbeddings
87
from langchain.prompts.prompt import PromptTemplate
98
from langchain.schema import format_document
109
from langchain.vectorstores import SupabaseVectorStore
11-
from langchain_anthropic import ChatAnthropic
12-
from langchain_community.chat_models import ChatOpenAI
1310
from langchain_core.messages import get_buffer_string
1411
from langchain_core.output_parsers import StrOutputParser
1512
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
16-
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
17-
13+
from langchain_google_genai import ChatGoogleGenerativeAI
1814
from supabase.client import Client, create_client
1915
from template import CONDENSE_QUESTION_PROMPT, QA_PROMPT
2016

2117
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
2218

2319
supabase_url = st.secrets["SUPABASE_URL"]
2420
supabase_key = st.secrets["SUPABASE_SERVICE_KEY"]
25-
supabase: Client = create_client(supabase_url, supabase_key)
26-
21+
client: Client = create_client(supabase_url, supabase_key)
2722

2823
@dataclass
2924
class ModelConfig:
3025
model_type: str
3126
secrets: Dict[str, Any]
3227
callback_handler: Optional[Callable] = field(default=None)
3328

34-
3529
class ModelWrapper:
3630
def __init__(self, config: ModelConfig):
3731
self.model_type = config.model_type
@@ -40,68 +34,17 @@ def __init__(self, config: ModelConfig):
4034
self.llm = self._setup_llm()
4135

4236
def _setup_llm(self):
43-
model_config = {
44-
"gpt-4o-mini": {
45-
"model_name": "gpt-4o-mini",
46-
"api_key": self.secrets["OPENAI_API_KEY"],
47-
},
48-
"gemma2-9b": {
49-
"model_name": "gemma2-9b-it",
50-
"api_key": self.secrets["GROQ_API_KEY"],
51-
"base_url": "https://api.groq.com/openai/v1",
52-
},
53-
"claude3-haiku": {
54-
"model_name": "claude-3-haiku-20240307",
55-
"api_key": self.secrets["ANTHROPIC_API_KEY"],
56-
},
57-
"mixtral-8x22b": {
58-
"model_name": "accounts/fireworks/models/mixtral-8x22b-instruct",
59-
"api_key": self.secrets["FIREWORKS_API_KEY"],
60-
"base_url": "https://api.fireworks.ai/inference/v1",
61-
},
62-
"llama-3.1-405b": {
63-
"model_name": "accounts/fireworks/models/llama-v3p1-405b-instruct",
64-
"api_key": self.secrets["FIREWORKS_API_KEY"],
65-
"base_url": "https://api.fireworks.ai/inference/v1",
66-
},
67-
}
68-
69-
config = model_config[self.model_type]
70-
71-
return (
72-
ChatOpenAI(
73-
model_name=config["model_name"],
74-
temperature=0.1,
75-
api_key=config["api_key"],
76-
max_tokens=700,
77-
callbacks=[self.callback_handler],
78-
streaming=True,
79-
base_url=config["base_url"]
80-
if config["model_name"] != "gpt-4o-mini"
81-
else None,
82-
default_headers={
83-
"HTTP-Referer": "https://snowchat.streamlit.app/",
84-
"X-Title": "Snowchat",
85-
},
86-
)
87-
if config["model_name"] != "claude-3-haiku-20240307"
88-
else (
89-
ChatAnthropic(
90-
model=config["model_name"],
91-
temperature=0.1,
92-
max_tokens=700,
93-
timeout=None,
94-
max_retries=2,
95-
callbacks=[self.callback_handler],
96-
streaming=True,
97-
)
98-
)
37+
return ChatGoogleGenerativeAI(
38+
model="models/gemini-2.0-flash",
39+
google_api_key=self.secrets["GEMINI_API_KEY"],
40+
temperature=0.1,
41+
callbacks=[self.callback_handler],
42+
max_tokens=700,
43+
streaming=True,
9944
)
10045

10146
def get_chain(self, vectorstore):
102-
def _combine_documents(
103-
docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"
104-
):
47+
def _combine_documents(docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"):
10548
doc_strings = [format_document(doc, document_prompt) for doc in docs]
10649
return document_separator.join(doc_strings)
10750

@@ -110,8 +53,7 @@ def _combine_documents(
11053
chat_history=lambda x: get_buffer_string(x["chat_history"])
11154
)
11255
| CONDENSE_QUESTION_PROMPT
113-
| OpenAI()
114-
| StrOutputParser(),
56+
| StrOutputParser()
11557
)
11658
_context = {
11759
"context": itemgetter("standalone_question")
@@ -120,33 +62,23 @@ def _combine_documents(
12062
"question": lambda x: x["standalone_question"],
12163
}
12264
conversational_qa_chain = _inputs | _context | QA_PROMPT | self.llm
123-
12465
return conversational_qa_chain
12566

126-
127-
def load_chain(model_name="qwen", callback_handler=None):
128-
embeddings = OpenAIEmbeddings(
129-
openai_api_key=st.secrets["OPENAI_API_KEY"], model="text-embedding-ada-002"
130-
)
67+
def load_chain(model_name="google_gemini", callback_handler=None):
68+
embeddings = FakeEmbeddings(size=768)
13169
vectorstore = SupabaseVectorStore(
13270
embedding=embeddings,
133-
client=supabase,
71+
client=client,
13472
table_name="documents",
13573
query_name="v_match_documents",
13674
)
137-
138-
model_type_mapping = {
139-
"gpt-4o-mini": "gpt-4o-mini",
140-
"gemma2-9b": "gemma2-9b",
141-
"claude3-haiku": "claude3-haiku",
142-
"mixtral-8x22b": "mixtral-8x22b",
143-
"llama-3.1-405b": "llama-3.1-405b",
144-
}
145-
146-
model_type = model_type_mapping.get(model_name.lower())
147-
if model_type is None:
148-
raise ValueError(f"Unsupported model name: {model_name}")
149-
75+
# Override the retriever with a dummy retriever to disable document retrieval.
76+
class DummyRetriever:
77+
def get_relevant_documents(self, query):
78+
return []
79+
vectorstore.as_retriever = lambda: DummyRetriever()
80+
81+
model_type = "google_gemini"
15082
config = ModelConfig(
15183
model_type=model_type, secrets=st.secrets, callback_handler=callback_handler
15284
)

data/.DS_Store

6 KB
Binary file not shown.

ingest.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# ingest.py
12
from typing import Any, Dict
23

34
import streamlit as st
@@ -9,40 +10,30 @@
910

1011
from supabase.client import Client, create_client
1112

12-
1313
class Secrets(BaseModel):
1414
SUPABASE_URL: str
1515
SUPABASE_SERVICE_KEY: str
1616
OPENAI_API_KEY: str
1717

18-
1918
class Config(BaseModel):
2019
chunk_size: int = 1000
2120
chunk_overlap: int = 0
2221
docs_dir: str = "docs/"
2322
docs_glob: str = "**/*.md"
2423

25-
2624
class DocumentProcessor:
2725
def __init__(self, secrets: Secrets, config: Config):
28-
self.client: Client = create_client(
29-
secrets.SUPABASE_URL, secrets.SUPABASE_SERVICE_KEY
30-
)
26+
self.client: Client = create_client(secrets.SUPABASE_URL, secrets.SUPABASE_SERVICE_KEY)
3127
self.loader = DirectoryLoader(config.docs_dir, glob=config.docs_glob)
32-
self.text_splitter = CharacterTextSplitter(
33-
chunk_size=config.chunk_size, chunk_overlap=config.chunk_overlap
34-
)
28+
self.text_splitter = CharacterTextSplitter(chunk_size=config.chunk_size, chunk_overlap=config.chunk_overlap)
3529
self.embeddings = OpenAIEmbeddings(openai_api_key=secrets.OPENAI_API_KEY)
3630

3731
def process(self) -> Dict[str, Any]:
3832
data = self.loader.load()
3933
texts = self.text_splitter.split_documents(data)
40-
vector_store = SupabaseVectorStore.from_documents(
41-
texts, self.embeddings, client=self.client
42-
)
34+
vector_store = SupabaseVectorStore.from_documents(texts, self.embeddings, client=self.client)
4335
return vector_store
4436

45-
4637
def run():
4738
secrets = Secrets(
4839
SUPABASE_URL=st.secrets["SUPABASE_URL"],
@@ -54,6 +45,5 @@ def run():
5445
result = doc_processor.process()
5546
return result
5647

57-
5848
if __name__ == "__main__":
5949
run()

0 commit comments

Comments
 (0)