diff --git a/backend/app.py b/backend/app.py index 0639fc5..6ec5911 100644 --- a/backend/app.py +++ b/backend/app.py @@ -1,12 +1,15 @@ import streamlit as st -from src.file_search import ChatAgent +from src.file_search import FileAgent +from src.webpage_search import WebpageSearch from pathlib import Path -def init_chat_agent(): - if 'chat_agent' not in st.session_state: - st.session_state.chat_agent = ChatAgent() - -init_chat_agent() +if 'agent' not in st.session_state: + st.session_state.file_agent = FileAgent() + st.session_state.webpage_search = WebpageSearch() + st.session_state.file_agent.reset() + st.session_state.webpage_search.reset() + + # st.session_state.webpage_search.feed([""]) st.set_page_config( page_title="Chat Application", @@ -16,12 +19,14 @@ def init_chat_agent(): st.title("Chat Application :speech_balloon:") +col1, col2 = st.columns([1, 3]) + # File Upload Section -st.sidebar.header("Upload Files") -st.sidebar.write("Upload your files here to make them searchable by the chat agent.") +with col1: + st.header("Upload Files") + st.write("Upload your files here to make them searchable by the chat agent.") -def upload_files(): - uploaded_files = st.sidebar.file_uploader("Choose files", accept_multiple_files=True, type=["txt", "pdf", "docx"]) + uploaded_files = st.file_uploader("Choose files", accept_multiple_files=True, type=["txt", "pdf", "docx"]) if uploaded_files: Path("./data/files").mkdir(parents=True, exist_ok=True) for file in uploaded_files: @@ -29,26 +34,21 @@ def upload_files(): file_path = Path(f"./data/files/{file.name}") with open(file_path, "wb") as f: f.write(file_contents) - st.sidebar.success(f"Uploaded {file.name}") - -upload_files() - -# Chat Section -st.header("Ask the Chat Agent") - -def get_response(question: str): - if not question: - return "No question provided" - response = st.session_state.chat_agent.query(question) - return response - -user_input = st.text_input("Ask a question:") - -if st.button("Submit"): - if user_input: - with st.spinner('Getting response...'): - response = get_response(user_input) - st.success("Response:") - st.write(response) - else: - st.error("Please enter a question.") + st.success(f"Uploaded {file.name}") + st.session_state.file_agent.feed_files() + + +# Chat Agent Section +with col2: + st.header("Ask the Chat Agent") + + user_input = st.text_input("Ask a question:") + + if st.button("Submit"): + if user_input: + with st.spinner('Getting response...'): + response = st.session_state.file_agent.query(user_input) + st.success("Response:") + st.write(response) + else: + st.error("Please enter a question.") diff --git a/backend/src/file_search.py b/backend/src/file_search.py index 89b42f2..34cffa5 100644 --- a/backend/src/file_search.py +++ b/backend/src/file_search.py @@ -7,17 +7,20 @@ from dotenv import load_dotenv import os -class ChatAgent: - def __init__(self, use_gemini=False, gemini_model="models/gemini-1.0-pro"): +class FileAgent: + def __init__(self, use_gemini:bool=False, gemini_model:str="models/gemini-1.0-pro"): load_dotenv() + self.use_gemini: bool = use_gemini + self.gemini_model: str = gemini_model os.makedirs("./data", exist_ok=True) os.makedirs("./data/files", exist_ok=True) + def feed_files(self): self.all_tools = get_all_tools(folder_path="./data/files") - if use_gemini: - self.llm = self._init_gemini_llm(gemini_model) + if self.use_gemini: + self.llm = self._init_gemini_llm(self.gemini_model) else: self.llm = OpenAI(model="gpt-3.5-turbo") #TODO: Change to "gpt-4o" for the final version @@ -50,14 +53,16 @@ def query(self, question): return str(response) def reset(self): - os.rmdir("./data/files", ignore_errors=True) - self.agent_worker = self._create_agent_worker() - self.agent = AgentRunner(self.agent_worker) + for file in os.listdir("./data/files"): + os.remove(f"./data/files/{file}") + # if self.agent_worker: + # self.agent_worker = self._create_agent_worker() + # if self.agent: + # self.agent = AgentRunner(self.agent_worker) if __name__ == "__main__": - agent = ChatAgent(use_gemini=False) #? Set to True if using Gemini - response = agent.query("Whose repo is this? And what tools has he worked on?") - print(response) - agent.reset() + agent = FileAgent(use_gemini=False) #? Set to True if using Gemini + agent.feed_files() response = agent.query("What is the primary focus of his work?") print(response) + # agent.reset() \ No newline at end of file diff --git a/backend/src/webpage_search.py b/backend/src/webpage_search.py index 5b85153..3e17055 100644 --- a/backend/src/webpage_search.py +++ b/backend/src/webpage_search.py @@ -10,7 +10,7 @@ from llama_index.core.agent import FunctionCallingAgentWorker, AgentRunner -class WebSearch: +class WebpageSearch: def __init__(self, urls=[]): load_dotenv() @@ -20,7 +20,6 @@ def __init__(self, urls=[]): self.llm = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) self.search_results = [] self.all_tools = get_all_tools(folder_path="./data/webpages") - self.feed(urls=urls) def feed_urls(self, urls: List[str]): self.urls = urls @@ -77,7 +76,7 @@ def feed(self, urls: List[str]): if __name__ == "__main__": - ws = WebSearch() + ws = WebpageSearch() ws.reset() ws.feed(urls=['https://docs.llamaindex.ai/en/stable/examples/embeddings/jinaai_embeddings/']) res = ws.query("How to implement JinaAI's embedding in python?")