Skip to content

Commit

Permalink
up (#1603)
Browse files Browse the repository at this point in the history
* up

* up
  • Loading branch information
emrgnt-cmplxty authored Nov 18, 2024
1 parent 2a3f06a commit b963121
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 13 deletions.
2 changes: 1 addition & 1 deletion py/core/main/services/management_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,7 @@ async def get_conversation(
conversation_id: str,
branch_id: Optional[str] = None,
auth_user=None,
) -> Tuple[str, list[Message]]:
) -> Tuple[str, list[Message], list[dict]]:
return await self.logging_connection.get_conversation(
conversation_id, branch_id
)
Expand Down
19 changes: 13 additions & 6 deletions py/core/providers/logger/r2r_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,25 +655,32 @@ async def get_conversation(
# Get all messages for this branch
async with self.conn.execute(
"""
WITH RECURSIVE branch_messages(id, content, parent_id, depth, created_at) AS (
SELECT m.id, m.content, m.parent_id, 0, m.created_at
WITH RECURSIVE branch_messages(id, content, parent_id, depth, created_at, metadata) AS (
SELECT m.id, m.content, m.parent_id, 0, m.created_at, m.metadata
FROM messages m
JOIN message_branches mb ON m.id = mb.message_id
WHERE mb.branch_id = ? AND m.parent_id IS NULL
UNION
SELECT m.id, m.content, m.parent_id, bm.depth + 1, m.created_at
SELECT m.id, m.content, m.parent_id, bm.depth + 1, m.created_at, m.metadata
FROM messages m
JOIN message_branches mb ON m.id = mb.message_id
JOIN branch_messages bm ON m.parent_id = bm.id
WHERE mb.branch_id = ?
)
SELECT id, content, parent_id FROM branch_messages
SELECT id, content, parent_id, metadata FROM branch_messages
ORDER BY created_at ASC
""",
""",
(branch_id, branch_id),
) as cursor:
rows = await cursor.fetchall()
return [(row[0], Message.parse_raw(row[1])) for row in rows]
return [
(
row[0], # id
Message.parse_raw(row[1]), # message content
json.loads(row[3]) if row[3] else {}, # metadata
)
for row in rows
]

async def get_branches_overview(self, conversation_id: str) -> list[dict]:
if not self.conn:
Expand Down
15 changes: 10 additions & 5 deletions py/sdk/mixins/management.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ async def create_conversation(self) -> dict:
async def add_message(
self,
conversation_id: Union[str, UUID],
message: Message,
message: dict,
parent_id: Optional[str] = None,
metadata: Optional[dict[str, Any]] = None,
) -> dict:
Expand All @@ -716,9 +716,14 @@ async def add_message(
data["parent_id"] = parent_id
if metadata is not None:
data["metadata"] = metadata
return await self._make_request( # type: ignore
"POST", f"add_message/{str(conversation_id)}", data=data
)
if len(data) == 1:
return await self._make_request( # type: ignore
"POST", f"add_message/{str(conversation_id)}", json=data
)
else:
return await self._make_request( # type: ignore
"POST", f"add_message/{str(conversation_id)}", data=data
)

async def update_message(
self,
Expand Down Expand Up @@ -755,7 +760,7 @@ async def update_message_metadata(
dict: The response from the server.
"""
return await self._make_request( # type: ignore
"PATCH", f"messages/{message_id}/metadata", data=metadata
"PATCH", f"messages/{message_id}/metadata", json=metadata
)

async def branches_overview(
Expand Down
2 changes: 1 addition & 1 deletion py/shared/api/models/management/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ class AddUserResponse(BaseModel):
WrappedUserOverviewResponse = PaginatedResultsWrapper[
list[UserOverviewResponse]
]
WrappedConversationResponse = ResultsWrapper[list[Tuple[str, Message]]]
WrappedConversationResponse = ResultsWrapper[list[Tuple[str, Message, dict]]]
WrappedDocumentOverviewResponse = PaginatedResultsWrapper[
list[DocumentOverviewResponse]
]
Expand Down

0 comments on commit b963121

Please sign in to comment.