diff --git a/rag_demo/free_use_manager.py b/rag_demo/free_use_manager.py
new file mode 100644
index 0000000..b574d33
--- /dev/null
+++ b/rag_demo/free_use_manager.py
@@ -0,0 +1,33 @@
+# Manages number of free questions permitted before forcing user to supply their own OpenAI Key
+
+import streamlit as st
+
+def _setup_free_questions_count():
+ if "FREE_QUESTIONS_REMAINING" not in st.session_state:
+ try:
+ st.session_state["FREE_QUESTIONS_REMAINING"] = st.secrets["FREE_QUESTIONS_PER_SESSION"]
+ except:
+ st.session_state["FREE_QUESTIONS_REMAINING"] = 3
+
+def free_questions_exhausted()-> bool:
+
+ _setup_free_questions_count()
+
+ remaining = st.session_state["FREE_QUESTIONS_REMAINING"]
+ return remaining <= 0
+
+def user_supplied_openai_key_unavailable()-> bool:
+ if "USER_OPENAI_KEY" not in st.session_state:
+ return True
+ uok = st.session_state["USER_OPENAI_KEY"]
+ if uok is None or uok == "":
+ return True
+ return False
+
+def decrement_free_questions():
+
+ _setup_free_questions_count()
+
+ remaining = st.session_state["FREE_QUESTIONS_REMAINING"]
+ if remaining > 0:
+ st.session_state["FREE_QUESTIONS_REMAINING"] = remaining - 1
diff --git a/rag_demo/graph_cypher_chain.py b/rag_demo/graph_cypher_chain.py
index 220feda..55af946 100644
--- a/rag_demo/graph_cypher_chain.py
+++ b/rag_demo/graph_cypher_chain.py
@@ -56,8 +56,11 @@
url = st.secrets["NEO4J_URI"]
username = st.secrets["NEO4J_USERNAME"]
password = st.secrets["NEO4J_PASSWORD"]
-openai_key = st.secrets["OPENAI_API_KEY"]
-llm_key = st.secrets["OPENAI_API_KEY"]
+
+if "USER_OPENAI_API_KEY" in st.session_state:
+ openai_key = st.session_state["USER_OPENAI_API_KEY"]
+else:
+ openai_key = st.secrets["OPENAI_API_KEY"]
graph = Neo4jGraph(
url=url,
diff --git a/rag_demo/main.py b/rag_demo/main.py
index 2355d27..53364b5 100644
--- a/rag_demo/main.py
+++ b/rag_demo/main.py
@@ -1,4 +1,5 @@
from analytics import track
+from free_use_manager import free_questions_exhausted, user_supplied_openai_key_unavailable, decrement_free_questions
from langchain.globals import set_llm_cache
from langchain.cache import InMemoryCache
from langchain_community.callbacks import HumanApprovalCallbackHandler
@@ -12,6 +13,7 @@
# Anonymous Session Analytics
if "SESSION_ID" not in st.session_state:
+ # Track method will create and add session id to state on first run
track(
"rag_demo",
"appStarted",
@@ -44,6 +46,10 @@
st.markdown(message["content"], unsafe_allow_html=True)
# User input - switch between sidebar sample quick select or actual user input. Clunky but works.
+if free_questions_exhausted() and user_supplied_openai_key_unavailable():
+ st.warning("Thank you for trying out the Neo4j Rag Demo. Please input your OpenAI Key in the sidebar to continue asking questions.")
+ st.stop()
+
if "sample" in st.session_state and st.session_state["sample"] is not None:
user_input = st.session_state["sample"]
else:
@@ -86,6 +92,8 @@
new_message = {"role": "ai", "content": content}
st.session_state.messages.append(new_message)
+ decrement_free_questions()
+
message_placeholder.markdown(content)
# Reinsert user chat input if sample quick select was previously used.
diff --git a/rag_demo/sidebar.py b/rag_demo/sidebar.py
index f769607..4607da3 100644
--- a/rag_demo/sidebar.py
+++ b/rag_demo/sidebar.py
@@ -13,6 +13,14 @@ def ChangeButtonColour(wgt_txt, wch_hex_colour = '12px'):
def sidebar():
with st.sidebar:
+
+ with st.expander("OpenAI Key"):
+ new_oak = st.text_input("Your OpenAI API Key")
+ # if "USER_OPENAI_KEY" not in st.session_state:
+ # st.session_state["USER_OPENAI_KEY"] = new_oak
+ # else:
+ st.session_state["USER_OPENAI_KEY"] = new_oak
+
st.markdown(f"""This the schema in which the EDGAR filings are stored in Neo4j: \n
""", unsafe_allow_html=True)
st.markdown(f"""This is how the Chatbot flow goes: \n
""", unsafe_allow_html=True)
diff --git a/rag_demo/vector_chain.py b/rag_demo/vector_chain.py
index 08dd0e7..b715002 100644
--- a/rag_demo/vector_chain.py
+++ b/rag_demo/vector_chain.py
@@ -32,7 +32,12 @@
input_variables=["input","context"], template=VECTOR_PROMPT_TEMPLATE
)
-EMBEDDING_MODEL = OpenAIEmbeddings()
+if "USER_OPENAI_API_KEY" in st.session_state:
+ openai_key = st.session_state["USER_OPENAI_API_KEY"]
+else:
+ openai_key = st.secrets["OPENAI_API_KEY"]
+
+EMBEDDING_MODEL = OpenAIEmbeddings(openai_api_key=openai_key)
MEMORY = ConversationBufferMemory(memory_key="chat_history", input_key='question', output_key='answer', return_messages=True)
index_name = "form_10k_chunks"
@@ -78,7 +83,7 @@
vector_retriever = vector_store.as_retriever()
vector_chain = RetrievalQAWithSourcesChain.from_chain_type(
- ChatOpenAI(temperature=0),
+ ChatOpenAI(temperature=0, openai_api_key=openai_key),
chain_type="stuff",
retriever=vector_retriever,
memory=MEMORY,
@@ -119,7 +124,7 @@ def get_results(question)-> str:
return result
-# Using the vector store directly. But this will blow out the token count
+# Using the vector store directly. But this could blow out the token count
# @retry(tries=5, delay=5)
# def get_results(question)-> str:
# """Generate response using Neo4jVector using vector index only
diff --git a/rag_demo/vector_graph_chain.py b/rag_demo/vector_graph_chain.py
index 714ac25..0bba76d 100644
--- a/rag_demo/vector_graph_chain.py
+++ b/rag_demo/vector_graph_chain.py
@@ -1,4 +1,3 @@
-from json import loads, dumps
from langchain.prompts.prompt import PromptTemplate
from langchain.vectorstores.neo4j_vector import Neo4jVector
from langchain.chains import RetrievalQAWithSourcesChain
@@ -22,7 +21,12 @@
input_variables=["question"], template=VECTOR_GRAPH_PROMPT_TEMPLATE
)
-EMBEDDING_MODEL = OpenAIEmbeddings()
+if "USER_OPENAI_API_KEY" in st.session_state:
+ openai_key = st.session_state["USER_OPENAI_API_KEY"]
+else:
+ openai_key = st.secrets["OPENAI_API_KEY"]
+
+EMBEDDING_MODEL = OpenAIEmbeddings(openai_api_key=openai_key)
MEMORY = ConversationBufferMemory(memory_key="chat_history", input_key='question', output_key='answer', return_messages=True)
index_name = "form_10k_chunks"
@@ -93,7 +97,7 @@
vector_graph_retriever = vector_store.as_retriever()
vector_graph_chain = RetrievalQAWithSourcesChain.from_chain_type(
- ChatOpenAI(temperature=0),
+ ChatOpenAI(temperature=0, openai_api_key=openai_key),
chain_type="stuff",
retriever=vector_graph_retriever,
memory=MEMORY,