Skip to content

Commit

Permalink
reddtriver v0
Browse files Browse the repository at this point in the history
  • Loading branch information
valenradovich committed Sep 21, 2024
1 parent 8dc557b commit 33a4946
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 22 deletions.
15 changes: 6 additions & 9 deletions agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ async def generate_response(self, context):
Anything between the `context` is retrieved from Reddit and is not a part of the conversation with the user. Today's date is {datetime.now().isoformat()}
"""


response = await self.gpt_client.chat.completions.create(
model="gpt-4o-mini",
messages=[
Expand All @@ -82,24 +81,22 @@ async def generate(self, query):
if rephrased_query:
search_results = await self.search_tool.search(rephrased_query)
if isinstance(search_results, list) and len(search_results) == 1 and isinstance(search_results[0], str) and search_results[0].startswith("Error performing search"):
return "I'm sorry, but I encountered an error while searching for information. Could you please try again or rephrase your question?"
return "I'm sorry, but I encountered an error while searching for information. Could you please try again or rephrase your question?", [], rephrased_query

# Convert search results to a comprehensive string format
context = []
for i, result in enumerate(search_results, 1):
context.append(f"{i}. Content: {result.get('pageContent', 'N/A')}")
context.append(f" URL: {result.get('url', 'N/A')}")
context.append(f" Title: {result.get('title', 'N/A')}")
# Add any other fields you want to include
context.append("") # Add a blank line between results
context.append(f" URL: {result.get('metadata', {}).get('url', 'N/A')}")
context.append(f" Title: {result.get('metadata', {}).get('title', 'N/A')}")
context.append("")

context_str = "\n".join(context)

response = await self.generate_response(context_str)
self.update_chat_history(query, response)
return response
return response, search_results, rephrased_query
else:
return "I'm sorry, but I couldn't process your request. Could you please rephrase your question or ask something else?"
return "I'm sorry, but I couldn't process your request. Could you please rephrase your question or ask something else?", [], None

def update_chat_history(self, query, response):
self.chat_history.append({"role": "user", "content": query})
Expand Down
7 changes: 4 additions & 3 deletions agents/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import logging
from agent import Reddtriever
from search import SearchTool
from utils import log_interaction

async def main():
reddtriever = Reddtriever()
Expand All @@ -15,10 +15,11 @@ async def main():
if query.lower() == 'quit':
break

response = await reddtriever.generate(query)
response, docs_retrieved, rephrased_query = await reddtriever.generate(query)
print("="*100)
print(response)

await log_interaction(query, docs_retrieved, response, rephrased_query)

if __name__ == "__main__":
asyncio.run(main())
asyncio.run(main())
15 changes: 5 additions & 10 deletions agents/search.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
from tavily import AsyncTavilyClient
from config import TAVILY_API_KEY
from sklearn.feature_extraction.text import TfidfVectorizer
Expand All @@ -10,16 +9,13 @@ def __init__(self):
self.vectorizer = TfidfVectorizer()

async def initialize(self):
# Any initialization that needs to be done asynchronously
pass

async def search(self, query: str) -> list:
try:
response = await self.client.search(query=query, include_domains=["reddit.com"])
documents = self._transform_response(response)
for doc in documents:
print(f"doc: {doc['pageContent']}")
#return self.re_rank(query, documents)
# return self.re_rank(query, documents) -> not used for now, just adds inference time
return documents
except Exception as e:
return [f"Error performing search: {str(e)}"]
Expand All @@ -39,19 +35,18 @@ def _transform_response(self, response: dict) -> list:
return documents

def re_rank(self, query: str, documents: list) -> list:
'''not used for now'''
'''not used for now, just adds inference time'''
if not documents:
return []

texts = [doc['pageContent'] for doc in documents]
for text in texts:
print(f"text: {text}")
print(f"query: {query}")

self.vectorizer.fit(texts)
doc_vectors = self.vectorizer.transform(texts)
query_vector = self.vectorizer.transform([query])

cosine_similarities = cosine_similarity(query_vector, doc_vectors).flatten()
# not working at all, vectors are too long for this kind of comparison. should use something else
cosine_similarities = cosine_similarity(query_vector, doc_vectors).flatten()

similarity_threshold = 0.5
filtered_documents = [
Expand Down
25 changes: 25 additions & 0 deletions agents/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import os
import json
import asyncio
from datetime import datetime

async def log_interaction(user_prompt, docs_retrieved, llm_answer, rephrased_query):
log_folder = "logs"
os.makedirs(log_folder, exist_ok=True)

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_file = os.path.join(log_folder, f"interaction_{timestamp}.json")

log_data = {
"timestamp": datetime.now().isoformat(),
"user_prompt": user_prompt,
"rephrased_query": rephrased_query,
"docs_retrieved": docs_retrieved,
"llm_answer": llm_answer
}

async def write_log():
with open(log_file, "w", encoding="utf-8") as f:
json.dump(log_data, f, ensure_ascii=False, indent=2)

asyncio.create_task(write_log())

0 comments on commit 33a4946

Please sign in to comment.