Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Personas memories fix #1917

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backend/routers/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from utils.other import endpoints as auth
from models.app import App
from utils.other.storage import upload_plugin_logo, delete_plugin_logo, upload_app_thumbnail, get_app_thumbnail_url
from utils.social import get_twitter_profile, get_twitter_timeline, verify_latest_tweet, \
from utils.social import get_twitter_profile, verify_latest_tweet, \
upsert_persona_from_twitter_profile, add_twitter_to_persona

router = APIRouter()
Expand Down
10 changes: 6 additions & 4 deletions backend/utils/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def get_available_apps(uid: str, include_reviews: bool = False) -> List[App]:
if cachedApps := get_generic_cache('get_public_approved_apps_data'):
print('get_public_approved_plugins_data from cache----------------------------')
public_approved_data = cachedApps
public_approved_data = get_public_approved_apps_db()
public_unapproved_data = get_public_unapproved_apps(uid)
private_data = get_private_apps(uid)
pass
Expand Down Expand Up @@ -392,7 +391,7 @@ async def generate_persona_prompt(uid: str, persona: dict):

tweets = None
if "twitter" in persona['connected_accounts']:
print("twitter in connected accounts---------------------------")
print("twitter is in connected accounts")
# Get latest tweets
tweets = await get_twitter_timeline(persona['twitter']['username'])
tweets = [{'tweet': tweet['text'], 'posted_at': tweet['created_at']} for tweet in tweets['timeline']]
Expand Down Expand Up @@ -472,13 +471,15 @@ def sync_update_persona_prompt(persona: dict):
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(update_persona_prompt(persona))
except Exception as e:
print(f"Error in update_persona_prompt for persona {persona.get('id', 'unknown')}: {str(e)}")
return None
finally:
loop.close()


async def update_persona_prompt(persona: dict):
"""Update a persona's chat prompt with latest facts and memories."""

# Get latest facts and user info
facts = get_facts(persona['uid'], limit=250)
user_name = get_user_name(persona['uid'])
Expand Down Expand Up @@ -531,7 +532,7 @@ async def update_persona_prompt(persona: dict):

# Add a guideline about tweets if they exist
if condensed_tweets:
persona_prompt += "7. Utilize condensed tweets to enhance authenticity, incorporating common expressions, opinions, and phrasing from {user_name}s social media presence.\n"
persona_prompt += "7. Utilize condensed tweets to enhance authenticity, incorporating common expressions, opinions, and phrasing from {user_name}'s social media presence.\n"

persona_prompt += f"""
**Rules:**
Expand All @@ -556,6 +557,7 @@ async def update_persona_prompt(persona: dict):

persona['persona_prompt'] = persona_prompt
persona['updated_at'] = datetime.now(timezone.utc)

update_persona_in_db(persona)
delete_app_cache_by_id(persona['id'])

Expand Down
24 changes: 14 additions & 10 deletions backend/utils/memories/process_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from models.task import Task, TaskStatus, TaskAction, TaskActionProvider
from models.trend import Trend
from models.notification_message import NotificationMessage
from utils.apps import get_available_apps, update_persona_prompt, sync_update_persona_prompt
from utils.apps import get_available_apps, sync_update_persona_prompt
from utils.llm import obtain_emotional_message, retrieve_metadata_fields_from_transcript
from utils.llm import summarize_open_glass, get_transcript_structure, generate_embedding, \
get_plugin_result, should_discard_memory, summarize_experience_text, new_facts_extractor, \
Expand Down Expand Up @@ -174,6 +174,18 @@ def save_structured_vector(uid: str, memory: Memory, update_only: bool = False):
update_vector_metadata(uid, memory.id, metadata)


def _update_personas_async(uid: str):
print(f"[PERSONAS] Starting persona updates in background thread for uid={uid}")
personas = get_omi_personas_by_uid_db(uid)
if personas:
threads = []
for persona in personas:
threads.append(threading.Thread(target=sync_update_persona_prompt, args=(persona,)))

[t.start() for t in threads]
[t.join() for t in threads]


def process_memory(
uid: str, language_code: str, memory: Union[Memory, CreateMemory, WorkflowCreateMemory],
force_process: bool = False, is_reprocess: bool = False
Expand All @@ -193,15 +205,7 @@ def process_memory(
threading.Thread(target=memory_created_webhook, args=(uid, memory,)).start()
# TODO: Bad code, cause the websocket was drop, need to check it carefully before enabling.
# Update persona prompts with new memory
# personas = get_omi_personas_by_uid_db(uid)
# if personas:
# threads = []
# print('updating personas after memory creation')
# for persona in personas:
# threads.append(threading.Thread(target=sync_update_persona_prompt, args=(persona,)))
#
# [t.start() for t in threads]
# [t.join() for t in threads]
threading.Thread(target=_update_personas_async, args=(uid,)).start()

# TODO: trigger external integrations here too

Expand Down
6 changes: 3 additions & 3 deletions backend/utils/social.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ async def get_twitter_profile(handle: str) -> Dict[str, Any]:
"X-RapidAPI-Host": rapid_api_host
}

async with httpx.AsyncClient() as client:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url, headers=headers)
response.raise_for_status()
data = response.json()
Expand All @@ -36,7 +36,7 @@ async def get_twitter_timeline(handle: str) -> Dict[str, Any]:
"X-RapidAPI-Host": rapid_api_host
}

async with httpx.AsyncClient() as client:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url, headers=headers)
response.raise_for_status()
data = response.json()
Expand All @@ -52,7 +52,7 @@ async def verify_latest_tweet(username: str, handle: str) -> Dict[str, Any]:
"X-RapidAPI-Host": rapid_api_host
}

async with httpx.AsyncClient() as client:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url, headers=headers)
response.raise_for_status()
data = response.json()
Expand Down