From a00fab88e7a036e51f64d1c2b608dec4e49f9d91 Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Fri, 29 Nov 2024 21:43:00 -0600 Subject: [PATCH] Update community tests --- .../GraphsIntegrationSuperUser.test.ts | 34 ++++++ .../r2rV2ClientIntegrationSuperUser.test.ts | 2 +- py/core/main/api/v3/graph_router.py | 53 +++++---- py/core/main/services/kg_service.py | 23 ++-- py/core/providers/database/graph.py | 106 ++++++++++-------- 5 files changed, 138 insertions(+), 80 deletions(-) diff --git a/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts b/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts index 3341ca818..68a510337 100644 --- a/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts +++ b/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts @@ -386,6 +386,40 @@ describe("r2rClient V3 Graphs Integration Tests", () => { expect(response.results.predicate).toBe("marries"); }); + test("Update the community", async () => { + const response = await client.graphs.updateCommunity({ + collectionId: collectionId, + communityId: communityId, + name: "Rodion Romanovich Raskolnikov and Avdotya Romanovna Raskolnikova Community", + summary: + "Rodion and Avdotya are siblings, the children of Pulcheria Alexandrovna Raskolnikova", + }); + + expect(response.results).toBeDefined(); + expect(response.results.name).toBe( + "Rodion Romanovich Raskolnikov and Avdotya Romanovna Raskolnikova Community", + ); + expect(response.results.summary).toBe( + "Rodion and Avdotya are siblings, the children of Pulcheria Alexandrovna Raskolnikova", + ); + }); + + test("Retrieve the updated community", async () => { + const response = await client.graphs.getCommunity({ + collectionId: collectionId, + communityId: communityId, + }); + + expect(response.results).toBeDefined(); + expect(response.results.id).toBe(communityId); + expect(response.results.name).toBe( + "Rodion Romanovich Raskolnikov and Avdotya Romanovna Raskolnikova Community", + ); + expect(response.results.summary).toBe( + "Rodion and Avdotya are siblings, the children of Pulcheria Alexandrovna Raskolnikova", + ); + }); + test("Delete the community", async () => { const response = await client.graphs.deleteCommunity({ collectionId: collectionId, diff --git a/js/sdk/__tests__/r2rV2ClientIntegrationSuperUser.test.ts b/js/sdk/__tests__/r2rV2ClientIntegrationSuperUser.test.ts index 594663799..ddeb9a365 100644 --- a/js/sdk/__tests__/r2rV2ClientIntegrationSuperUser.test.ts +++ b/js/sdk/__tests__/r2rV2ClientIntegrationSuperUser.test.ts @@ -129,7 +129,7 @@ describe("r2rClient Integration Tests", () => { metadatas: [{ title: "raskolnikov.txt" }, { title: "karamozov.txt" }], }), ).resolves.not.toThrow(); - }); + }, 10000); test("Ingest files in folder", async () => { const files = ["examples/data/folder"]; diff --git a/py/core/main/api/v3/graph_router.py b/py/core/main/api/v3/graph_router.py index cdd23c724..eeb695fae 100644 --- a/py/core/main/api/v3/graph_router.py +++ b/py/core/main/api/v3/graph_router.py @@ -1204,9 +1204,6 @@ async def create_community( rating_explanation: Optional[str] = Body( default="", description="Explanation for the rating" ), - attributes: Optional[dict] = Body( - default=None, description="Attributes for the community" - ), auth_user=Depends(self.providers.auth.auth_wrapper), ) -> WrappedCommunityResponse: """ @@ -1302,9 +1299,15 @@ async def get_communities( ) -> WrappedCommunitiesResponse: """ Lists all communities in the graph with pagination support. - - By default, all attributes are returned, but this can be limited using the `attributes` parameter. """ + if ( + not auth_user.is_superuser + and collection_id not in auth_user.graph_ids + ): + raise R2RException( + "The currently authenticated user does not have access to the specified graph.", + 403, + ) communities, count = ( await self.providers.database.graph_handler.get_communities( @@ -1371,23 +1374,29 @@ async def get_community( ) -> WrappedCommunityResponse: """ Retrieves a specific community by its ID. - - By default, all attributes are returned, but this can be limited using the `attributes` parameter. """ - if not auth_user.is_superuser: + if ( + not auth_user.is_superuser + and collection_id not in auth_user.graph_ids + ): raise R2RException( - "Only superusers can access this endpoint.", 403 + "The currently authenticated user does not have access to the specified graph.", + 403, ) - return await self.services[ + results = await self.services[ "kg" ].providers.database.graph_handler.communities.get( - graph_id=collection_id, - community_id=community_id, - auth_user=auth_user, + parent_id=collection_id, + community_ids=[community_id], + store_type="graph", offset=0, limit=1, ) + print(f"results: {results}") + if len(results) == 0 or len(results[0]) == 0: + raise R2RException("Relationship not found", 404) + return results[0][0] @self.router.delete( "/graphs/{collection_id}/communities/{community_id}", @@ -1519,29 +1528,29 @@ async def update_community( name: Optional[str] = Body(None), summary: Optional[str] = Body(None), findings: Optional[list[str]] = Body(None), - rating: Optional[float] = Body(None), + rating: Optional[float] = Body(default=None, ge=1, le=10), rating_explanation: Optional[str] = Body(None), - attributes: Optional[dict] = Body(None), auth_user=Depends(self.providers.auth.auth_wrapper), ) -> WrappedCommunityResponse: """ - Updates an existing community's metadata and properties. + Updates an existing community in the graph. """ - if not auth_user.is_superuser: + if ( + not auth_user.is_superuser + and collection_id not in auth_user.graph_ids + ): raise R2RException( - "Only superusers can update communities", 403 + "The currently authenticated user does not have access to the specified graph.", + 403, ) - return await self.services["kg"].update_community_v3( - id=collection_id, + return await self.services["kg"].update_community( community_id=community_id, name=name, summary=summary, findings=findings, rating=rating, rating_explanation=rating_explanation, - attributes=attributes, - auth_user=auth_user, ) @self.router.post( diff --git a/py/core/main/services/kg_service.py b/py/core/main/services/kg_service.py index b4ea2aa91..f29553f32 100644 --- a/py/core/main/services/kg_service.py +++ b/py/core/main/services/kg_service.py @@ -409,27 +409,25 @@ async def create_community( @telemetry_event("update_community") async def update_community( self, - id: UUID, community_id: UUID, name: Optional[str], summary: Optional[str], findings: Optional[list[str]], rating: Optional[float], rating_explanation: Optional[str], - ): + ) -> Community: + summary_embedding = None if summary is not None: - embedding = str( + summary_embedding = str( await self.providers.embedding.async_get_embedding(summary) ) - else: - embedding = None return await self.providers.database.graph_handler.communities.update( - id=id, community_id=community_id, + store_type="graph", # type: ignore name=name, summary=summary, - embedding=embedding, + summary_embedding=summary_embedding, findings=findings, rating=rating, rating_explanation=rating_explanation, @@ -446,16 +444,16 @@ async def delete_community( community_id=community_id, ) - @telemetry_event("list_communities_v3") - async def list_communities_v3( + @telemetry_event("list_communities") + async def list_communities( self, - id: UUID, + collection_id: UUID, offset: int, limit: int, - **kwargs, ): return await self.providers.database.graph_handler.communities.get( - id=id, + parent_id=collection_id, + store_type="graph", # type: ignore offset=offset, limit=limit, ) @@ -473,7 +471,6 @@ async def get_communities( ): return await self.providers.database.graph_handler.get_communities( collection_id=collection_id, - levels=levels, community_ids=community_ids, offset=offset or 0, limit=limit or -1, diff --git a/py/core/providers/database/graph.py b/py/core/providers/database/graph.py index 4896e651a..6ceefe7b1 100644 --- a/py/core/providers/database/graph.py +++ b/py/core/providers/database/graph.py @@ -870,9 +870,9 @@ async def create( query = f""" INSERT INTO {self._get_table_name(table_name)} - (community_id, name, summary, findings, rating, rating_explanation, description_embedding) + (collection_id, name, summary, findings, rating, rating_explanation, description_embedding) VALUES ($1, $2, $3, $4, $5, $6, $7) - RETURNING id, community_id, name, summary, findings, rating, rating_explanation, created_at, updated_at + RETURNING id, collection_id, name, summary, findings, rating, rating_explanation, created_at, updated_at """ params = [ @@ -893,7 +893,7 @@ async def create( return Community( id=result["id"], - community_id=result["community_id"], + collection_id=result["collection_id"], name=result["name"], summary=result["summary"], findings=result["findings"], @@ -910,49 +910,83 @@ async def create( async def update( self, - id: UUID, community_id: UUID, - name: Optional[str], - summary: Optional[str], - embedding: Optional[str], - findings: Optional[list[str]], - rating: Optional[float], - rating_explanation: Optional[str], + store_type: StoreType, + name: Optional[str] = None, + summary: Optional[str] = None, + summary_embedding: Optional[list[float] | str] = None, + findings: Optional[list[str]] = None, + rating: Optional[float] = None, + rating_explanation: Optional[str] = None, ) -> Community: - + table_name = "graph_community" update_fields = [] - params: list[Any] = [community_id] # type: ignore + params: list[Any] = [] + param_index = 1 + if name is not None: - update_fields.append(f"name = ${len(params)+1}") + update_fields.append(f"name = ${param_index}") params.append(name) + param_index += 1 if summary is not None: - update_fields.append(f"summary = ${len(params)+1}") + update_fields.append(f"summary = ${param_index}") params.append(summary) + param_index += 1 - if embedding is not None: - update_fields.append(f"description_embedding = ${len(params)+1}") - params.append(embedding) + if summary_embedding is not None: + update_fields.append(f"description_embedding = ${param_index}") + params.append(summary_embedding) + param_index += 1 if findings is not None: - update_fields.append(f"findings = ${len(params)+1}") + update_fields.append(f"findings = ${param_index}") params.append(findings) + param_index += 1 if rating is not None: - update_fields.append(f"rating = ${len(params)+1}") + update_fields.append(f"rating = ${param_index}") params.append(rating) + param_index += 1 if rating_explanation is not None: - update_fields.append(f"rating_explanation = ${len(params)+1}") + update_fields.append(f"rating_explanation = ${param_index}") params.append(rating_explanation) + param_index += 1 - update_fields.append(f"updated_at = CURRENT_TIMESTAMP") + if not update_fields: + raise R2RException(status_code=400, message="No fields to update") - QUERY = f""" - UPDATE {self._get_table_name("graph_community")} SET {", ".join(update_fields)} WHERE id = $1 - RETURNING id, graph_id, name, summary, findings, rating, rating_explanation, metadata, level, created_by, updated_by, updated_at + update_fields.append("updated_at = NOW()") + params.append(community_id) + + query = f""" + UPDATE {self._get_table_name(table_name)} + SET {", ".join(update_fields)} + WHERE id = ${param_index}\ + RETURNING id, community_id, name, summary, findings, rating, rating_explanation, created_at, updated_at """ - return await self.connection_manager.fetchrow_query(QUERY, params) + try: + result = await self.connection_manager.fetchrow_query( + query, params + ) + + return Community( + id=result["id"], + community_id=result["community_id"], + name=result["name"], + summary=result["summary"], + findings=result["findings"], + rating=result["rating"], + rating_explanation=result["rating_explanation"], + created_at=result["created_at"], + updated_at=result["updated_at"], + ) + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"An error occurred while updating the community: {e}", + ) async def delete( self, @@ -990,7 +1024,7 @@ async def get( # Do we ever want to get communities from document store? table_name = "graph_community" - conditions = ["graph_id = $1"] + conditions = ["collection_id = $1"] params: list[Any] = [parent_id] param_index = 2 @@ -1005,8 +1039,8 @@ async def get( param_index += 1 select_fields = """ - id, graph_id, name, summary, findings, rating, - rating_explanation, level, metadata, created_at, updated_at + id, community_id, name, summary, findings, rating, + rating_explanation, level, created_at, updated_at """ if include_embeddings: select_fields += ", description_embedding" @@ -1043,15 +1077,6 @@ async def get( for row in rows: community_dict = dict(row) - # Process metadata if it exists and is a string - if isinstance(community_dict["metadata"], str): - try: - community_dict["metadata"] = json.loads( - community_dict["metadata"] - ) - except json.JSONDecodeError: - pass - communities.append(Community(**community_dict)) return communities, count @@ -2506,7 +2531,6 @@ async def get_communities( offset: int, limit: int, community_ids: Optional[list[UUID]] = None, - levels: Optional[list[int]] = None, include_embeddings: bool = False, ) -> tuple[list[Community], int]: """ @@ -2517,7 +2541,6 @@ async def get_communities( offset: Number of records to skip limit: Maximum number of records to return (-1 for no limit) community_ids: Optional list of community IDs to filter by - levels: Optional list of levels to filter by include_embeddings: Whether to include embeddings in the response Returns: @@ -2532,11 +2555,6 @@ async def get_communities( params.append(community_ids) param_index += 1 - if levels: - conditions.append(f"level = ANY(${param_index})") - params.append(levels) - param_index += 1 - select_fields = """ id, collection_id, name, summary, findings, rating, rating_explanation """