diff --git a/rag_demo/free_use_manager.py b/rag_demo/free_use_manager.py new file mode 100644 index 0000000..b574d33 --- /dev/null +++ b/rag_demo/free_use_manager.py @@ -0,0 +1,33 @@ +# Manages number of free questions permitted before forcing user to supply their own OpenAI Key + +import streamlit as st + +def _setup_free_questions_count(): + if "FREE_QUESTIONS_REMAINING" not in st.session_state: + try: + st.session_state["FREE_QUESTIONS_REMAINING"] = st.secrets["FREE_QUESTIONS_PER_SESSION"] + except: + st.session_state["FREE_QUESTIONS_REMAINING"] = 3 + +def free_questions_exhausted()-> bool: + + _setup_free_questions_count() + + remaining = st.session_state["FREE_QUESTIONS_REMAINING"] + return remaining <= 0 + +def user_supplied_openai_key_unavailable()-> bool: + if "USER_OPENAI_KEY" not in st.session_state: + return True + uok = st.session_state["USER_OPENAI_KEY"] + if uok is None or uok == "": + return True + return False + +def decrement_free_questions(): + + _setup_free_questions_count() + + remaining = st.session_state["FREE_QUESTIONS_REMAINING"] + if remaining > 0: + st.session_state["FREE_QUESTIONS_REMAINING"] = remaining - 1 diff --git a/rag_demo/graph_cypher_chain.py b/rag_demo/graph_cypher_chain.py index 220feda..55af946 100644 --- a/rag_demo/graph_cypher_chain.py +++ b/rag_demo/graph_cypher_chain.py @@ -56,8 +56,11 @@ url = st.secrets["NEO4J_URI"] username = st.secrets["NEO4J_USERNAME"] password = st.secrets["NEO4J_PASSWORD"] -openai_key = st.secrets["OPENAI_API_KEY"] -llm_key = st.secrets["OPENAI_API_KEY"] + +if "USER_OPENAI_API_KEY" in st.session_state: + openai_key = st.session_state["USER_OPENAI_API_KEY"] +else: + openai_key = st.secrets["OPENAI_API_KEY"] graph = Neo4jGraph( url=url, diff --git a/rag_demo/main.py b/rag_demo/main.py index 2355d27..53364b5 100644 --- a/rag_demo/main.py +++ b/rag_demo/main.py @@ -1,4 +1,5 @@ from analytics import track +from free_use_manager import free_questions_exhausted, user_supplied_openai_key_unavailable, decrement_free_questions from langchain.globals import set_llm_cache from langchain.cache import InMemoryCache from langchain_community.callbacks import HumanApprovalCallbackHandler @@ -12,6 +13,7 @@ # Anonymous Session Analytics if "SESSION_ID" not in st.session_state: + # Track method will create and add session id to state on first run track( "rag_demo", "appStarted", @@ -44,6 +46,10 @@ st.markdown(message["content"], unsafe_allow_html=True) # User input - switch between sidebar sample quick select or actual user input. Clunky but works. +if free_questions_exhausted() and user_supplied_openai_key_unavailable(): + st.warning("Thank you for trying out the Neo4j Rag Demo. Please input your OpenAI Key in the sidebar to continue asking questions.") + st.stop() + if "sample" in st.session_state and st.session_state["sample"] is not None: user_input = st.session_state["sample"] else: @@ -86,6 +92,8 @@ new_message = {"role": "ai", "content": content} st.session_state.messages.append(new_message) + decrement_free_questions() + message_placeholder.markdown(content) # Reinsert user chat input if sample quick select was previously used. diff --git a/rag_demo/sidebar.py b/rag_demo/sidebar.py index f769607..4607da3 100644 --- a/rag_demo/sidebar.py +++ b/rag_demo/sidebar.py @@ -13,6 +13,14 @@ def ChangeButtonColour(wgt_txt, wch_hex_colour = '12px'): def sidebar(): with st.sidebar: + + with st.expander("OpenAI Key"): + new_oak = st.text_input("Your OpenAI API Key") + # if "USER_OPENAI_KEY" not in st.session_state: + # st.session_state["USER_OPENAI_KEY"] = new_oak + # else: + st.session_state["USER_OPENAI_KEY"] = new_oak + st.markdown(f"""This the schema in which the EDGAR filings are stored in Neo4j: \n """, unsafe_allow_html=True) st.markdown(f"""This is how the Chatbot flow goes: \n """, unsafe_allow_html=True) diff --git a/rag_demo/vector_chain.py b/rag_demo/vector_chain.py index 08dd0e7..b715002 100644 --- a/rag_demo/vector_chain.py +++ b/rag_demo/vector_chain.py @@ -32,7 +32,12 @@ input_variables=["input","context"], template=VECTOR_PROMPT_TEMPLATE ) -EMBEDDING_MODEL = OpenAIEmbeddings() +if "USER_OPENAI_API_KEY" in st.session_state: + openai_key = st.session_state["USER_OPENAI_API_KEY"] +else: + openai_key = st.secrets["OPENAI_API_KEY"] + +EMBEDDING_MODEL = OpenAIEmbeddings(openai_api_key=openai_key) MEMORY = ConversationBufferMemory(memory_key="chat_history", input_key='question', output_key='answer', return_messages=True) index_name = "form_10k_chunks" @@ -78,7 +83,7 @@ vector_retriever = vector_store.as_retriever() vector_chain = RetrievalQAWithSourcesChain.from_chain_type( - ChatOpenAI(temperature=0), + ChatOpenAI(temperature=0, openai_api_key=openai_key), chain_type="stuff", retriever=vector_retriever, memory=MEMORY, @@ -119,7 +124,7 @@ def get_results(question)-> str: return result -# Using the vector store directly. But this will blow out the token count +# Using the vector store directly. But this could blow out the token count # @retry(tries=5, delay=5) # def get_results(question)-> str: # """Generate response using Neo4jVector using vector index only diff --git a/rag_demo/vector_graph_chain.py b/rag_demo/vector_graph_chain.py index 714ac25..0bba76d 100644 --- a/rag_demo/vector_graph_chain.py +++ b/rag_demo/vector_graph_chain.py @@ -1,4 +1,3 @@ -from json import loads, dumps from langchain.prompts.prompt import PromptTemplate from langchain.vectorstores.neo4j_vector import Neo4jVector from langchain.chains import RetrievalQAWithSourcesChain @@ -22,7 +21,12 @@ input_variables=["question"], template=VECTOR_GRAPH_PROMPT_TEMPLATE ) -EMBEDDING_MODEL = OpenAIEmbeddings() +if "USER_OPENAI_API_KEY" in st.session_state: + openai_key = st.session_state["USER_OPENAI_API_KEY"] +else: + openai_key = st.secrets["OPENAI_API_KEY"] + +EMBEDDING_MODEL = OpenAIEmbeddings(openai_api_key=openai_key) MEMORY = ConversationBufferMemory(memory_key="chat_history", input_key='question', output_key='answer', return_messages=True) index_name = "form_10k_chunks" @@ -93,7 +97,7 @@ vector_graph_retriever = vector_store.as_retriever() vector_graph_chain = RetrievalQAWithSourcesChain.from_chain_type( - ChatOpenAI(temperature=0), + ChatOpenAI(temperature=0, openai_api_key=openai_key), chain_type="stuff", retriever=vector_graph_retriever, memory=MEMORY,