Skip to content

Commit

Permalink
Update community tests
Browse files Browse the repository at this point in the history
  • Loading branch information
NolanTrem committed Nov 30, 2024
1 parent 555f732 commit a00fab8
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 80 deletions.
34 changes: 34 additions & 0 deletions js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,40 @@ describe("r2rClient V3 Graphs Integration Tests", () => {
expect(response.results.predicate).toBe("marries");
});

test("Update the community", async () => {
const response = await client.graphs.updateCommunity({
collectionId: collectionId,
communityId: communityId,
name: "Rodion Romanovich Raskolnikov and Avdotya Romanovna Raskolnikova Community",
summary:
"Rodion and Avdotya are siblings, the children of Pulcheria Alexandrovna Raskolnikova",
});

expect(response.results).toBeDefined();
expect(response.results.name).toBe(
"Rodion Romanovich Raskolnikov and Avdotya Romanovna Raskolnikova Community",
);
expect(response.results.summary).toBe(
"Rodion and Avdotya are siblings, the children of Pulcheria Alexandrovna Raskolnikova",
);
});

test("Retrieve the updated community", async () => {
const response = await client.graphs.getCommunity({
collectionId: collectionId,
communityId: communityId,
});

expect(response.results).toBeDefined();
expect(response.results.id).toBe(communityId);
expect(response.results.name).toBe(
"Rodion Romanovich Raskolnikov and Avdotya Romanovna Raskolnikova Community",
);
expect(response.results.summary).toBe(
"Rodion and Avdotya are siblings, the children of Pulcheria Alexandrovna Raskolnikova",
);
});

test("Delete the community", async () => {
const response = await client.graphs.deleteCommunity({
collectionId: collectionId,
Expand Down
2 changes: 1 addition & 1 deletion js/sdk/__tests__/r2rV2ClientIntegrationSuperUser.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ describe("r2rClient Integration Tests", () => {
metadatas: [{ title: "raskolnikov.txt" }, { title: "karamozov.txt" }],
}),
).resolves.not.toThrow();
});
}, 10000);

test("Ingest files in folder", async () => {
const files = ["examples/data/folder"];
Expand Down
53 changes: 31 additions & 22 deletions py/core/main/api/v3/graph_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -1204,9 +1204,6 @@ async def create_community(
rating_explanation: Optional[str] = Body(
default="", description="Explanation for the rating"
),
attributes: Optional[dict] = Body(
default=None, description="Attributes for the community"
),
auth_user=Depends(self.providers.auth.auth_wrapper),
) -> WrappedCommunityResponse:
"""
Expand Down Expand Up @@ -1302,9 +1299,15 @@ async def get_communities(
) -> WrappedCommunitiesResponse:
"""
Lists all communities in the graph with pagination support.
By default, all attributes are returned, but this can be limited using the `attributes` parameter.
"""
if (
not auth_user.is_superuser
and collection_id not in auth_user.graph_ids
):
raise R2RException(
"The currently authenticated user does not have access to the specified graph.",
403,
)

communities, count = (
await self.providers.database.graph_handler.get_communities(
Expand Down Expand Up @@ -1371,23 +1374,29 @@ async def get_community(
) -> WrappedCommunityResponse:
"""
Retrieves a specific community by its ID.
By default, all attributes are returned, but this can be limited using the `attributes` parameter.
"""
if not auth_user.is_superuser:
if (
not auth_user.is_superuser
and collection_id not in auth_user.graph_ids
):
raise R2RException(
"Only superusers can access this endpoint.", 403
"The currently authenticated user does not have access to the specified graph.",
403,
)

return await self.services[
results = await self.services[
"kg"
].providers.database.graph_handler.communities.get(
graph_id=collection_id,
community_id=community_id,
auth_user=auth_user,
parent_id=collection_id,
community_ids=[community_id],
store_type="graph",
offset=0,
limit=1,
)
print(f"results: {results}")
if len(results) == 0 or len(results[0]) == 0:
raise R2RException("Relationship not found", 404)
return results[0][0]

@self.router.delete(
"/graphs/{collection_id}/communities/{community_id}",
Expand Down Expand Up @@ -1519,29 +1528,29 @@ async def update_community(
name: Optional[str] = Body(None),
summary: Optional[str] = Body(None),
findings: Optional[list[str]] = Body(None),
rating: Optional[float] = Body(None),
rating: Optional[float] = Body(default=None, ge=1, le=10),
rating_explanation: Optional[str] = Body(None),
attributes: Optional[dict] = Body(None),
auth_user=Depends(self.providers.auth.auth_wrapper),
) -> WrappedCommunityResponse:
"""
Updates an existing community's metadata and properties.
Updates an existing community in the graph.
"""
if not auth_user.is_superuser:
if (
not auth_user.is_superuser
and collection_id not in auth_user.graph_ids
):
raise R2RException(
"Only superusers can update communities", 403
"The currently authenticated user does not have access to the specified graph.",
403,
)

return await self.services["kg"].update_community_v3(
id=collection_id,
return await self.services["kg"].update_community(
community_id=community_id,
name=name,
summary=summary,
findings=findings,
rating=rating,
rating_explanation=rating_explanation,
attributes=attributes,
auth_user=auth_user,
)

@self.router.post(
Expand Down
23 changes: 10 additions & 13 deletions py/core/main/services/kg_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,27 +409,25 @@ async def create_community(
@telemetry_event("update_community")
async def update_community(
self,
id: UUID,
community_id: UUID,
name: Optional[str],
summary: Optional[str],
findings: Optional[list[str]],
rating: Optional[float],
rating_explanation: Optional[str],
):
) -> Community:
summary_embedding = None
if summary is not None:
embedding = str(
summary_embedding = str(
await self.providers.embedding.async_get_embedding(summary)
)
else:
embedding = None

return await self.providers.database.graph_handler.communities.update(
id=id,
community_id=community_id,
store_type="graph", # type: ignore
name=name,
summary=summary,
embedding=embedding,
summary_embedding=summary_embedding,
findings=findings,
rating=rating,
rating_explanation=rating_explanation,
Expand All @@ -446,16 +444,16 @@ async def delete_community(
community_id=community_id,
)

@telemetry_event("list_communities_v3")
async def list_communities_v3(
@telemetry_event("list_communities")
async def list_communities(
self,
id: UUID,
collection_id: UUID,
offset: int,
limit: int,
**kwargs,
):
return await self.providers.database.graph_handler.communities.get(
id=id,
parent_id=collection_id,
store_type="graph", # type: ignore
offset=offset,
limit=limit,
)
Expand All @@ -473,7 +471,6 @@ async def get_communities(
):
return await self.providers.database.graph_handler.get_communities(
collection_id=collection_id,
levels=levels,
community_ids=community_ids,
offset=offset or 0,
limit=limit or -1,
Expand Down
Loading

0 comments on commit a00fab8

Please sign in to comment.