diff --git a/.github/workflows/r2r-js-sdk-integration-tests.yml b/.github/workflows/r2r-js-sdk-integration-tests.yml index 2712ab5c1..9f0ae0365 100644 --- a/.github/workflows/r2r-js-sdk-integration-tests.yml +++ b/.github/workflows/r2r-js-sdk-integration-tests.yml @@ -1,10 +1,43 @@ name: R2R JS SDK Integration Tests + on: push: branches: - - '**' # Trigger on all branches + - '**' + jobs: - test: + setup: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: Set up Python and install dependencies + uses: ./.github/actions/setup-python-light + with: + os: ubuntu-latest + - name: Setup and start PostgreSQL + uses: ./.github/actions/setup-postgres-ext + with: + os: ubuntu-latest + - name: Start R2R Light server + uses: ./.github/actions/start-r2r-light + - name: Use Node.js + uses: actions/setup-node@v2 + with: + node-version: "20.x" + - name: Install pnpm + uses: pnpm/action-setup@v2 + with: + version: 8.x + run_install: false + - name: Install JS SDK dependencies + working-directory: ./js/sdk + run: pnpm install + - name: Check if R2R server is running + run: | + curl http://localhost:7272/v2/health || echo "Server not responding" + + v2-unit-test: + needs: setup runs-on: ubuntu-latest env: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} @@ -42,9 +75,113 @@ jobs: - name: Install JS SDK dependencies working-directory: ./js/sdk run: pnpm install - - name: Check if R2R server is running - run: | - curl http://localhost:7272/v2/health || echo "Server not responding" + - name: Run r2rV2Client tests + working-directory: ./js/sdk + run: pnpm jest r2rV2Client.test.ts + + v2-integration-tests: + needs: v2-unit-test + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + test-group: + - r2rV2ClientIntegrationSuperUser.test.ts + - r2rV2ClientIntegrationUser.test.ts + env: + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + AZURE_API_KEY: ${{ secrets.AZURE_API_KEY }} + AZURE_API_BASE: ${{ secrets.AZURE_API_BASE }} + AZURE_API_VERSION: ${{ secrets.AZURE_API_VERSION }} + TELEMETRY_ENABLED: 'false' + R2R_POSTGRES_HOST: localhost + R2R_POSTGRES_DBNAME: postgres + R2R_POSTGRES_PORT: '5432' + R2R_POSTGRES_PASSWORD: postgres + R2R_POSTGRES_USER: postgres + R2R_PROJECT_NAME: r2r_default + steps: + - uses: actions/checkout@v2 + - name: Set up Python and install dependencies + uses: ./.github/actions/setup-python-light + with: + os: ubuntu-latest + - name: Setup and start PostgreSQL + uses: ./.github/actions/setup-postgres-ext + with: + os: ubuntu-latest + - name: Start R2R Light server + uses: ./.github/actions/start-r2r-light + - name: Use Node.js + uses: actions/setup-node@v2 + with: + node-version: "20.x" + - name: Install pnpm + uses: pnpm/action-setup@v2 + with: + version: 8.x + run_install: false + - name: Install JS SDK dependencies + working-directory: ./js/sdk + run: pnpm install - name: Run integration tests working-directory: ./js/sdk - run: pnpm test + run: pnpm jest ${{ matrix.test-group }} + + v3-integration-tests: + needs: setup + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + test-group: + - ChunksIntegrationSuperUser.test.ts + - CollectionsIntegrationSuperUser.test.ts + - ConversationsIntegrationSuperUser.test.ts + - DocumentsAndCollectionsIntegrationUser.test.ts + - DocumentsIntegrationSuperUser.test.ts + - GraphsIntegrationSuperUser.test.ts + - PromptsIntegrationSuperUser.test.ts + - RetrievalIntegrationSuperUser.test.ts + - SystemIntegrationSuperUser.test.ts + - SystemIntegrationUser.test.ts + - UsersIntegrationSuperUser.test.ts + env: + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + AZURE_API_KEY: ${{ secrets.AZURE_API_KEY }} + AZURE_API_BASE: ${{ secrets.AZURE_API_BASE }} + AZURE_API_VERSION: ${{ secrets.AZURE_API_VERSION }} + TELEMETRY_ENABLED: 'false' + R2R_POSTGRES_HOST: localhost + R2R_POSTGRES_DBNAME: postgres + R2R_POSTGRES_PORT: '5432' + R2R_POSTGRES_PASSWORD: postgres + R2R_POSTGRES_USER: postgres + R2R_PROJECT_NAME: r2r_default + steps: + - uses: actions/checkout@v2 + - name: Set up Python and install dependencies + uses: ./.github/actions/setup-python-light + with: + os: ubuntu-latest + - name: Setup and start PostgreSQL + uses: ./.github/actions/setup-postgres-ext + with: + os: ubuntu-latest + - name: Start R2R Light server + uses: ./.github/actions/start-r2r-light + - name: Use Node.js + uses: actions/setup-node@v2 + with: + node-version: "20.x" + - name: Install pnpm + uses: pnpm/action-setup@v2 + with: + version: 8.x + run_install: false + - name: Install JS SDK dependencies + working-directory: ./js/sdk + run: pnpm install + - name: Run remaining tests + working-directory: ./js/sdk + run: pnpm jest ${{ matrix.test-group }} diff --git a/js/sdk/__tests__/ChunksIntegrationSuperUser.test.ts b/js/sdk/__tests__/ChunksIntegrationSuperUser.test.ts index a15336239..6370b9b32 100644 --- a/js/sdk/__tests__/ChunksIntegrationSuperUser.test.ts +++ b/js/sdk/__tests__/ChunksIntegrationSuperUser.test.ts @@ -31,7 +31,7 @@ describe("r2rClient V3 Collections Integration Tests", () => { expect(response.results).toEqual([ { document_id: expect.any(String), - message: "Ingestion task completed successfully.", + message: "Document created and ingested successfully.", }, ]); }, 10000); diff --git a/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts b/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts index 5d5d0a3e6..68a510337 100644 --- a/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts +++ b/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts @@ -3,10 +3,14 @@ import { describe, test, beforeAll, expect } from "@jest/globals"; const baseUrl = "http://localhost:7272"; -describe("r2rClient V3 Collections Integration Tests", () => { +describe("r2rClient V3 Graphs Integration Tests", () => { let client: r2rClient; - let graph1Id: string; - let graph2Id: string; + let documentId: string; + let collectionId: string; + let entity1Id: string; + let entity2Id: string; + let relationshipId: string; + let communityId: string; beforeAll(async () => { client = new r2rClient(baseUrl); @@ -16,75 +20,472 @@ describe("r2rClient V3 Collections Integration Tests", () => { }); }); - test("Create a graph with only a name", async () => { - const response = await client.graphs.create({ - name: "Graph 1", + test("Create document with file path", async () => { + const response = await client.documents.create({ + file: { + path: "examples/data/raskolnikov_2.txt", + name: "raskolnikov_2.txt", + }, + metadata: { title: "raskolnikov_2.txt" }, + }); + + expect(response.results.document_id).toBeDefined(); + documentId = response.results.document_id; + }, 10000); + + test("Create new collection", async () => { + const response = await client.collections.create({ + name: "Raskolnikov Collection", + }); + expect(response).toBeTruthy(); + collectionId = response.results.id; + }); + + test("Retrieve collection", async () => { + const response = await client.collections.retrieve({ + id: collectionId, }); expect(response.results).toBeDefined(); - graph1Id = response.results.id; - expect(graph1Id).toEqual(response.results.id); - expect(response.results.name).toEqual("Graph 1"); - expect(response.results.description).toBe(null); + expect(response.results.id).toBe(collectionId); + expect(response.results.name).toBe("Raskolnikov Collection"); }); - test("Create a graph with name and description", async () => { - const response = await client.graphs.create({ - name: "2", - description: "Graph 2", + test("Update graph", async () => { + const response = await client.graphs.update({ + collectionId: collectionId, + name: "Raskolnikov Graph", }); - graph2Id = response.results.id; + expect(response.results).toBeDefined(); - expect(response.results.name).toEqual("2"); - expect(response.results.description).toEqual("Graph 2"); }); - test("Ensure that there are two graphs", async () => { - const response = await client.graphs.list(); + test("Retrieve graph and ensure that update was successful", async () => { + const response = await client.graphs.retrieve({ + collectionId: collectionId, + }); + expect(response.results).toBeDefined(); - expect(response.results.length).toEqual(2); + expect(response.results.name).toBe("Raskolnikov Graph"); + expect(response.results.updated_at).not.toBe(response.results.created_at); }); - test("Retrieve graph 1", async () => { - const response = await client.graphs.retrieve({ id: graph1Id }); + test("List graphs", async () => { + const response = await client.graphs.list({}); + expect(response.results).toBeDefined(); - expect(response.results.name).toEqual("Graph 1"); - expect(response.results.description).toBe(null); }); - test("Retrieve graph 2", async () => { - const response = await client.graphs.retrieve({ id: graph2Id }); + test("Check that there are no entities in the graph", async () => { + const response = await client.graphs.listEntities({ + collectionId: collectionId, + }); + expect(response.results).toBeDefined(); - expect(response.results.name).toEqual("2"); - expect(response.results.description).toEqual("Graph 2"); + expect(response.results.entries).toHaveLength(0); }); - test("Update the name of graph 1", async () => { - const response = await client.graphs.update({ - id: graph1Id, - name: "Graph 1 Updated", + test("Check that there are no relationships in the graph", async () => { + const response = await client.graphs.listRelationships({ + collectionId: collectionId, }); + expect(response.results).toBeDefined(); - expect(response.results.name).toEqual("Graph 1 Updated"); + expect(response.results.entries).toHaveLength; }); - test("Update the desription graph 2", async () => { - const response = await client.graphs.update({ - id: graph2Id, - description: "Graph 2 Updated", + test("Extract entities from the document", async () => { + const response = await client.documents.extract({ + id: documentId, + }); + + await new Promise((resolve) => setTimeout(resolve, 30000)); + + expect(response.results).toBeDefined(); + }, 60000); + + test("Assign document to collection", async () => { + const response = await client.collections.addDocument({ + id: collectionId, + documentId: documentId, + }); + expect(response.results).toBeDefined(); + }); + + test("Pull entities into the graph", async () => { + const response = await client.graphs.pull({ + collectionId: collectionId, + }); + expect(response.results).toBeDefined(); + }); + + test("Check that there are entities in the graph", async () => { + const response = await client.graphs.listEntities({ + collectionId: collectionId, + }); + expect(response.results).toBeDefined(); + expect(response.total_entries).toBeGreaterThanOrEqual(1); + }, 60000); + + test("Check that there are relationships in the graph", async () => { + const response = await client.graphs.listRelationships({ + collectionId: collectionId, + }); + expect(response.results).toBeDefined(); + expect(response.total_entries).toBeGreaterThanOrEqual(1); + }); + + test("Check that there are no communities in the graph prior to building", async () => { + const response = await client.graphs.listCommunities({ + collectionId: collectionId, + }); + + expect(response.results).toBeDefined(); + expect(response.results.entries).toHaveLength(0); + }); + + test("Build communities", async () => { + const response = await client.graphs.buildCommunities({ + collectionId: collectionId, + }); + + await new Promise((resolve) => setTimeout(resolve, 15000)); + + expect(response.results).toBeDefined(); + }, 45000); + + test("Check that there are communities in the graph", async () => { + const response = await client.graphs.listCommunities({ + collectionId: collectionId, + }); + + expect(response.results).toBeDefined(); + expect(response.total_entries).toBeGreaterThanOrEqual(1); + }); + + test("Create a new entity", async () => { + const response = await client.graphs.createEntity({ + collectionId: collectionId, + name: "Razumikhin", + description: "A good friend of Raskolnikov", + category: "Person", + }); + + expect(response.results).toBeDefined(); + entity1Id = response.results.id; + }); + + test("Create another new entity", async () => { + const response = await client.graphs.createEntity({ + collectionId: collectionId, + name: "Dunia", + description: "The sister of Raskolnikov", + category: "Person", }); + expect(response.results).toBeDefined(); - expect(response.results.description).toEqual("Graph 2 Updated"); + entity2Id = response.results.id; }); - test("Delete graph 1", async () => { - const response = await client.graphs.delete({ id: graph1Id }); + test("Retrieve the entity", async () => { + const response = await client.graphs.getEntity({ + collectionId: collectionId, + entityId: entity1Id, + }); + expect(response.results).toBeDefined(); - expect(response.results.success).toBe(true); + expect(response.results.id).toBe(entity1Id); + expect(response.results.name).toBe("Razumikhin"); + expect(response.results.description).toBe("A good friend of Raskolnikov"); }); - test("Delete graph 2", async () => { - const response = await client.graphs.delete({ id: graph2Id }); + test("Retrieve the other entity", async () => { + const response = await client.graphs.getEntity({ + collectionId: collectionId, + entityId: entity2Id, + }); + + expect(response.results).toBeDefined(); + expect(response.results.id).toBe(entity2Id); + expect(response.results.name).toBe("Dunia"); + expect(response.results.description).toBe("The sister of Raskolnikov"); + }); + + test("Check that the entities are in the graph", async () => { + const response = await client.graphs.listEntities({ + collectionId: collectionId, + }); + + expect(response.results).toBeDefined(); + expect(response.results.map((entity) => entity.id)).toContain(entity1Id); + expect(response.results.map((entity) => entity.id)).toContain(entity2Id); + }); + + test("Create a relationship between the entities", async () => { + const response = await client.graphs.createRelationship({ + collectionId: collectionId, + subject: "Razumikhin", + subjectId: entity1Id, + predicate: "falls in love with", + object: "Dunia", + objectId: entity2Id, + }); + + relationshipId = response.results.id; + + expect(response.results).toBeDefined(); + expect(response.results.subject).toBe("Razumikhin"); + expect(response.results.object).toBe("Dunia"); + expect(response.results.predicate).toBe("falls in love with"); + }); + + test("Retrieve the relationship", async () => { + const response = await client.graphs.getRelationship({ + collectionId: collectionId, + relationshipId: relationshipId, + }); + + expect(response.results).toBeDefined(); + expect(response.results.id).toBe(relationshipId); + expect(response.results.subject).toBe("Razumikhin"); + expect(response.results.object).toBe("Dunia"); + expect(response.results.predicate).toBe("falls in love with"); + }); + + test("Create a new community", async () => { + const response = await client.graphs.createCommunity({ + collectionId: collectionId, + name: "Raskolnikov and Dunia Community", + summary: + "Raskolnikov and Dunia are siblings, the children of Pulcheria Alexandrovna", + findings: [ + "Raskolnikov and Dunia are siblings", + "They are the children of Pulcheria Alexandrovna", + "Their family comes from a modest background", + "Dunia works as a governess to support the family", + "Raskolnikov is a former university student", + "Both siblings are intelligent and well-educated", + "They maintain a close relationship despite living apart", + "Their mother Pulcheria writes letters to keep them connected", + ], + rating: 10, + ratingExplanation: + "Raskolnikov and Dunia are central to the story and have a complex relationship", + }); + + communityId = response.results.id; + + expect(response.results).toBeDefined(); + expect(response.results.name).toBe("Raskolnikov and Dunia Community"); + expect(response.results.summary).toBe( + "Raskolnikov and Dunia are siblings, the children of Pulcheria Alexandrovna", + ); + expect(response.results.findings).toContain( + "Raskolnikov and Dunia are siblings", + ); + expect(response.results.findings).toContain( + "They are the children of Pulcheria Alexandrovna", + ); + expect(response.results.findings).toContain( + "Their family comes from a modest background", + ); + expect(response.results.findings).toContain( + "Dunia works as a governess to support the family", + ); + expect(response.results.findings).toContain( + "Raskolnikov is a former university student", + ); + expect(response.results.findings).toContain( + "Both siblings are intelligent and well-educated", + ); + expect(response.results.findings).toContain( + "They maintain a close relationship despite living apart", + ); + expect(response.results.findings).toContain( + "Their mother Pulcheria writes letters to keep them connected", + ); + expect(response.results.rating).toBe(10); + //TODO: Why is this failing? + // expect(response.results.ratingExplanation).toBe( + // "Raskolnikov and Dunia are central to the story and have a complex relationship", + // ); + }); + + test("Update the entity", async () => { + const response = await client.graphs.updateEntity({ + collectionId: collectionId, + entityId: entity1Id, + name: "Dmitri Prokofich Razumikhin", + description: "A good friend of Raskolnikov and Dunia", + category: "Person", + }); + + expect(response.results).toBeDefined(); + expect(response.results.id).toBe(entity1Id); + expect(response.results.name).toBe("Dmitri Prokofich Razumikhin"); + expect(response.results.description).toBe( + "A good friend of Raskolnikov and Dunia", + ); + }); + + test("Retrieve the updated entity", async () => { + const response = await client.graphs.getEntity({ + collectionId: collectionId, + entityId: entity1Id, + }); + + expect(response.results).toBeDefined(); + expect(response.results.id).toBe(entity1Id); + expect(response.results.name).toBe("Dmitri Prokofich Razumikhin"); + expect(response.results.description).toBe( + "A good friend of Raskolnikov and Dunia", + ); + }); + + // This test is failing because we attach a separate name to the relationship, rather + // than use the names of the entities. This needs to be fixed in the backend. + // test("Ensure that the entity was updated in the relationship", async () => { + // const response = await client.graphs.getRelationship({ + // collectionId: collectionId, + // relationshipId: relationshipId, + // }); + + // expect(response.results).toBeDefined(); + // expect(response.results.subject).toBe("Dmitri Prokofich Razumikhin"); + // expect(response.results.object).toBe("Dunia"); + // expect(response.results.predicate).toBe("falls in love with"); + // }); + + test("Update the relationship", async () => { + const response = await client.graphs.updateRelationship({ + collectionId: collectionId, + relationshipId: relationshipId, + subject: "Razumikhin", + subjectId: entity1Id, + predicate: "marries", + object: "Dunia", + objectId: entity2Id, + }); + + expect(response.results).toBeDefined(); + expect(response.results.id).toBe(relationshipId); + expect(response.results.subject).toBe("Razumikhin"); + expect(response.results.object).toBe("Dunia"); + expect(response.results.predicate).toBe("marries"); + }); + + test("Retrieve the updated relationship", async () => { + const response = await client.graphs.getRelationship({ + collectionId: collectionId, + relationshipId: relationshipId, + }); + + expect(response.results).toBeDefined(); + expect(response.results.id).toBe(relationshipId); + expect(response.results.subject).toBe("Razumikhin"); + expect(response.results.object).toBe("Dunia"); + expect(response.results.predicate).toBe("marries"); + }); + + test("Update the community", async () => { + const response = await client.graphs.updateCommunity({ + collectionId: collectionId, + communityId: communityId, + name: "Rodion Romanovich Raskolnikov and Avdotya Romanovna Raskolnikova Community", + summary: + "Rodion and Avdotya are siblings, the children of Pulcheria Alexandrovna Raskolnikova", + }); + + expect(response.results).toBeDefined(); + expect(response.results.name).toBe( + "Rodion Romanovich Raskolnikov and Avdotya Romanovna Raskolnikova Community", + ); + expect(response.results.summary).toBe( + "Rodion and Avdotya are siblings, the children of Pulcheria Alexandrovna Raskolnikova", + ); + }); + + test("Retrieve the updated community", async () => { + const response = await client.graphs.getCommunity({ + collectionId: collectionId, + communityId: communityId, + }); + + expect(response.results).toBeDefined(); + expect(response.results.id).toBe(communityId); + expect(response.results.name).toBe( + "Rodion Romanovich Raskolnikov and Avdotya Romanovna Raskolnikova Community", + ); + expect(response.results.summary).toBe( + "Rodion and Avdotya are siblings, the children of Pulcheria Alexandrovna Raskolnikova", + ); + }); + + test("Delete the community", async () => { + const response = await client.graphs.deleteCommunity({ + collectionId: collectionId, + communityId: communityId, + }); + + expect(response.results).toBeDefined(); + }); + + test("Check that the community was deleted", async () => { + const response = await client.graphs.listCommunities({ + collectionId: collectionId, + }); + + expect(response.results).toBeDefined(); + expect(response.results.entries).toHaveLength(0); + }); + + test("Reset the graph", async () => { + const response = await client.graphs.reset({ + collectionId: collectionId, + }); + + expect(response.results).toBeDefined(); + }); + + test("Check that there are no entities in the graph", async () => { + const response = await client.graphs.listEntities({ + collectionId: collectionId, + }); + + expect(response.results).toBeDefined(); + expect(response.results.entries).toHaveLength(0); + }); + + test("Check that there are no relationships in the graph", async () => { + const response = await client.graphs.listRelationships({ + collectionId: collectionId, + }); + + expect(response.results).toBeDefined(); + expect(response.results.entries).toHaveLength(0); + }); + + test("Delete raskolnikov_2.txt", async () => { + const response = await client.documents.delete({ + id: documentId, + }); + + expect(response.results).toBeDefined(); + }); + + test("Check that the document is not in the collection", async () => { + const response = await client.collections.listDocuments({ + id: collectionId, + }); + + expect(response.results).toBeDefined(); + expect(response.results.entries).toHaveLength(0); + }); + + test("Delete Raskolnikov Collection", async () => { + const response = await client.collections.delete({ + id: collectionId, + }); + expect(response.results).toBeDefined(); - expect(response.results.success).toBe(true); }); }); diff --git a/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts.txt b/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts.txt new file mode 100644 index 000000000..31cd60d74 --- /dev/null +++ b/js/sdk/__tests__/GraphsIntegrationSuperUser.test.ts.txt @@ -0,0 +1,90 @@ +import { r2rClient } from "../src/index"; +import { describe, test, beforeAll, expect } from "@jest/globals"; + +const baseUrl = "http://localhost:7272"; + +describe("r2rClient V3 Collections Integration Tests", () => { + let client: r2rClient; + let graph1Id: string; + let graph2Id: string; + + beforeAll(async () => { + client = new r2rClient(baseUrl); + await client.users.login({ + email: "admin@example.com", + password: "change_me_immediately", + }); + }); + + test("Create a graph with only a name", async () => { + const response = await client.graphs.create({ + name: "Graph 1", + }); + expect(response.results).toBeDefined(); + graph1Id = response.results.id; + expect(graph1Id).toEqual(response.results.id); + expect(response.results.name).toEqual("Graph 1"); + expect(response.results.description).toBe(""); + }); + + test("Create a graph with name and description", async () => { + const response = await client.graphs.create({ + name: "2", + description: "Graph 2", + }); + graph2Id = response.results.id; + expect(response.results).toBeDefined(); + expect(response.results.name).toEqual("2"); + expect(response.results.description).toEqual("Graph 2"); + }); + + test("Ensure that there are two graphs", async () => { + const response = await client.graphs.list(); + expect(response.results).toBeDefined(); + expect(response.results.length).toEqual(2); + }); + + test("Retrieve graph 1", async () => { + const response = await client.graphs.retrieve({ id: graph1Id }); + expect(response.results).toBeDefined(); + expect(response.results.name).toEqual("Graph 1"); + expect(response.results.description).toBe(""); + }); + + test("Retrieve graph 2", async () => { + const response = await client.graphs.retrieve({ id: graph2Id }); + expect(response.results).toBeDefined(); + expect(response.results.name).toEqual("2"); + expect(response.results.description).toEqual("Graph 2"); + }); + + test("Update the name of graph 1", async () => { + const response = await client.graphs.update({ + id: graph1Id, + name: "Graph 1 Updated", + }); + expect(response.results).toBeDefined(); + expect(response.results.name).toEqual("Graph 1 Updated"); + }); + + test("Update the description graph 2", async () => { + const response = await client.graphs.update({ + id: graph2Id, + description: "Graph 2 Updated", + }); + expect(response.results).toBeDefined(); + expect(response.results.description).toEqual("Graph 2 Updated"); + }); + + test("Delete graph 1", async () => { + const response = await client.graphs.delete({ id: graph1Id }); + expect(response.results).toBeDefined(); + expect(response.results.success).toBe(true); + }); + + test("Delete graph 2", async () => { + const response = await client.graphs.delete({ id: graph2Id }); + expect(response.results).toBeDefined(); + expect(response.results.success).toBe(true); + }); +}); diff --git a/js/sdk/__tests__/GraphsIntegrationUser.test.ts.txt b/js/sdk/__tests__/GraphsIntegrationUser.test.ts.txt new file mode 100644 index 000000000..4946015c2 --- /dev/null +++ b/js/sdk/__tests__/GraphsIntegrationUser.test.ts.txt @@ -0,0 +1,122 @@ +import { r2rClient } from "../src/index"; +import { describe, test, beforeAll, expect } from "@jest/globals"; + +const baseUrl = "http://localhost:7272"; + +describe("r2rClient V3 Collections Integration Tests", () => { + let client: r2rClient; + + let graph1Id: string; + let graph2Id: string; + + let entity1Id: string; + + beforeAll(async () => { + client = new r2rClient(baseUrl); + await client.users.login({ + email: "admin@example.com", + password: "change_me_immediately", + }); + }); + + test("Create a graph with only a name", async () => { + const response = await client.graphs.create({ + name: "Graph 1", + }); + expect(response.results).toBeDefined(); + graph1Id = response.results.id; + expect(graph1Id).toEqual(response.results.id); + expect(response.results.name).toEqual("Graph 1"); + expect(response.results.description).toBe(null); + }); + + test("Create a graph with name and description", async () => { + const response = await client.graphs.create({ + name: "2", + description: "Graph 2", + }); + graph2Id = response.results.id; + expect(response.results).toBeDefined(); + expect(response.results.name).toEqual("2"); + expect(response.results.description).toEqual("Graph 2"); + }); + + test("Ensure that there are two graphs", async () => { + const response = await client.graphs.list(); + expect(response.results).toBeDefined(); + expect(response.results.length).toEqual(2); + }); + + test("Retrieve graph 1", async () => { + const response = await client.graphs.retrieve({ id: graph1Id }); + expect(response.results).toBeDefined(); + expect(response.results.name).toEqual("Graph 1"); + expect(response.results.description).toBe(null); + }); + + test("Retrieve graph 2", async () => { + const response = await client.graphs.retrieve({ id: graph2Id }); + expect(response.results).toBeDefined(); + expect(response.results.name).toEqual("2"); + expect(response.results.description).toEqual("Graph 2"); + }); + + test("Update the name of graph 1", async () => { + const response = await client.graphs.update({ + id: graph1Id, + name: "Graph 1 Updated", + }); + expect(response.results).toBeDefined(); + expect(response.results.name).toEqual("Graph 1 Updated"); + }); + + test("Update the desription graph 2", async () => { + const response = await client.graphs.update({ + id: graph2Id, + description: "Graph 2 Updated", + }); + expect(response.results).toBeDefined(); + expect(response.results.description).toEqual("Graph 2 Updated"); + }); + + test("Create an entity and add it to graph 1", async () => { + const createResponse = await client.entities.create({ + name: "Entity 1", + description: "Entity 1 Description", + }); + entity1Id = createResponse.results.id; + expect(createResponse.results).toBeDefined(); + expect(createResponse.results.name).toEqual("Entity 1"); + + const addResponse = await client.graphs.addEntity({ + id: graph1Id, + entityId: createResponse.results.id, + }); + expect(addResponse.results).toBeDefined(); + }); + + test("Remove entity from graph 1", async () => { + const response = await client.graphs.removeEntity({ + id: graph1Id, + entityId: entity1Id, + }); + expect(response.results).toBeDefined(); + }); + + test("Delete entity from graph 1", async () => { + const response = await client.entities.delete({ id: entity1Id }); + expect(response.results).toBeDefined(); + }); + + test("Delete graph 1", async () => { + const response = await client.graphs.delete({ id: graph1Id }); + expect(response.results).toBeDefined(); + expect(response.results.success).toBe(true); + }); + + test("Delete graph 2", async () => { + const response = await client.graphs.delete({ id: graph2Id }); + expect(response.results).toBeDefined(); + expect(response.results.success).toBe(true); + }); +}); diff --git a/js/sdk/__tests__/RetrievalIntegrationSuperUser.test.ts b/js/sdk/__tests__/RetrievalIntegrationSuperUser.test.ts index 10e3f6465..89451941e 100644 --- a/js/sdk/__tests__/RetrievalIntegrationSuperUser.test.ts +++ b/js/sdk/__tests__/RetrievalIntegrationSuperUser.test.ts @@ -3,16 +3,10 @@ import { describe, test, beforeAll, expect } from "@jest/globals"; const baseUrl = "http://localhost:7272"; -const messages = [ - { - role: "system" as const, - content: "You are a helpful assistant.", - }, - { - role: "user" as const, - content: "Tell me about Sonia.", - }, -]; +const message = { + role: "user" as const, + content: "Tell me about Sonia.", +}; /** * sonia.txt will have an id of 28ce9a4c-4d15-5287-b0c6-67834b9c4546 @@ -85,15 +79,15 @@ describe("r2rClient V3 Documents Integration Tests", () => { test("Agent with no parameters", async () => { const response = await client.retrieval.agent({ - messages: messages, + message: message, }); expect(response.results).toBeDefined(); }, 30000); - test("Streaming RAG", async () => { + test("Streaming agent", async () => { const stream = await client.retrieval.agent({ - messages: messages, + message: message, ragGenerationConfig: { stream: true, }, diff --git a/js/sdk/__tests__/r2rV2ClientIntegrationSuperUser.test.ts b/js/sdk/__tests__/r2rV2ClientIntegrationSuperUser.test.ts index 2e5cfdbde..ddeb9a365 100644 --- a/js/sdk/__tests__/r2rV2ClientIntegrationSuperUser.test.ts +++ b/js/sdk/__tests__/r2rV2ClientIntegrationSuperUser.test.ts @@ -129,7 +129,7 @@ describe("r2rClient Integration Tests", () => { metadatas: [{ title: "raskolnikov.txt" }, { title: "karamozov.txt" }], }), ).resolves.not.toThrow(); - }); + }, 10000); test("Ingest files in folder", async () => { const files = ["examples/data/folder"]; @@ -153,7 +153,7 @@ describe("r2rClient Integration Tests", () => { await expect( client.ingestChunks([{ text: "test chunks" }]), ).resolves.not.toThrow(); - }); + }, 10000); test("Ingest chunks", async () => { await expect( @@ -163,7 +163,7 @@ describe("r2rClient Integration Tests", () => { { source: "example" }, ), ).resolves.not.toThrow(); - }); + }, 10000); test("Search documents", async () => { await expect(client.search("test")).resolves.not.toThrow(); diff --git a/js/sdk/examples/data/raskolnikov_2.txt b/js/sdk/examples/data/raskolnikov_2.txt new file mode 100644 index 000000000..895e99965 --- /dev/null +++ b/js/sdk/examples/data/raskolnikov_2.txt @@ -0,0 +1,7 @@ +When Raskolnikov got home, his hair was soaked with sweat and he was +breathing heavily. He went rapidly up the stairs, walked into his +unlocked room and at once fastened the latch. Then in senseless terror +he rushed to the corner, to that hole under the paper where he had put +the things; put his hand in, and for some minutes felt carefully in the +hole, in every crack and fold of the paper. Finding nothing, he got up +and drew a deep breath. diff --git a/js/sdk/src/types.ts b/js/sdk/src/types.ts index 75c2707b3..0c72d2ff3 100644 --- a/js/sdk/src/types.ts +++ b/js/sdk/src/types.ts @@ -47,6 +47,20 @@ export interface CollectionResponse { document_count: number; } +// Community types +export interface CommunityResponse { + id: string; + name: string; + summary: string; + findings: string[]; + communityId?: string; + graphId?: string; + collectionId?: string; + rating?: number; + ratingExplanation?: string; + descriptionEmbedding?: string; +} + // Conversation types export interface ConversationResponse { id: string; @@ -82,13 +96,32 @@ export interface DocumentResponse { size_in_bytes?: number; ingestion_status: string; kg_extraction_status: string; - created_date: string; - updated_date: string; + created_at: string; + updated_at: string; ingestion_attempt_number?: number; summary?: string; summary_embedding?: string; } +// Entity types +export interface EntityResponse { + id: string; + sid?: string; + name: string; + category?: string; + description?: string; + chunk_ids: string[]; + description_embedding?: string; + document_id: string; + document_ids: string[]; + graph_ids: string[]; + user_id: string; + last_modified_by: string; + created_at: string; + updated_at: string; + attributes?: Record; +} + // Graph types export interface GraphResponse { id: string; @@ -134,6 +167,21 @@ export interface PromptResponse { input_types: string[]; } +// Relationship types +export interface RelationshipResponse { + id: string; + subject: string; + predicate: string; + object: string; + description?: string; + subject_id: string; + object_id: string; + weight: number; + chunk_ids: string[]; + parent_id: string; + metadata: Record; +} + // Retrieval types export interface VectorSearchResult { chunk_id: string; @@ -218,6 +266,12 @@ export type WrappedCollectionsResponse = PaginatedResultsWrapper< CollectionResponse[] >; +// Community Responses +export type WrappedCommunityResponse = ResultsWrapper; +export type WrappedCommunitiesResponse = PaginatedResultsWrapper< + CommunityResponse[] +>; + // Conversation Responses export type WrappedConversationMessagesResponse = ResultsWrapper< MessageResponse[] @@ -240,6 +294,10 @@ export type WrappedDocumentsResponse = PaginatedResultsWrapper< DocumentResponse[] >; +// Entity Responses +export type WrappedEntityResponse = ResultsWrapper; +export type WrappedEntitiesResponse = PaginatedResultsWrapper; + // Graph Responses export type WrappedGraphResponse = ResultsWrapper; export type WrappedGraphsResponse = PaginatedResultsWrapper; @@ -254,6 +312,12 @@ export type WrappedListVectorIndicesResponse = ResultsWrapper; export type WrappedPromptResponse = ResultsWrapper; export type WrappedPromptsResponse = PaginatedResultsWrapper; +// Relationship Responses +export type WrappedRelationshipResponse = ResultsWrapper; +export type WrappedRelationshipsResponse = PaginatedResultsWrapper< + RelationshipResponse[] +>; + // Retrieval Responses export type WrappedVectorSearchResponse = ResultsWrapper; export type WrappedSearchResponse = ResultsWrapper; diff --git a/js/sdk/src/v3/clients/chunks.ts b/js/sdk/src/v3/clients/chunks.ts index 86593be7e..a2ce30ada 100644 --- a/js/sdk/src/v3/clients/chunks.ts +++ b/js/sdk/src/v3/clients/chunks.ts @@ -1,3 +1,4 @@ +import { feature } from "../../feature"; import { r2rClient } from "../../r2rClient"; import { UnprocessedChunk, @@ -20,6 +21,7 @@ export class ChunksClient { * @param runWithOrchestration Optional flag to run with orchestration * @returns */ + @feature("chunks.create") async create(options: { chunks: UnprocessedChunk[]; runWithOrchestration?: boolean; @@ -39,6 +41,7 @@ export class ChunksClient { * @param metadata Optional new metadata for the chunk * @returns */ + @feature("chunks.update") async update(options: { id: string; text?: string; @@ -54,6 +57,7 @@ export class ChunksClient { * @param id ID of the chunk to retrieve * @returns */ + @feature("chunks.retrieve") async retrieve(options: { id: string }): Promise { return this.client.makeRequest("GET", `chunks/${options.id}`); } @@ -63,6 +67,7 @@ export class ChunksClient { * @param id ID of the chunk to delete * @returns */ + @feature("chunks.delete") async delete(options: { id: string }): Promise { return this.client.makeRequest("DELETE", `chunks/${options.id}`); } @@ -75,6 +80,7 @@ export class ChunksClient { * @param limit Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. * @returns */ + @feature("chunks.list") async list(options?: { includeVectors?: boolean; metadataFilters?: Record; diff --git a/js/sdk/src/v3/clients/collections.ts b/js/sdk/src/v3/clients/collections.ts index a6fe32ddb..c77b1eda5 100644 --- a/js/sdk/src/v3/clients/collections.ts +++ b/js/sdk/src/v3/clients/collections.ts @@ -1,3 +1,4 @@ +import { feature } from "../../feature"; import { r2rClient } from "../../r2rClient"; import { WrappedBooleanResponse, @@ -17,6 +18,7 @@ export class CollectionsClient { * @param description Optional description of the collection * @returns A promise that resolves with the created collection */ + @feature("collections.create") async create(options: { name: string; description?: string; @@ -33,6 +35,7 @@ export class CollectionsClient { * @param limit Optional limit for pagination * @returns */ + @feature("collections.list") async list(options?: { ids?: string[]; offset?: number; @@ -57,6 +60,7 @@ export class CollectionsClient { * @param id Collection ID to retrieve * @returns */ + @feature("collections.retrieve") async retrieve(options: { id: string }): Promise { return this.client.makeRequest("GET", `collections/${options.id}`); } @@ -68,6 +72,7 @@ export class CollectionsClient { * @param description Optional new description for the collection * @returns */ + @feature("collections.update") async update(options: { id: string; name?: string; @@ -88,6 +93,7 @@ export class CollectionsClient { * @param id Collection ID to delete * @returns */ + @feature("collections.delete") async delete(options: { id: string }): Promise { return this.client.makeRequest("DELETE", `collections/${options.id}`); } @@ -99,6 +105,7 @@ export class CollectionsClient { * @param limit Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. * @returns */ + @feature("collections.listDocuments") async listDocuments(options: { id: string; offset?: number; @@ -124,6 +131,7 @@ export class CollectionsClient { * @param documentId Document ID to add * @returns */ + @feature("collections.addDocument") async addDocument(options: { id: string; documentId: string; @@ -140,6 +148,7 @@ export class CollectionsClient { * @param documentId Document ID to remove * @returns */ + @feature("collections.removeDocument") async removeDocument(options: { id: string; documentId: string; @@ -157,6 +166,7 @@ export class CollectionsClient { * @param limit Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. * @returns */ + @feature("collections.listUsers") async listUsers(options: { id: string; offset?: number; @@ -178,6 +188,7 @@ export class CollectionsClient { * @param userId User ID to add * @returns */ + @feature("collections.addUser") async addUser(options: { id: string; userId: string; @@ -194,6 +205,7 @@ export class CollectionsClient { * @param userId User ID to remove * @returns */ + @feature("collections.removeUser") async removeUser(options: { id: string; userId: string; @@ -203,4 +215,49 @@ export class CollectionsClient { `collections/${options.id}/users/${options.userId}`, ); } + + /** + * Creates communities in the graph by analyzing entity relationships and similarities. + * + * Communities are created through the following process: + * 1. Analyzes entity relationships and metadata to build a similarity graph + * 2. Applies advanced community detection algorithms (e.g. Leiden) to identify densely connected groups + * 3. Creates hierarchical community structure with multiple granularity levels + * 4. Generates natural language summaries and statistical insights for each community + * + * The resulting communities can be used to: + * - Understand high-level graph structure and organization + * - Identify key entity groupings and their relationships + * - Navigate and explore the graph at different levels of detail + * - Generate insights about entity clusters and their characteristics + * + * The community detection process is configurable through settings like: + * - Community detection algorithm parameters + * - Summary generation prompt + * @param collectionId The collection ID corresponding to the graph + * @returns + */ + @feature("collections.extract") + async extract(options: { + collectionId: string; + runType?: string; + settings?: Record; + runWithOrchestration?: boolean; + }): Promise { + const data = { + ...(options.settings && { settings: options.settings }), + ...(options.runType && { run_type: options.runType }), + ...(options.runWithOrchestration && { + run_with_orchestration: options.runWithOrchestration, + }), + }; + + return this.client.makeRequest( + "POST", + `collections/${options.collectionId}/extract`, + { + data, + }, + ); + } } diff --git a/js/sdk/src/v3/clients/conversations.ts b/js/sdk/src/v3/clients/conversations.ts index 425f6dab6..796c1edc7 100644 --- a/js/sdk/src/v3/clients/conversations.ts +++ b/js/sdk/src/v3/clients/conversations.ts @@ -1,3 +1,4 @@ +import { feature } from "../../feature"; import { r2rClient } from "../../r2rClient"; import { WrappedBooleanResponse, @@ -15,6 +16,7 @@ export class ConversationsClient { * Create a new conversation. * @returns */ + @feature("conversations.create") async create(): Promise { return this.client.makeRequest("POST", "conversations"); } @@ -26,6 +28,7 @@ export class ConversationsClient { * @param limit Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. * @returns */ + @feature("conversations.list") async list(options?: { ids?: string[]; offset?: number; @@ -51,6 +54,7 @@ export class ConversationsClient { * @param branchID The ID of the branch to retrieve * @returns */ + @feature("conversations.retrieve") async retrieve(options: { id: string; branchID?: string; @@ -69,6 +73,7 @@ export class ConversationsClient { * @param id The ID of the conversation to delete * @returns */ + @feature("conversations.delete") async delete(options: { id: string }): Promise { return this.client.makeRequest("DELETE", `conversations/${options.id}`); } @@ -82,6 +87,7 @@ export class ConversationsClient { * @param metadata Additional metadata to attach to the message * @returns */ + @feature("conversations.addMessage") async addMessage(options: { id: string; content: string; @@ -112,6 +118,7 @@ export class ConversationsClient { * @param content The new content of the message * @returns */ + @feature("conversations.updateMessage") async updateMessage(options: { id: string; messageID: string; @@ -135,6 +142,7 @@ export class ConversationsClient { * @param id The ID of the conversation to list branches for * @returns */ + @feature("conversations.listBranches") async listBranches(options: { id: string; offset?: number; diff --git a/js/sdk/src/v3/clients/documents.ts b/js/sdk/src/v3/clients/documents.ts index 8b328b8ca..bc720b4c3 100644 --- a/js/sdk/src/v3/clients/documents.ts +++ b/js/sdk/src/v3/clients/documents.ts @@ -6,8 +6,11 @@ import { WrappedCollectionsResponse, WrappedDocumentResponse, WrappedDocumentsResponse, + WrappedEntitiesResponse, WrappedIngestionResponse, + WrappedRelationshipsResponse, } from "../../types"; +import { feature } from "../../feature"; let fs: any; if (typeof window === "undefined") { @@ -32,6 +35,7 @@ export class DocumentsClient { * @param runWithOrchestration Optional flag to run with orchestration * @returns */ + @feature("documents.create") async create(options: { file?: FileInput; content?: string; @@ -144,6 +148,7 @@ export class DocumentsClient { * @param runWithOrchestration Whether to run with orchestration * @returns */ + @feature("documents.update") async update(options: { id: string; file?: FileInput; @@ -236,6 +241,7 @@ export class DocumentsClient { * @param limit Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. * @returns */ + @feature("documents.retrieve") async retrieve(options: { id: string }): Promise { return this.client.makeRequest("GET", `documents/${options.id}`); } @@ -247,6 +253,7 @@ export class DocumentsClient { * @param limit Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. * @returns */ + @feature("documents.list") async list(options?: { ids?: string[]; offset?: number; @@ -271,6 +278,7 @@ export class DocumentsClient { * @param id ID of document to download * @returns */ + @feature("documents.download") async download(options: { id: string }): Promise { return this.client.makeRequest("GET", `documents/${options.id}/download`, { responseType: "blob", @@ -282,6 +290,7 @@ export class DocumentsClient { * @param id ID of document to delete * @returns */ + @feature("documents.delete") async delete(options: { id: string }): Promise { return this.client.makeRequest("DELETE", `documents/${options.id}`); } @@ -294,6 +303,7 @@ export class DocumentsClient { * @param limit Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. * @returns */ + @feature("documents.listChunks") async listChunks(options: { id: string; includeVectors?: boolean; @@ -318,6 +328,7 @@ export class DocumentsClient { * @param limit Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. * @returns */ + @feature("documents.listCollections") async listCollections(options: { id: string; offset?: number; @@ -337,6 +348,7 @@ export class DocumentsClient { ); } + @feature("documents.deleteByFilter") async deleteByFilter(options: { filters: Record; }): Promise { @@ -344,4 +356,109 @@ export class DocumentsClient { data: options.filters, }); } + + /** + * Extracts entities and relationships from a document. + * + * The entities and relationships extraction process involves: + * 1. Parsing documents into semantic chunks + * 2. Extracting entities and relationships using LLMs + * @param options + * @returns + */ + @feature("documents.extract") + async extract(options: { + id: string; + runType?: string; + runWithOrchestration?: boolean; + }): Promise { + const data: Record = {}; + + if (options.runType) { + data.runType = options.runType; + } + if (options.runWithOrchestration !== undefined) { + data.runWithOrchestration = options.runWithOrchestration; + } + + return this.client.makeRequest("POST", `documents/${options.id}/extract`, { + data, + }); + } + + /** + * Retrieves the entities that were extracted from a document. These + * represent important semantic elements like people, places, + * organizations, concepts, etc. + * + * Users can only access entities from documents they own or have access + * to through collections. Entity embeddings are only included if + * specifically requested. + * + * Results are returned in the order they were extracted from the document. + * @param id Document ID to retrieve entities for + * @param offset Specifies the number of objects to skip. Defaults to 0. + * @param limit Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. + * @param includeEmbeddings Whether to include vector embeddings in the response. + * @returns + */ + @feature("documents.listEntities") + async listEntities(options: { + id: string; + offset?: number; + limit?: number; + includeVectors?: boolean; + }): Promise { + const params: Record = { + offset: options.offset ?? 0, + limit: options.limit ?? 100, + includeVectors: options.includeVectors ?? false, + }; + + return this.client.makeRequest("GET", `documents/${options.id}/entities`, { + params, + }); + } + + /** + * Retrieves the relationships between entities that were extracted from + * a document. These represent connections and interactions between + * entities found in the text. + * + * Users can only access relationships from documents they own or have + * access to through collections. Results can be filtered by entity names + * and relationship types. + * + * Results are returned in the order they were extracted from the document. + * @param id Document ID to retrieve relationships for + * @param offset Specifies the number of objects to skip. Defaults to 0. + * @param limit Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. + * @param includeEmbeddings Whether to include vector embeddings in the response. + * @param entityNames Filter relationships by specific entity names. + * @param relationshipTypes Filter relationships by specific relationship types. + * @returns WrappedRelationshipsResponse + */ + @feature("documents.listRelationships") + async listRelationships(options: { + id: string; + offset?: number; + limit?: number; + includeVectors?: boolean; + entityNames?: string[]; + relationshipTypes?: string[]; + }): Promise { + const params: Record = { + offset: options.offset ?? 0, + limit: options.limit ?? 100, + includeVectors: options.includeVectors ?? false, + }; + + return this.client.makeRequest( + "GET", + `documents/${options.id}/relationships`, + { + params, + }, + ); + } } diff --git a/js/sdk/src/v3/clients/graphs.ts b/js/sdk/src/v3/clients/graphs.ts index 2daeba3bb..351b0b2e3 100644 --- a/js/sdk/src/v3/clients/graphs.ts +++ b/js/sdk/src/v3/clients/graphs.ts @@ -1,37 +1,30 @@ +import { feature } from "../../feature"; import { r2rClient } from "../../r2rClient"; import { WrappedGraphResponse, WrappedBooleanResponse, WrappedGraphsResponse, + WrappedEntityResponse, + WrappedEntitiesResponse, + WrappedRelationshipsResponse, + WrappedRelationshipResponse, + WrappedCommunitiesResponse, + WrappedCommunityResponse, } from "../../types"; export class GraphsClient { constructor(private client: r2rClient) {} - /** - * Create a new graph. - * @param name Name of the graph - * @param description Optional description of the graph - * @returns The created graph - */ - async create(options: { - name: string; - description?: string; - }): Promise { - return this.client.makeRequest("POST", "graphs", { - data: options, - }); - } - /** * List graphs with pagination and filtering options. - * @param ids Optional list of graph IDs to filter by + * @param collectionIds Optional list of collection IDs to filter by * @param offset Optional offset for pagination * @param limit Optional limit for pagination * @returns */ + @feature("graphs.list") async list(options?: { - ids?: string[]; + collectionIds?: string[]; offset?: number; limit?: number; }): Promise { @@ -40,8 +33,8 @@ export class GraphsClient { limit: options?.limit ?? 100, }; - if (options?.ids && options.ids.length > 0) { - params.ids = options.ids; + if (options?.collectionIds && options.collectionIds.length > 0) { + params.collectionIds = options.collectionIds; } return this.client.makeRequest("GET", "graphs", { @@ -49,19 +42,48 @@ export class GraphsClient { }); } - async retrieve(options: { id: string }): Promise { - return this.client.makeRequest("GET", `graphs/${options.id}`); + /** + * Get detailed information about a specific graph. + * @param collectionId The collection ID corresponding to the graph + * @returns + */ + @feature("graphs.retrieve") + async retrieve(options: { + collectionId: string; + }): Promise { + return this.client.makeRequest("GET", `graphs/${options.collectionId}`); + } + + /** + * Deletes a graph and all its associated data. + * + * This endpoint permanently removes the specified graph along with all + * entities and relationships that belong to only this graph. + * + * Entities and relationships extracted from documents are not deleted. + * @param collectionId The collection ID corresponding to the graph + * @returns + */ + @feature("graphs.reset") + async reset(options: { + collectionId: string; + }): Promise { + return this.client.makeRequest( + "POST", + `graphs/${options.collectionId}/reset`, + ); } /** * Update an existing graph. - * @param id Graph ID to update + * @param collectionId The collection ID corresponding to the graph * @param name Optional new name for the graph * @param description Optional new description for the graph * @returns */ + @feature("graphs.update") async update(options: { - id: string; + collectionId: string; name?: string; description?: string; }): Promise { @@ -70,17 +92,511 @@ export class GraphsClient { ...(options.description && { description: options.description }), }; - return this.client.makeRequest("POST", `graphs/${options.id}`, { + return this.client.makeRequest("POST", `graphs/${options.collectionId}`, { data, }); } /** + * Creates a new entity in the graph. + * @param collectionId The collection ID corresponding to the graph + * @param entity Entity to add + * @returns + */ + @feature("graphs.createEntity") + async createEntity(options: { + collectionId: string; + name: string; + description?: string; + category?: string; + metadata?: Record; + }): Promise { + const data = { + name: options.name, + ...(options.description && { description: options.description }), + ...(options.category && { category: options.category }), + ...(options.metadata && { metadata: options.metadata }), + }; + + return this.client.makeRequest( + "POST", + `graphs/${options.collectionId}/entities`, + { + data, + }, + ); + } + + /** + * List all entities in a graph. + * @param collectionId The collection ID corresponding to the graph + * @param offset Specifies the number of objects to skip. Defaults to 0. + * @param limit Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. + * @returns + */ + @feature("graphs.listEntities") + async listEntities(options: { + collectionId: string; + offset?: number; + limit?: number; + }): Promise { + const params: Record = { + offset: options?.offset ?? 0, + limit: options?.limit ?? 100, + }; + + return this.client.makeRequest( + "GET", + `graphs/${options.collectionId}/entities`, + { + params, + }, + ); + } + + /** + * Retrieve an entity from a graph. + * @param collectionId The collection ID corresponding to the graph + * @param entityId Entity ID to retrieve + * @returns + */ + @feature("graphs.getEntity") + async getEntity(options: { + collectionId: string; + entityId: string; + }): Promise { + return this.client.makeRequest( + "GET", + `graphs/${options.collectionId}/entities/${options.entityId}`, + ); + } + + /** + * Updates an existing entity in the graph. + * @param collectionId The collection ID corresponding to the graph + * @param entityId Entity ID to update + * @param entity Entity to update + * @returns + */ + @feature("graphs.updateEntity") + async updateEntity(options: { + collectionId: string; + entityId: string; + name?: string; + description?: string; + category?: string; + metadata?: Record; + }): Promise { + const data = { + ...(options.name && { name: options.name }), + ...(options.description && { description: options.description }), + ...(options.category && { category: options.category }), + ...(options.metadata && { metadata: options.metadata }), + }; + + return this.client.makeRequest( + "POST", + `graphs/${options.collectionId}/entities/${options.entityId}`, + { + data, + }, + ); + } + + /** + * Remove an entity from a graph. + * @param collectionId The collection ID corresponding to the graph + * @param entityId Entity ID to remove + * @returns + */ + @feature("graphs.removeEntity") + async removeEntity(options: { + collectionId: string; + entityId: string; + }): Promise { + return this.client.makeRequest( + "DELETE", + `graphs/${options.collectionId}/entities/${options.entityId}`, + ); + } + /** + * Creates a new relationship in the graph. + * @param collectionId The collection ID corresponding to the graph + * @param relationship Relationship to add + * @returns + */ + @feature("graphs.createRelationship") + async createRelationship(options: { + collectionId: string; + subject: string; + subjectId: string; + predicate: string; + object: string; + objectId: string; + description?: string; + weight?: number; + metadata?: Record; + }): Promise { + const data = { + subject: options.subject, + subject_id: options.subjectId, + predicate: options.predicate, + object: options.object, + object_id: options.objectId, + ...(options.description && { description: options.description }), + ...(options.weight && { weight: options.weight }), + ...(options.metadata && { metadata: options.metadata }), + }; + + return this.client.makeRequest( + "POST", + `graphs/${options.collectionId}/relationships`, + { + data, + }, + ); + } + + /** + * List all relationships in a graph. + * @param collectionId The collection ID corresponding to the graph + * @param offset Specifies the number of objects to skip. Defaults to 0. + * @param limit Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. + * @returns + */ + @feature("graphs.listRelationships") + async listRelationships(options: { + collectionId: string; + offset?: number; + limit?: number; + }): Promise { + const params: Record = { + offset: options?.offset ?? 0, + limit: options?.limit ?? 100, + }; + + return this.client.makeRequest( + "GET", + `graphs/${options.collectionId}/relationships`, + { + params, + }, + ); + } + + /** + * Retrieve a relationship from a graph. + * @param collectionId The collection ID corresponding to the graph + * @param relationshipId Relationship ID to retrieve + * @returns + */ + @feature("graphs.getRelationship") + async getRelationship(options: { + collectionId: string; + relationshipId: string; + }): Promise { + return this.client.makeRequest( + "GET", + `graphs/${options.collectionId}/relationships/${options.relationshipId}`, + ); + } + + /** + * Updates an existing relationship in the graph. + * @param collectionId The collection ID corresponding to the graph + * @param relationshipId Relationship ID to update + * @param relationship Relationship to update + * @returns WrappedRelationshipResponse + */ + @feature("graphs.updateRelationship") + async updateRelationship(options: { + collectionId: string; + relationshipId: string; + subject?: string; + subjectId?: string; + predicate?: string; + object?: string; + objectId?: string; + description?: string; + weight?: number; + metadata?: Record; + }): Promise { + const data = { + ...(options.subject && { subject: options.subject }), + ...(options.subjectId && { subject_id: options.subjectId }), + ...(options.predicate && { predicate: options.predicate }), + ...(options.object && { object: options.object }), + ...(options.objectId && { object_id: options.objectId }), + ...(options.description && { description: options.description }), + ...(options.weight && { weight: options.weight }), + ...(options.metadata && { metadata: options.metadata }), + }; + + return this.client.makeRequest( + "POST", + `graphs/${options.collectionId}/relationships/${options.relationshipId}`, + { + data, + }, + ); + } + + /** + * Remove a relationship from a graph. + * @param collectionId The collection ID corresponding to the graph + * @param relationshipId Entity ID to remove + * @returns + */ + @feature("graphs.removeRelationship") + async removeRelationship(options: { + collectionId: string; + relationshipId: string; + }): Promise { + return this.client.makeRequest( + "DELETE", + `graphs/${options.collectionId}/relationships/${options.relationshipId}`, + ); + } + + /** + * Creates a new community in the graph. + * + * While communities are typically built automatically via the /graphs/{id}/communities/build endpoint, + * this endpoint allows you to manually create your own communities. + * + * This can be useful when you want to: + * - Define custom groupings of entities based on domain knowledge + * - Add communities that weren't detected by the automatic process + * - Create hierarchical organization structures + * - Tag groups of entities with specific metadata + * + * The created communities will be integrated with any existing automatically detected communities + * in the graph's community structure. + * + * @param collectionId The collection ID corresponding to the graph + * @param name Name of the community + * @param summary Summary of the community + * @param findings Findings or insights about the community + * @param rating Rating of the community + * @param ratingExplanation Explanation of the community rating + * @param attributes Additional attributes to associate with the community + * @returns WrappedCommunityResponse + */ + @feature("graphs.createCommunity") + async createCommunity(options: { + collectionId: string; + name: string; + summary: string; + findings?: string[]; + rating?: number; + ratingExplanation?: string; + attributes?: Record; + }): Promise { + const data = { + name: options.name, + ...(options.summary && { summary: options.summary }), + ...(options.findings && { findings: options.findings }), + ...(options.rating && { rating: options.rating }), + ...(options.ratingExplanation && { + rating_explanation: options.ratingExplanation, + }), + ...(options.attributes && { attributes: options.attributes }), + }; + + return this.client.makeRequest( + "POST", + `graphs/${options.collectionId}/communities`, + { + data, + }, + ); + } + + /** + * List all communities in a graph. + * @param collectionId The collection ID corresponding to the graph + * @param offset Specifies the number of objects to skip. Defaults to 0. + * @param limit Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. + * @returns + */ + @feature("graphs.listCommunities") + async listCommunities(options: { + collectionId: string; + offset?: number; + limit?: number; + }): Promise { + const params: Record = { + offset: options?.offset ?? 0, + limit: options?.limit ?? 100, + }; + + return this.client.makeRequest( + "GET", + `graphs/${options.collectionId}/communities`, + { + params, + }, + ); + } + + /** + * Retrieve a community from a graph. + * @param collectionId The collection ID corresponding to the graph + * @param communityId Entity ID to retrieve + * @returns + */ + @feature("graphs.getCommunity") + async getCommunity(options: { + collectionId: string; + communityId: string; + }): Promise { + return this.client.makeRequest( + "GET", + `graphs/${options.collectionId}/communities/${options.communityId}`, + ); + } + + /** + * Updates an existing community in the graph. + * @param collectionId The collection ID corresponding to the graph + * @param communityId Community ID to update + * @param entity Entity to update + * @returns WrappedCommunityResponse + */ + @feature("graphs.updateCommunity") + async updateCommunity(options: { + collectionId: string; + communityId: string; + name?: string; + summary?: string; + findings?: string[]; + rating?: number; + ratingExplanation?: string; + attributes?: Record; + }): Promise { + const data = { + ...(options.name && { name: options.name }), + ...(options.summary && { summary: options.summary }), + ...(options.findings && { findings: options.findings }), + ...(options.rating && { rating: options.rating }), + ...(options.ratingExplanation && { + rating_explanation: options.ratingExplanation, + }), + ...(options.attributes && { attributes: options.attributes }), + }; + return this.client.makeRequest( + "POST", + `graphs/${options.collectionId}/communities/${options.communityId}`, + { + data, + }, + ); + } + + /** + * Delete a community in a graph. + * @param collectionId The collection ID corresponding to the graph + * @param communityId Community ID to delete + * @returns + */ + @feature("graphs.deleteCommunity") + async deleteCommunity(options: { + collectionId: string; + communityId: string; + }): Promise { + return this.client.makeRequest( + "DELETE", + `graphs/${options.collectionId}/communities/${options.communityId}`, + ); + } + + /** + * Adds documents to a graph by copying their entities and relationships. + * + * This endpoint: + * 1. Copies document entities to the graph_entity table + * 2. Copies document relationships to the graph_relationship table + * 3. Associates the documents with the graph + * + * When a document is added: + * - Its entities and relationships are copied to graph-specific tables + * - Existing entities/relationships are updated by merging their properties + * - The document ID is recorded in the graph's document_ids array + * + * Documents added to a graph will contribute their knowledge to: + * - Graph analysis and querying + * - Community detection + * - Knowledge graph enrichment + * + * The user must have access to both the graph and the documents being added. + * @param collectionId The collection ID corresponding to the graph + * @returns + */ + @feature("graphs.pull") + async pull(options: { + collectionId: string; + }): Promise { + return this.client.makeRequest( + "POST", + `graphs/${options.collectionId}/pull`, + ); + } + + /** + * Removes a document from a graph and removes any associated entities + * + * This endpoint: + * 1. Removes the document ID from the graph's document_ids array + * 2. Optionally deletes the document's copied entities and relationships + * + * The user must have access to both the graph and the document being removed. + * @param collectionId The collection ID corresponding to the graph + * @param documentId The document ID to remove + * @returns + */ + @feature("graphs.removeDocument") + async removeDocument(options: { + collectionId: string; + documentId: string; + }): Promise { + return this.client.makeRequest( + "DELETE", + `graphs/${options.collectionId}/documents/${options.documentId}`, + ); + } + + /** + * Creates communities in the graph by analyzing entity relationships and similarities. + * + * Communities are created through the following process: + * 1. Analyzes entity relationships and metadata to build a similarity graph + * 2. Applies advanced community detection algorithms (e.g. Leiden) to identify densely connected groups + * 3. Creates hierarchical community structure with multiple granularity levels + * 4. Generates natural language summaries and statistical insights for each community + * + * The resulting communities can be used to: + * - Understand high-level graph structure and organization + * - Identify key entity groupings and their relationships + * - Navigate and explore the graph at different levels of detail + * - Generate insights about entity clusters and their characteristics + * + * The community detection process is configurable through settings like: + * - Community detection algorithm parameters + * - Summary generation prompt * * @param options * @returns */ - async delete(options: { id: string }): Promise { - return this.client.makeRequest("DELETE", `graphs/${options.id}`); + @feature("graphs.buildCommunities") + async buildCommunities(options: { + collectionId: string; + runType?: string; + kgEntichmentSettings?: Record; + runWithOrchestration?: boolean; + }): Promise { + return this.client.makeRequest( + "POST", + `graphs/${options.collectionId}/communities/build`, + ); } } diff --git a/js/sdk/src/v3/clients/indices.ts b/js/sdk/src/v3/clients/indices.ts index 384093a52..cae88989b 100644 --- a/js/sdk/src/v3/clients/indices.ts +++ b/js/sdk/src/v3/clients/indices.ts @@ -1,3 +1,4 @@ +import { feature } from "../../feature"; import { r2rClient } from "../../r2rClient"; import { IndexConfig, @@ -14,6 +15,7 @@ export class IndiciesClient { * @param runWithOrchestration Whether to run index creation as an orchestrated task. * @returns */ + @feature("indices.create") async create(options: { config: IndexConfig; runWithOrchestration?: boolean; @@ -37,6 +39,7 @@ export class IndiciesClient { * @param limit Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. * @returns */ + @feature("indices.list") async list(options?: { filters?: Record; offset?: number; @@ -62,6 +65,7 @@ export class IndiciesClient { * @param tableName The name of the table where the index is stored. * @returns */ + @feature("indices.retrieve") async retrieve(options: { tableName: string; indexName: string; @@ -78,6 +82,7 @@ export class IndiciesClient { * @param tableName The name of the table where the index is stored. * @returns */ + @feature("indices.delete") async delete(options: { tableName: string; indexName: string; diff --git a/js/sdk/src/v3/clients/prompts.ts b/js/sdk/src/v3/clients/prompts.ts index 1cdf804bf..d247251f9 100644 --- a/js/sdk/src/v3/clients/prompts.ts +++ b/js/sdk/src/v3/clients/prompts.ts @@ -1,3 +1,4 @@ +import { feature } from "../../feature"; import { r2rClient } from "../../r2rClient"; import { WrappedBooleanResponse, @@ -9,6 +10,17 @@ import { export class PromptsClient { constructor(private client: r2rClient) {} + /** + * Create a new prompt with the given configuration. + * + * This endpoint allows superusers to create a new prompt with a + * specified name, template, and input types. + * @param name The name of the prompt + * @param template The template string for the prompt + * @param inputTypes A dictionary mapping input names to their types + * @returns + */ + @feature("prompts.create") async create(options: { name: string; template: string; @@ -19,10 +31,28 @@ export class PromptsClient { }); } + /** + * List all available prompts. + * + * This endpoint retrieves a list of all prompts in the system. + * Only superusers can access this endpoint. + * @returns + */ + @feature("prompts.list") async list(): Promise { return this.client.makeRequest("GET", "prompts"); } + /** + * Get a specific prompt by name, optionally with inputs and override. + * + * This endpoint retrieves a specific prompt and allows for optional + * inputs and template override. + * Only superusers can access this endpoint. + * @param options + * @returns + */ + @feature("prompts.retrieve") async retrieve(options: { name: string; inputs?: string[]; @@ -41,6 +71,14 @@ export class PromptsClient { }); } + /** + * Update an existing prompt's template and/or input types. + * + * This endpoint allows superusers to update the template and input types of an existing prompt. + * @param options + * @returns + */ + @feature("prompts.update") async update(options: { name: string; template?: string; @@ -61,6 +99,13 @@ export class PromptsClient { }); } + /** + * Delete a prompt by name. + * + * This endpoint allows superusers to delete an existing prompt. + * @param name The name of the prompt to delete + * @returns + */ async delete(options: { name: string }): Promise { return this.client.makeRequest("DELETE", `prompts/${options.name}`); } diff --git a/js/sdk/src/v3/clients/retrieval.ts b/js/sdk/src/v3/clients/retrieval.ts index 914ee2d7e..08b9f5028 100644 --- a/js/sdk/src/v3/clients/retrieval.ts +++ b/js/sdk/src/v3/clients/retrieval.ts @@ -6,10 +6,27 @@ import { KGSearchSettings, GenerationConfig, } from "../../models"; +import { feature } from "../../feature"; export class RetrievalClient { constructor(private client: r2rClient) {} + /** + * Perform a search query on the vector database and knowledge graph and + * any other configured search engines. + * + * This endpoint allows for complex filtering of search results using + * PostgreSQL-based queries. Filters can be applied to various fields + * such as document_id, and internal metadata values. + * + * Allowed operators include: `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, + * `like`, `ilike`, `in`, and `nin`. + * @param query Search query to find relevant documents + * @param VectorSearchSettings Settings for vector-based search + * @param KGSearchSettings Settings for knowledge graph search + * @returns + */ + @feature("retrieval.search") async search(options: { query: string; vectorSearchSettings?: ChunkSearchSettings | Record; @@ -30,6 +47,23 @@ export class RetrievalClient { }); } + /** + * Execute a RAG (Retrieval-Augmented Generation) query. + * + * This endpoint combines search results with language model generation. + * It supports the same filtering capabilities as the search endpoint, + * allowing for precise control over the retrieved context. + * + * The generation process can be customized using the `rag_generation_config` parameter. + * @param query + * @param ragGenerationConfig Configuration for RAG generation + * @param vectorSearchSettings Settings for vector-based search + * @param kgSearchSettings Settings for knowledge graph search + * @param taskPromptOverride Optional custom prompt to override default + * @param includeTitleIfAvailable Include document titles in responses when available + * @returns + */ + @feature("retrieval.rag") async rag(options: { query: string; ragGenerationConfig?: GenerationConfig | Record; @@ -66,6 +100,7 @@ export class RetrievalClient { } } + @feature("retrieval.streamRag") private async streamRag( ragData: Record, ): Promise> { @@ -82,8 +117,55 @@ export class RetrievalClient { ); } + /** + * Engage with an intelligent RAG-powered conversational agent for complex + * information retrieval and analysis. + * + * This advanced endpoint combines retrieval-augmented generation (RAG) + * with a conversational AI agent to provide detailed, context-aware + * responses based on your document collection. + * + * The agent can: + * - Maintain conversation context across multiple interactions + * - Dynamically search and retrieve relevant information from both + * vector and knowledge graph sources + * - Break down complex queries into sub-questions for comprehensive + * answers + * - Cite sources and provide evidence-based responses + * - Handle follow-up questions and clarifications + * - Navigate complex topics with multi-step reasoning + * + * Key Features: + * - Hybrid search combining vector and knowledge graph approaches + * - Contextual conversation management with conversation_id tracking + * - Customizable generation parameters for response style and length + * - Source document citation with optional title inclusion + * - Streaming support for real-time responses + * - Branch management for exploring different conversation paths + * + * Common Use Cases: + * - Research assistance and literature review + * - Document analysis and summarization + * - Technical support and troubleshooting + * - Educational Q&A and tutoring + * - Knowledge base exploration + * + * The agent uses both vector search and knowledge graph capabilities to + * find and synthesize information, providing detailed, factual responses + * with proper attribution to source documents. + * @param message Current message to process + * @param ragGenerationConfig Configuration for RAG generation + * @param vectorSearchSettings Settings for vector-based search + * @param kgSearchSettings Settings for knowledge graph search + * @param taskPromptOverride Optional custom prompt to override default + * @param includeTitleIfAvailable Include document titles in responses when available + * @param conversationId ID of the conversation + * @param branchId ID of the conversation branch + * @returns + */ + @feature("retrieval.agent") async agent(options: { - messages: Message[]; + message: Message; ragGenerationConfig?: GenerationConfig | Record; vectorSearchSettings?: ChunkSearchSettings | Record; kgSearchSettings?: KGSearchSettings | Record; @@ -93,7 +175,7 @@ export class RetrievalClient { branchId?: string; }): Promise> { const data: Record = { - messages: options.messages, + message: options.message, ...(options.vectorSearchSettings && { vectorSearchSettings: options.vectorSearchSettings, }), @@ -126,6 +208,7 @@ export class RetrievalClient { } } + @feature("retrieval.streamAgent") private async streamAgent( agentData: Record, ): Promise> { @@ -142,6 +225,20 @@ export class RetrievalClient { ); } + /** + * Generate completions for a list of messages. + * + * This endpoint uses the language model to generate completions for + * the provided messages. The generation process can be customized using + * the generation_config parameter. + * + * The messages list should contain alternating user and assistant + * messages, with an optional system message at the start. Each message + * should have a 'role' and 'content'. + * @param messages List of messages to generate completion for + * @returns + */ + @feature("retrieval.completion") async completion(options: { messages: Message[]; generationConfig?: GenerationConfig | Record; @@ -162,6 +259,7 @@ export class RetrievalClient { } } + @feature("retrieval.streamCompletion") private async streamCompletion( ragData: Record, ): Promise> { diff --git a/js/sdk/src/v3/clients/system.ts b/js/sdk/src/v3/clients/system.ts index 6d32ec828..7fdf30773 100644 --- a/js/sdk/src/v3/clients/system.ts +++ b/js/sdk/src/v3/clients/system.ts @@ -1,3 +1,4 @@ +import { feature } from "../../feature"; import { r2rClient } from "../../r2rClient"; import { WrappedGenericMessageResponse, @@ -12,6 +13,7 @@ export class SystemClient { /** * Check the health of the R2R server. */ + @feature("system.health") async health(): Promise { return await this.client.makeRequest("GET", "health"); } @@ -21,6 +23,7 @@ export class SystemClient { * @param options * @returns */ + @feature("system.logs") async logs(options: { runTypeFilter?: string; offset?: number; @@ -42,14 +45,17 @@ export class SystemClient { * Get the configuration settings for the R2R server. * @returns */ + @feature("system.settings") async settings(): Promise { return await this.client.makeRequest("GET", "system/settings"); } /** - * Get statistics about the server, including the start time, uptime, CPU usage, and memory usage. + * Get statistics about the server, including the start time, uptime, + * CPU usage, and memory usage. * @returns */ + @feature("system.status") async status(): Promise { return await this.client.makeRequest("GET", "system/status"); } diff --git a/js/sdk/src/v3/clients/users.ts b/js/sdk/src/v3/clients/users.ts index b36e86aac..eadd37a7a 100644 --- a/js/sdk/src/v3/clients/users.ts +++ b/js/sdk/src/v3/clients/users.ts @@ -1,3 +1,4 @@ +import { feature } from "../../feature"; import { r2rClient } from "../../r2rClient"; import { WrappedBooleanResponse, @@ -17,6 +18,7 @@ export class UsersClient { * @param password User's password * @returns */ + @feature("users.register") async register(options: { email: string; password: string; @@ -33,6 +35,7 @@ export class UsersClient { * @param password User's password * @returns */ + @feature("users.delete") async delete(options: { id: string; password: string; @@ -49,6 +52,7 @@ export class UsersClient { * @param email User's email address * @param verificationCode Verification code sent to the user's email */ + @feature("users.verifyEmail") async verifyEmail(options: { email: string; verificationCode: string; @@ -64,6 +68,7 @@ export class UsersClient { * @param password User's password * @returns */ + @feature("users.login") async login(options: { email: string; password: string }): Promise { const response = await this.client.makeRequest("POST", "users/login", { data: { @@ -90,7 +95,7 @@ export class UsersClient { * @param accessToken Existing access token * @returns */ - // FIXME: What is going on here... + @feature("users.loginWithToken") async loginWithToken(options: { accessToken: string }): Promise { this.client.setTokens(options.accessToken, null); @@ -115,6 +120,7 @@ export class UsersClient { * Log out the current user. * @returns */ + @feature("users.logout") async logout(): Promise { const response = await this.client.makeRequest("POST", "users/logout"); this.client.setTokens(null, null); @@ -125,6 +131,7 @@ export class UsersClient { * Refresh the access token using the refresh token. * @returns */ + @feature("users.refreshAccessToken") async refreshAccessToken(): Promise { const refreshToken = this.client.getRefreshToken(); if (!refreshToken) { @@ -160,6 +167,7 @@ export class UsersClient { * @param new_password User's new password * @returns */ + @feature("users.changePassword") async changePassword(options: { current_password: string; new_password: string; @@ -174,6 +182,7 @@ export class UsersClient { * @param email User's email address * @returns */ + @feature("users.requestPasswordReset") async requestPasswordReset(options: { email: string; }): Promise { @@ -182,6 +191,13 @@ export class UsersClient { }); } + /** + * Reset a user's password using a reset token. + * @param reset_token Reset token sent to the user's email + * @param new_password New password for the user + * @returns + */ + @feature("users.resetPassword") async resetPassword(options: { reset_token: string; new_password: string; @@ -200,6 +216,7 @@ export class UsersClient { * @param limit Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. * @returns */ + @feature("users.list") async list(options?: { email?: string; is_active?: boolean; @@ -232,6 +249,7 @@ export class UsersClient { * @param id User ID to retrieve * @returns */ + @feature("users.retrieve") async retrieve(options: { id: string }): Promise { return this.client.makeRequest("GET", `users/${options.id}`); } @@ -240,6 +258,7 @@ export class UsersClient { * Get detailed information about the currently authenticated user. * @returns */ + @feature("users.me") async me(): Promise { return this.client.makeRequest("GET", `users/me`); } @@ -254,6 +273,7 @@ export class UsersClient { * @param profilePicture Optional new profile picture for the user * @returns */ + @feature("users.update") async update(options: { id: string; email?: string; @@ -284,6 +304,7 @@ export class UsersClient { * @param limit Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. * @returns */ + @feature("users.listCollections") async listCollections(options: { id: string; offset?: number; @@ -305,6 +326,7 @@ export class UsersClient { * @param collectionId Collection ID to add the user to * @returns */ + @feature("users.addToCollection") async addToCollection(options: { id: string; collectionId: string; @@ -321,6 +343,7 @@ export class UsersClient { * @param collectionId Collection ID to remove the user from * @returns */ + @feature("users.removeFromCollection") async removeFromCollection(options: { id: string; collectionId: string; diff --git a/py/compose.yaml b/py/compose.yaml index 2621c06d7..7024c83ba 100644 --- a/py/compose.yaml +++ b/py/compose.yaml @@ -36,7 +36,7 @@ services: -c max_connections=${R2R_POSTGRES_MAX_CONNECTIONS:-1024} r2r: - image: r2r/test + image: ${R2R_IMAGE:-ragtoriches/prod:latest} build: context: . args: diff --git a/py/core/base/abstractions/__init__.py b/py/core/base/abstractions/__init__.py index 8b4b1a5e5..3cc40c678 100644 --- a/py/core/base/abstractions/__init__.py +++ b/py/core/base/abstractions/__init__.py @@ -65,7 +65,6 @@ KGEntityResult, KGGlobalResult, KGRelationshipResult, - KGSearchMethod, KGSearchResultType, SearchSettings, ) @@ -131,7 +130,6 @@ # Search abstractions "AggregateSearchResult", "GraphSearchResult", - "KGSearchMethod", "KGSearchResultType", "KGEntityResult", "KGRelationshipResult", diff --git a/py/core/base/api/models/__init__.py b/py/core/base/api/models/__init__.py index 3049ac68e..fd691bba9 100644 --- a/py/core/base/api/models/__init__.py +++ b/py/core/base/api/models/__init__.py @@ -22,9 +22,7 @@ Community, Entity, GraphResponse, - KGCreationResponse, KGEnrichmentResponse, - KGEntityDeduplicationResponse, KGTunePromptResponse, Relationship, WrappedCommunitiesResponse, @@ -33,9 +31,7 @@ WrappedEntityResponse, WrappedGraphResponse, WrappedGraphsResponse, - WrappedKGCreationResponse, WrappedKGEnrichmentResponse, - WrappedKGEntityDeduplicationResponse, WrappedKGTunePromptResponse, WrappedRelationshipResponse, WrappedRelationshipsResponse, @@ -104,20 +100,16 @@ "Entity", "Relationship", "Community", - "KGCreationResponse", "KGEnrichmentResponse", "KGTunePromptResponse", - "KGEntityDeduplicationResponse", "WrappedEntityResponse", "WrappedEntitiesResponse", "WrappedRelationshipResponse", "WrappedRelationshipsResponse", "WrappedCommunityResponse", "WrappedCommunitiesResponse", - "WrappedKGCreationResponse", "WrappedKGEnrichmentResponse", "WrappedKGTunePromptResponse", - "WrappedKGEntityDeduplicationResponse", # TODO: Need to review anything above this "GraphResponse", "WrappedGraphResponse", diff --git a/py/core/base/providers/database.py b/py/core/base/providers/database.py index d11d496fd..262c978ad 100644 --- a/py/core/base/providers/database.py +++ b/py/core/base/providers/database.py @@ -604,7 +604,7 @@ async def list_chunks( class EntityHandler(Handler): @abstractmethod - async def create(self, *args: Any, **kwargs: Any) -> None: + async def create(self, *args: Any, **kwargs: Any) -> Entity: """Create entities in storage.""" pass @@ -614,7 +614,7 @@ async def get(self, *args: Any, **kwargs: Any) -> list[Entity]: pass @abstractmethod - async def update(self, *args: Any, **kwargs: Any) -> None: + async def update(self, *args: Any, **kwargs: Any) -> Entity: """Update entities in storage.""" pass @@ -626,7 +626,7 @@ async def delete(self, *args: Any, **kwargs: Any) -> None: class RelationshipHandler(Handler): @abstractmethod - async def create(self, *args: Any, **kwargs: Any) -> None: + async def create(self, *args: Any, **kwargs: Any) -> Relationship: """Add relationships to storage.""" pass @@ -636,7 +636,7 @@ async def get(self, *args: Any, **kwargs: Any) -> list[Relationship]: pass @abstractmethod - async def update(self, *args: Any, **kwargs: Any) -> None: + async def update(self, *args: Any, **kwargs: Any) -> Relationship: """Update relationships in storage.""" pass @@ -648,7 +648,7 @@ async def delete(self, *args: Any, **kwargs: Any) -> None: class CommunityHandler(Handler): @abstractmethod - async def create(self, *args: Any, **kwargs: Any) -> None: + async def create(self, *args: Any, **kwargs: Any) -> Community: """Create communities in storage.""" pass @@ -658,7 +658,7 @@ async def get(self, *args: Any, **kwargs: Any) -> list[Community]: pass @abstractmethod - async def update(self, *args: Any, **kwargs: Any) -> None: + async def update(self, *args: Any, **kwargs: Any) -> Community: """Update communities in storage.""" pass diff --git a/py/core/main/api/v2/kg_router.py b/py/core/main/api/v2/kg_router.py index e2a1be46a..d4a344eeb 100644 --- a/py/core/main/api/v2/kg_router.py +++ b/py/core/main/api/v2/kg_router.py @@ -1,6 +1,6 @@ import logging from pathlib import Path -from typing import Optional, Union +from typing import Optional from uuid import UUID import yaml @@ -11,9 +11,6 @@ from core.base.api.models import ( WrappedCommunitiesResponse, WrappedEntitiesResponse, - WrappedKGCreationResponse, - WrappedKGEnrichmentResponse, - WrappedKGEntityDeduplicationResponse, WrappedKGTunePromptResponse, WrappedRelationshipsResponse, ) @@ -38,7 +35,7 @@ def __init__( self, service: KgService, orchestration_provider: Optional[ - Union[HatchetOrchestrationProvider, SimpleOrchestrationProvider] + HatchetOrchestrationProvider | SimpleOrchestrationProvider ] = None, run_type: RunType = RunType.KG, ): @@ -68,7 +65,7 @@ def _register_workflows(self): ) else: workflow_messages["extract-triples"] = ( - "Document entities and relationships extracted successfully. To generate GraphRAG communities, run cluster on the collection this document belongs to." + "Document entities and relationships extracted successfully. To generate GraphRAG communities, POST to `/graphs//communities/build` with a collection this document belongs to." ) workflow_messages["build-communities"] = ( "Graph communities created successfully. You can view the communities at http://localhost:7272/v2/communities" @@ -103,7 +100,7 @@ async def create_graph( ), run_with_orchestration: Optional[bool] = Body(True), auth_user=Depends(self.service.providers.auth.auth_wrapper), - ): # -> WrappedKGCreationResponse: # type: ignore + ): """ Creating a graph on your documents. This endpoint takes input a list of document ids and KGCreationSettings. If document IDs are not provided, the graph will be created on all documents in the system. @@ -188,7 +185,7 @@ async def enrich_graph( ), run_with_orchestration: Optional[bool] = Body(True), auth_user=Depends(self.service.providers.auth.auth_wrapper), - ): # -> WrappedKGEnrichmentResponse: + ): """ This endpoint enriches the graph with additional information. It creates communities of nodes based on their similarity and adds embeddings to the graph. @@ -399,7 +396,7 @@ async def deduplicate_entities( None, description="Settings for the deduplication process." ), auth_user=Depends(self.service.providers.auth.auth_wrapper), - ) -> WrappedKGEntityDeduplicationResponse: + ): """ Deduplicate entities in the knowledge graph. """ diff --git a/py/core/main/api/v3/collections_router.py b/py/core/main/api/v3/collections_router.py index 9312d4552..b5ceacc16 100644 --- a/py/core/main/api/v3/collections_router.py +++ b/py/core/main/api/v3/collections_router.py @@ -8,13 +8,11 @@ from core.base import KGCreationSettings, KGRunType, R2RException, RunType from core.base.api.models import ( GenericBooleanResponse, - GenericMessageResponse, WrappedBooleanResponse, WrappedCollectionResponse, WrappedCollectionsResponse, WrappedDocumentsResponse, WrappedGenericMessageResponse, - WrappedKGCreationResponse, WrappedUsersResponse, ) from core.providers import ( @@ -1066,7 +1064,7 @@ async def extract( description="Whether to run the entities and relationships extraction process with orchestration.", ), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> WrappedKGCreationResponse: # type: ignore + ): """ Extracts entities and relationships from a document. The entities and relationships extraction process involves: diff --git a/py/core/main/api/v3/documents_router.py b/py/core/main/api/v3/documents_router.py index f3b329979..d42f7fa7a 100644 --- a/py/core/main/api/v3/documents_router.py +++ b/py/core/main/api/v3/documents_router.py @@ -12,23 +12,18 @@ from pydantic import Json from core.base import R2RException, RunType, generate_document_id -from core.base.abstractions import ( - Entity, - GraphBuildSettings, - KGCreationSettings, - KGRunType, - Relationship, -) +from core.base.abstractions import KGCreationSettings, KGRunType from core.base.api.models import ( GenericBooleanResponse, - PaginatedResultsWrapper, WrappedBooleanResponse, WrappedChunksResponse, WrappedCollectionsResponse, WrappedDocumentResponse, WrappedDocumentsResponse, + WrappedEntitiesResponse, + WrappedGenericMessageResponse, WrappedIngestionResponse, - WrappedKGCreationResponse, + WrappedRelationshipsResponse, ) from core.providers import ( HatchetOrchestrationProvider, @@ -763,16 +758,16 @@ async def list_chunks( ..., description="The ID of the document to retrieve chunks for.", ), - offset: Optional[int] = Query( + offset: int = Query( 0, ge=0, - description="The offset of the first chunk to retrieve.", + description="Specifies the number of objects to skip. Defaults to 0.", ), - limit: Optional[int] = Query( + limit: int = Query( 100, - ge=0, - le=20_000, - description="The maximum number of chunks to retrieve, up to 20,000.", + ge=1, + le=1000, + description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.", ), include_vectors: Optional[bool] = Query( False, @@ -1265,7 +1260,7 @@ async def extract( description="Whether to run the entities and relationships extraction process with orchestration.", ), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> WrappedKGCreationResponse: # type: ignore + ) -> WrappedGenericMessageResponse: """ Extracts entities and relationships from a document. The entities and relationships extraction process involves: @@ -1292,7 +1287,6 @@ async def extract( settings_dict=settings, # type: ignore ) - # If the run type is estimate, return an estimate of the creation cost if run_type is KGRunType.ESTIMATE: return { # type: ignore "message": "Estimate retrieved successfully", @@ -1305,30 +1299,27 @@ async def extract( kg_creation_settings=server_kg_creation_settings, ), } - else: - # Otherwise, create the graph - if run_with_orchestration: - workflow_input = { - "document_id": str(id), - "kg_creation_settings": server_kg_creation_settings.model_dump_json(), - "user": auth_user.json(), - } - return await self.orchestration_provider.run_workflow( # type: ignore - "extract-triples", {"request": workflow_input}, {} - ) - else: - from core.main.orchestration import simple_kg_factory + if run_with_orchestration: + workflow_input = { + "document_id": str(id), + "kg_creation_settings": server_kg_creation_settings.model_dump_json(), + "user": auth_user.json(), + } - logger.info( - "Running extract-triples without orchestration." - ) - simple_kg = simple_kg_factory(self.services["kg"]) - await simple_kg["extract-triples"](workflow_input) # type: ignore - return { # type: ignore - "message": "Graph created successfully.", - "task_id": None, - } + return await self.orchestration_provider.run_workflow( + "extract-triples", {"request": workflow_input}, {} + ) + else: + from core.main.orchestration import simple_kg_factory + + logger.info("Running extract-triples without orchestration.") + simple_kg = simple_kg_factory(self.services["kg"]) + await simple_kg["extract-triples"](workflow_input) + return { # type: ignore + "message": "Graph created successfully.", + "task_id": None, + } @self.router.get( "/documents/{id}/entities", @@ -1359,23 +1350,23 @@ async def get_entities( ..., description="The ID of the document to retrieve entities from.", ), - offset: Optional[int] = Query( + offset: int = Query( 0, ge=0, - description="The offset of the first entity to retrieve.", + description="Specifies the number of objects to skip. Defaults to 0.", ), - limit: Optional[int] = Query( + limit: int = Query( 100, - ge=0, - le=20_000, - description="The maximum number of entities to retrieve, up to 20,000.", + ge=1, + le=1000, + description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.", ), include_embeddings: Optional[bool] = Query( False, description="Whether to include vector embeddings in the response.", ), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> PaginatedResultsWrapper[list[Entity]]: + ) -> WrappedEntitiesResponse: """ Retrieves the entities that were extracted from a document. These represent important semantic elements like people, places, organizations, concepts, etc. @@ -1404,7 +1395,7 @@ async def get_entities( raise R2RException("Document not found.", 404) # Get all entities for this document from the document_entity table - entities, total_count = ( + entities, count = ( await self.providers.database.graph_handler.entities.get( parent_id=id, store_type="document", @@ -1414,7 +1405,7 @@ async def get_entities( ) ) - return entities, {"total_entries": total_count} + return entities, {"total_entries": count} # type: ignore @self.router.get( "/documents/{id}/relationships", @@ -1484,16 +1475,16 @@ async def list_relationships( ..., description="The ID of the document to retrieve relationships for.", ), - offset: Optional[int] = Query( + offset: int = Query( 0, ge=0, - description="The offset of the first relationship to retrieve.", + description="Specifies the number of objects to skip. Defaults to 0.", ), - limit: Optional[int] = Query( + limit: int = Query( 100, - ge=0, - le=20_000, - description="The maximum number of relationships to retrieve, up to 20,000.", + ge=1, + le=1000, + description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.", ), entity_names: Optional[list[str]] = Query( None, @@ -1504,7 +1495,7 @@ async def list_relationships( description="Filter relationships by specific relationship types.", ), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> PaginatedResultsWrapper[list[Relationship]]: + ) -> WrappedRelationshipsResponse: """ Retrieves the relationships between entities that were extracted from a document. These represent connections and interactions between entities found in the text. @@ -1533,7 +1524,7 @@ async def list_relationships( raise R2RException("Document not found.", 404) # Get relationships for this document - relationships, total_count = ( + relationships, count = ( await self.providers.database.graph_handler.relationships.get( parent_id=id, store_type="document", @@ -1544,7 +1535,7 @@ async def list_relationships( ) ) - return relationships, {"total_entries": total_count} + return relationships, {"total_entries": count} # type: ignore @staticmethod async def _process_file(file): diff --git a/py/core/main/api/v3/graph_router.py b/py/core/main/api/v3/graph_router.py index eb3e4981d..0f383cb9d 100644 --- a/py/core/main/api/v3/graph_router.py +++ b/py/core/main/api/v3/graph_router.py @@ -3,32 +3,20 @@ from typing import Optional from uuid import UUID -from fastapi import Body, Depends, Path, Query, Request +from fastapi import Body, Depends, Path, Query from core.base import R2RException, RunType -from core.base.abstractions import ( - DataLevel, - Entity, - GraphBuildSettings, - KGCreationSettings, - KGRunType, - Relationship, -) +from core.base.abstractions import KGRunType from core.base.api.models import ( GenericBooleanResponse, - GenericMessageResponse, - PaginatedResultsWrapper, WrappedBooleanResponse, WrappedCommunitiesResponse, WrappedCommunityResponse, WrappedEntitiesResponse, WrappedEntityResponse, - WrappedGenericMessageResponse, WrappedGraphResponse, WrappedGraphsResponse, - WrappedKGCreationResponse, WrappedKGEnrichmentResponse, - WrappedKGEntityDeduplicationResponse, WrappedKGTunePromptResponse, WrappedRelationshipResponse, WrappedRelationshipsResponse, @@ -46,18 +34,6 @@ logger = logging.getLogger() -from enum import Enum - - -class GraphObjectType(str, Enum): - ENTITIES = "entities" - RELATIONSHIPS = "relationships" - COLLECTIONS = "collections" - DOCUMENTS = "documents" - - def __str__(self): - return self.value - class GraphRouter(BaseRouterV3): def __init__( @@ -71,15 +47,6 @@ def __init__( ): super().__init__(providers, services, orchestration_provider, run_type) - def _get_path_level(self, request: Request) -> DataLevel: - path = request.url.path - if "/chunks/" in path: - return DataLevel.CHUNK - elif "/documents/" in path: - return DataLevel.DOCUMENT - else: - return DataLevel.GRAPH - async def _deduplicate_entities( self, collection_id: UUID, @@ -87,7 +54,7 @@ async def _deduplicate_entities( run_type: Optional[KGRunType] = KGRunType.ESTIMATE, run_with_orchestration: bool = True, auth_user=None, - ) -> WrappedKGEntityDeduplicationResponse: + ): """Deduplicates entities in the knowledge graph using LLM-based analysis. The deduplication process: @@ -161,110 +128,6 @@ async def _get_collection_id( return collection_id def _setup_routes(self): - # @self.router.post( - # "/graphs", - # summary="Create a new graph", - # openapi_extra={ - # "x-codeSamples": [ - # { # TODO: Verify - # "lang": "Python", - # "source": textwrap.dedent( - # """ - # from r2r import R2RClient - - # client = R2RClient("http://localhost:7272") - # # when using auth, do client.login(...) - - # result = client.graphs.create( - # graph={ - # "name": "New Graph", - # "description": "New Description" - # } - # ) - # """ - # ), - # }, - # { - # "lang": "JavaScript", - # "source": textwrap.dedent( - # """ - # const { r2rClient } = require("r2r-js"); - - # const client = new r2rClient("http://localhost:7272"); - - # function main() { - # const response = await client.documents.create({ - # name: "New Graph", - # description: "New Description", - # }); - # } - - # main(); - # """ - # ), - # }, - # ] - # }, - # ) - # @self.base_endpoint - # async def create_graph( - # collection_id: Optional[UUID] = Body( - # None, - # description="Collection ID to associate with the graph. If not provided, uses user's default collection.", - # ), - # name: Optional[str] = Body( - # None, description="The name of the graph" - # ), - # description: Optional[str] = Body( - # None, description="An optional description of the graph" - # ), - # auth_user=Depends(self.providers.auth.auth_wrapper), - # ) -> WrappedGraphResponse: - # """ - # Creates a new empty graph. - - # This is the first step in building a knowledge graph. After creating the graph, you can: - - # 1. Add data to the graph: - # - Manually add entities and relationships via the /entities and /relationships endpoints - # - Automatically extract entities and relationships from documents via the /graphs/{id}/documents endpoint - - # 2. Build communities: - # - Build communities of related entities via the /graphs/{collection_id}/communities/build endpoint - - # 3. Update graph metadata: - # - Modify the graph name, description and settings via the /graphs/{collection_id} endpoint - - # The graph ID returned by this endpoint is required for all subsequent operations on the graph. - - # Raises: - # R2RException: If a graph already exists for the given collection. - # """ - - # collection_id = await self._get_collection_id( - # collection_id, auth_user - # ) - - # # Check if a graph already exists for this collection - # existing_graphs = await self.services["kg"].list_graphs( - # collection_id=collection_id, - # offset=0, - # limit=1, - # ) - - # if existing_graphs["total_entries"] > 0: - # raise R2RException( - # f"A graph already exists for collection {collection_id}. Only one graph per collection is allowed.", - # 409, # HTTP 409 Conflict status code - # ) - - # return await self.services["kg"].create_new_graph( - # user_id=auth_user.id, - # collection_id=collection_id, - # name=name, - # description=description, - # ) - @self.router.get( "/graphs", summary="List graphs", @@ -279,12 +142,7 @@ def _setup_routes(self): client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.graphs.create( - graph={ - "name": "New Graph", - "description": "New Description" - } - ) + result = client.graphs.list() """ ), }, @@ -297,7 +155,7 @@ def _setup_routes(self): const client = new r2rClient("http://localhost:7272"); function main() { - const response = await client.graphs.list(); + const response = await client.graphs.list({}); } main(); @@ -381,7 +239,7 @@ async def list_graphs( function main() { const response = await client.graphs.retrieve({ - collection_id: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7" }); } @@ -444,7 +302,8 @@ async def build_communities( run_with_orchestration: Optional[bool] = Body(True), auth_user=Depends(self.providers.auth.auth_wrapper), ): # -> WrappedKGEnrichmentResponse: - """Creates communities in the graph by analyzing entity relationships and similarities. + """ + Creates communities in the graph by analyzing entity relationships and similarities. Communities are created through the following process: 1. Analyzes entity relationships and metadata to build a similarity graph @@ -463,8 +322,14 @@ async def build_communities( - Summary generation prompt """ print("collection_id = ", collection_id) - if not auth_user.is_superuser: - logger.warning("Implement permission checks here.") + if ( + not auth_user.is_superuser + and collection_id not in auth_user.graph_ids + ): + raise R2RException( + "The currently authenticated user does not have access to the specified graph.", + 403, + ) # If no collection ID is provided, use the default user collection # id = generate_default_user_collection_id(auth_user.id) @@ -541,7 +406,7 @@ async def build_communities( function main() { const response = await client.graphs.reset({ - collection_id: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7" }); } @@ -669,7 +534,44 @@ async def update_graph( description=description, ) - @self.router.get("/graphs/{collection_id}/entities") + @self.router.get( + "/graphs/{collection_id}/entities", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.graphs.get_entities(collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7") + """ + ), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent( + """ + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient("http://localhost:7272"); + + function main() { + const response = await client.graphs.get_entities({ + collection_id: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + }); + } + + main(); + """ + ), + }, + ], + }, + ) @self.base_endpoint async def get_entities( collection_id: UUID = Path( @@ -688,18 +590,24 @@ async def get_entities( description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.", ), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> PaginatedResultsWrapper[list[Entity]]: + ) -> WrappedEntitiesResponse: """Lists all entities in the graph with pagination support.""" - # return await self.services["kg"].get_entities( - # id, offset, limit, auth_user - # ) - entities, count = ( - await self.providers.database.graph_handler.get_entities( - collection_id, offset, limit + if ( + not auth_user.is_superuser + and collection_id not in auth_user.graph_ids + ): + raise R2RException( + "The currently authenticated user does not have access to the specified graph.", + 403, ) + + entities, count = await self.services["kg"].get_entities( + parent_id=collection_id, + offset=offset, + limit=limit, ) - print("entities = ", entities) - return entities, { + + return entities, { # type: ignore "total_entries": count, } @@ -710,39 +618,37 @@ async def create_entity( ..., description="The collection ID corresponding to the graph to add the entity to.", ), - entity: Entity = Body(..., description="The entity to create"), + name: str = Body( + ..., description="The name of the entity to create." + ), + description: Optional[str] = Body( + None, description="The description of the entity to create." + ), + category: Optional[str] = Body( + None, description="The category of the entity to create." + ), + metadata: Optional[dict] = Body( + None, description="The metadata of the entity to create." + ), auth_user=Depends(self.providers.auth.auth_wrapper), - ): # -> WrappedEntityResponse: + ) -> WrappedEntityResponse: """Creates a new entity in the graph.""" if ( not auth_user.is_superuser and collection_id not in auth_user.graph_ids ): raise R2RException( - "The currently authenticated user does not have access to this graph.", + "The currently authenticated user does not have access to the specified graph.", 403, ) - # Set parent ID to graph ID - entity.parent_id = collection_id - - # Create entity - created_ids = ( - await self.providers.database.graph_handler.entities.create( - entities=[entity], store_type="graph" - ) - ) - if not created_ids: - raise R2RException("Failed to create entity", 500) - - result = await self.providers.database.graph_handler.entities.get( + return await self.services["kg"].create_entity( + name=name, + description=description, parent_id=collection_id, - store_type="graph", - entity_ids=[created_ids[0]], + category=category, + metadata=metadata, ) - if len(result) == 0: - raise R2RException("Failed to create entity", 500) - return result[0] @self.router.post("/graphs/{collection_id}/relationships") @self.base_endpoint @@ -751,8 +657,32 @@ async def create_relationship( ..., description="The collection ID corresponding to the graph to add the relationship to.", ), - relationship: Relationship = Body( - ..., description="The relationship to create" + subject: str = Body( + ..., description="The subject of the relationship to create." + ), + subject_id: UUID = Body( + ..., + description="The ID of the subject of the relationship to create.", + ), + predicate: str = Body( + ..., description="The predicate of the relationship to create." + ), + object: str = Body( + ..., description="The object of the relationship to create." + ), + object_id: UUID = Body( + ..., + description="The ID of the object of the relationship to create.", + ), + description: Optional[str] = Body( + None, + description="The description of the relationship to create.", + ), + weight: Optional[float] = Body( + None, description="The weight of the relationship to create." + ), + metadata: Optional[dict] = Body( + None, description="The metadata of the relationship to create." ), auth_user=Depends(self.providers.auth.auth_wrapper), ) -> WrappedRelationshipResponse: @@ -762,21 +692,64 @@ async def create_relationship( and collection_id not in auth_user.graph_ids ): raise R2RException( - "The currently authenticated user does not have access to this graph.", + "The currently authenticated user does not have access to the specified graph.", 403, ) - # Set parent ID to graph ID - relationship.parent_id = collection_id - - # Create relationship - await self.providers.database.graph_handler.relationships.create( - relationships=[relationship], store_type="graph" + return await self.services["kg"].create_relationship( + subject=subject, + subject_id=subject_id, + predicate=predicate, + object=object, + object_id=object_id, + description=description, + weight=weight, + metadata=metadata, + parent_id=collection_id, ) - return relationship + @self.router.get( + "/graphs/{collection_id}/entities/{entity_id}", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.graphs.get_entity( + collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + entity_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + ) + """ + ), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent( + """ + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient("http://localhost:7272"); + + function main() { + const response = await client.graphs.get_entity({ + collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + entityId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + }); + } - @self.router.get("/graphs/{collection_id}/entities/{entity_id}") + main(); + """ + ), + }, + ] + }, + ) @self.base_endpoint async def get_entity( collection_id: UUID = Path( @@ -789,9 +762,21 @@ async def get_entity( auth_user=Depends(self.providers.auth.auth_wrapper), ) -> WrappedEntityResponse: """Retrieves a specific entity by its ID.""" - # Note: The original was missing implementation, so assuming similar pattern to relationships + if ( + not auth_user.is_superuser + and collection_id not in auth_user.graph_ids + ): + raise R2RException( + "The currently authenticated user does not have access to the specified graph.", + 403, + ) + result = await self.providers.database.graph_handler.entities.get( - collection_id, "graph", entity_ids=[entity_id] + parent_id=collection_id, + store_type="graph", + offset=0, + limit=1, + entity_ids=[entity_id], ) if len(result) == 0 or len(result[0]) == 0: raise R2RException("Entity not found", 404) @@ -807,25 +792,81 @@ async def update_entity( entity_id: UUID = Path( ..., description="The ID of the entity to update." ), - entity: Entity = Body( - ..., description="The updated entity object." + name: Optional[str] = Body( + ..., description="The updated name of the entity." + ), + description: Optional[str] = Body( + None, description="The updated description of the entity." + ), + category: Optional[str] = Body( + None, description="The updated category of the entity." + ), + metadata: Optional[dict] = Body( + None, description="The updated metadata of the entity." ), auth_user=Depends(self.providers.auth.auth_wrapper), ) -> WrappedEntityResponse: """Updates an existing entity in the graph.""" - entity.id = entity_id - entity.parent_id = ( - entity.parent_id or collection_id - ) # Set parent ID to graph ID - results = await self.providers.database.graph_handler.entities.update( - [entity], - store_type="graph", - # id, entity_id, entity, auth_user + if ( + not auth_user.is_superuser + and collection_id not in auth_user.graph_ids + ): + raise R2RException( + "The currently authenticated user does not have access to the specified graph.", + 403, + ) + + return await self.services["kg"].update_entity( + entity_id=entity_id, + name=name, + category=category, + description=description, + metadata=metadata, ) - print("results = ", results) - return entity - @self.router.delete("/graphs/{collection_id}/entities/{entity_id}") + @self.router.delete( + "/graphs/{collection_id}/entities/{entity_id}", + summary="Remove an entity", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.graphs.remove_entity( + collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + entity_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + ) + """ + ), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent( + """ + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient("http://localhost:7272"); + + function main() { + const response = await client.graphs.removeEntity({ + collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + entityId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + }); + } + + main(); + """ + ), + }, + ] + }, + ) @self.base_endpoint async def delete_entity( collection_id: UUID = Path( @@ -839,12 +880,61 @@ async def delete_entity( auth_user=Depends(self.providers.auth.auth_wrapper), ) -> WrappedBooleanResponse: """Removes an entity from the graph.""" - await self.providers.database.graph_handler.entities.delete( - collection_id, [entity_id], "graph" + if ( + not auth_user.is_superuser + and collection_id not in auth_user.graph_ids + ): + raise R2RException( + "The currently authenticated user does not have access to the specified graph.", + 403, + ) + + await self.services["kg"].delete_entity( + parent_id=collection_id, + entity_id=entity_id, ) - return {"success": True} - @self.router.get("/graphs/{collection_id}/relationships") + return GenericBooleanResponse(success=True) # type: ignore + + @self.router.get( + "/graphs/{collection_id}/relationships", + description="Lists all relationships in the graph with pagination support.", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.graphs.list_relationships(collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7") + """ + ), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent( + """ + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient("http://localhost:7272"); + + function main() { + const response = await client.graphs.listRelationships({ + collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + }); + } + + main(); + """ + ), + }, + ], + }, + ) @self.base_endpoint async def get_relationships( collection_id: UUID = Path( @@ -863,53 +953,71 @@ async def get_relationships( description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.", ), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> PaginatedResultsWrapper[list[Relationship]]: + ) -> WrappedRelationshipsResponse: """ Lists all relationships in the graph with pagination support. """ - # Permission check if ( not auth_user.is_superuser and collection_id not in auth_user.graph_ids ): raise R2RException( - "The currently authenticated user does not have access to this graph.", + "The currently authenticated user does not have access to the specified graph.", 403, ) - relationships, count = ( - await self.providers.database.graph_handler.relationships.get( - parent_id=collection_id, - store_type="graph", - offset=offset, - limit=limit, - ) + relationships, count = await self.services["kg"].get_relationships( + parent_id=collection_id, + offset=offset, + limit=limit, ) - return relationships, { + return relationships, { # type: ignore "total_entries": count, } - @self.router.post("/graphs/{collection_id}/relationships") - @self.base_endpoint - async def create_relationship( - collection_id: UUID = Path( - ..., - description="The collection ID corresponding to the graph to add the relationship to.", - ), - relationship_ids: list[UUID] = Body( - ..., - description="The IDs of the relationships to add to the graph.", - ), - auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> WrappedRelationshipResponse: - """Creates a new relationship in the graph.""" - return await self.providers.database.graph_handler.relationships.add_to_graph( - collection_id, relationship_ids, "graph" - ) - @self.router.get( - "/graphs/{collection_id}/relationships/{relationship_id}" + "/graphs/{collection_id}/relationships/{relationship_id}", + description="Retrieves a specific relationship by its ID.", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.graphs.get_relationship( + collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + relationship_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + ) + """ + ), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent( + """ + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient("http://localhost:7272"); + + function main() { + const response = await client.graphs.getRelationship({ + collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + relationshipId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + }); + } + + main(); + """ + ), + }, + ], + }, ) @self.base_endpoint async def get_relationship( @@ -923,9 +1031,22 @@ async def get_relationship( auth_user=Depends(self.providers.auth.auth_wrapper), ) -> WrappedRelationshipResponse: """Retrieves a specific relationship by its ID.""" + if ( + not auth_user.is_superuser + and collection_id not in auth_user.graph_ids + ): + raise R2RException( + "The currently authenticated user does not have access to the specified graph.", + 403, + ) + results = ( await self.providers.database.graph_handler.relationships.get( - collection_id, "graph", relationship_ids=[relationship_id] + parent_id=collection_id, + store_type="graph", + offset=0, + limit=1, + relationship_ids=[relationship_id], ) ) if len(results) == 0 or len(results[0]) == 0: @@ -944,20 +1065,97 @@ async def update_relationship( relationship_id: UUID = Path( ..., description="The ID of the relationship to update." ), - relationship: Relationship = Body( - ..., description="The updated relationship object." + subject: Optional[str] = Body( + ..., description="The updated subject of the relationship." + ), + subject_id: Optional[UUID] = Body( + ..., description="The updated subject ID of the relationship." + ), + predicate: Optional[str] = Body( + ..., description="The updated predicate of the relationship." + ), + object: Optional[str] = Body( + ..., description="The updated object of the relationship." + ), + object_id: Optional[UUID] = Body( + ..., description="The updated object ID of the relationship." + ), + description: Optional[str] = Body( + None, + description="The updated description of the relationship.", + ), + weight: Optional[float] = Body( + None, description="The updated weight of the relationship." + ), + metadata: Optional[dict] = Body( + None, description="The updated metadata of the relationship." ), auth_user=Depends(self.providers.auth.auth_wrapper), - ): # -> WrappedRelationshipResponse: + ) -> WrappedRelationshipResponse: """Updates an existing relationship in the graph.""" - relationship.id = relationship_id - relationship.parent_id = relationship.parent_id or collection_id - return await self.providers.database.graph_handler.relationships.update( - [relationship], "graph" + if ( + not auth_user.is_superuser + and collection_id not in auth_user.graph_ids + ): + raise R2RException( + "The currently authenticated user does not have access to the specified graph.", + 403, + ) + + return await self.services["kg"].update_relationship( + relationship_id=relationship_id, + subject=subject, + subject_id=subject_id, + predicate=predicate, + object=object, + object_id=object_id, + description=description, + weight=weight, + metadata=metadata, ) @self.router.delete( - "/graphs/{collection_id}/relationships/{relationship_id}" + "/graphs/{collection_id}/relationships/{relationship_id}", + description="Removes a relationship from the graph.", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.graphs.delete_relationship( + collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + relationship_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + ) + """ + ), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent( + """ + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient("http://localhost:7272"); + + function main() { + const response = await client.graphs.deleteRelationship({ + collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + relationshipId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + }); + } + + main(); + """ + ), + }, + ], + }, ) @self.base_endpoint async def delete_relationship( @@ -972,15 +1170,21 @@ async def delete_relationship( auth_user=Depends(self.providers.auth.auth_wrapper), ) -> WrappedBooleanResponse: """Removes a relationship from the graph.""" - # return await self.services[ - # "kg" - # ].documents.graph_handler.relationships.remove_from_graph( - # id, relationship_id, auth_user - # ) - await self.providers.database.graph_handler.relationships.delete( - collection_id, [relationship_id], "graph" + if ( + not auth_user.is_superuser + and collection_id not in auth_user.graph_ids + ): + raise R2RException( + "The currently authenticated user does not have access to the specified graph.", + 403, + ) + + await self.services["kg"].delete_relationship( + parent_id=collection_id, + relationship_id=relationship_id, ) - return {"success": True} + + return GenericBooleanResponse(success=True) # type: ignore @self.router.post( "/graphs/{collection_id}/communities", @@ -996,7 +1200,37 @@ async def delete_relationship( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.graphs.communities.create(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", communities=[community1, community2]) + result = client.graphs.create_community( + collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", + name="My Community", + summary="A summary of the community", + findings=["Finding 1", "Finding 2"], + rating=5, + rating_explanation="This is a rating explanation", + ) + """ + ), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent( + """ + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient("http://localhost:7272"); + + function main() { + const response = await client.graphs.createCommunity({ + collectionId: "9fbe403b-c11c-5aae-8ade-ef22980c3ad1", + name: "My Community", + summary: "A summary of the community", + findings: ["Finding 1", "Finding 2"], + rating: 5, + ratingExplanation: "This is a rating explanation", + }); + } + + main(); """ ), }, @@ -1004,7 +1238,7 @@ async def delete_relationship( }, ) @self.base_endpoint - async def create_communities( + async def create_community( collection_id: UUID = Path( ..., description="The collection ID corresponding to the graph to create the community in.", @@ -1014,29 +1248,21 @@ async def create_communities( findings: Optional[list[str]] = Body( default=[], description="Findings about the community" ), - level: Optional[int] = Body( - default=0, - ge=0, - le=100, - description="The level of the community", - ), rating: Optional[float] = Body( default=5, ge=1, le=10, description="Rating between 1 and 10" ), rating_explanation: Optional[str] = Body( default="", description="Explanation for the rating" ), - attributes: Optional[dict] = Body( - default=None, description="Attributes for the community" - ), auth_user=Depends(self.providers.auth.auth_wrapper), - ): + ) -> WrappedCommunityResponse: """ Creates a new community in the graph. While communities are typically built automatically via the /graphs/{id}/communities/build endpoint, - this endpoint allows you to manually create your own communities. This can be useful when you want to: + this endpoint allows you to manually create your own communities. + This can be useful when you want to: - Define custom groupings of entities based on domain knowledge - Add communities that weren't detected by the automatic process - Create hierarchical organization structures @@ -1045,16 +1271,22 @@ async def create_communities( The created communities will be integrated with any existing automatically detected communities in the graph's community structure. """ - return await self.services["kg"].create_community_v3( - graph_id=collection_id, + if ( + not auth_user.is_superuser + and collection_id not in auth_user.graph_ids + ): + raise R2RException( + "The currently authenticated user does not have access to the specified graph.", + 403, + ) + + return await self.services["kg"].create_community( + parent_id=collection_id, name=name, summary=summary, findings=findings, rating=rating, rating_explanation=rating_explanation, - level=level, - attributes=attributes, - auth_user=auth_user, ) @self.router.get( @@ -1071,7 +1303,25 @@ async def create_communities( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.graphs.communities.get(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1") + result = client.graphs.list_communities(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1") + """ + ), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent( + """ + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient("http://localhost:7272"); + + function main() { + const response = await client.graphs.listCommunities({ + collectionId: "9fbe403b-c11c-5aae-8ade-ef22980c3ad1", + }); + } + + main(); """ ), }, @@ -1080,7 +1330,6 @@ async def create_communities( ) @self.base_endpoint async def get_communities( - request: Request, collection_id: UUID = Path( ..., description="The collection ID corresponding to the graph to get communities for.", @@ -1097,27 +1346,26 @@ async def get_communities( description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.", ), auth_user=Depends(self.providers.auth.auth_wrapper), - ): + ) -> WrappedCommunitiesResponse: """ Lists all communities in the graph with pagination support. - - By default, all attributes are returned, but this can be limited using the `attributes` parameter. """ - if not auth_user.is_superuser: + if ( + not auth_user.is_superuser + and collection_id not in auth_user.graph_ids + ): raise R2RException( - "Only superusers can access this endpoint.", 403 + "The currently authenticated user does not have access to the specified graph.", + 403, ) - communities, count = await self.services[ - "kg" - ].providers.database.graph_handler.communities.get( - graph_id=collection_id, + communities, count = await self.services["kg"].get_communities( + parent_id=collection_id, offset=offset, limit=limit, - auth_user=auth_user, ) - return communities, { + return communities, { # type: ignore "total_entries": count, } @@ -1135,7 +1383,25 @@ async def get_communities( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.graphs.communities.get(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1") + result = client.graphs.get_community(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1") + """ + ), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent( + """ + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient("http://localhost:7272"); + + function main() { + const response = await client.graphs.getCommunity({ + collectionId: "9fbe403b-c11c-5aae-8ade-ef22980c3ad1", + }); + } + + main(); """ ), }, @@ -1144,7 +1410,6 @@ async def get_communities( ) @self.base_endpoint async def get_community( - request: Request, collection_id: UUID = Path( ..., description="The ID of the collection to get communities for.", @@ -1154,34 +1419,78 @@ async def get_community( description="The ID of the community to get.", ), auth_user=Depends(self.providers.auth.auth_wrapper), - ): + ) -> WrappedCommunityResponse: """ Retrieves a specific community by its ID. - - By default, all attributes are returned, but this can be limited using the `attributes` parameter. """ - if not auth_user.is_superuser: + if ( + not auth_user.is_superuser + and collection_id not in auth_user.graph_ids + ): raise R2RException( - "Only superusers can access this endpoint.", 403 + "The currently authenticated user does not have access to the specified graph.", + 403, ) - return await self.services[ + results = await self.services[ "kg" ].providers.database.graph_handler.communities.get( - graph_id=collection_id, - community_id=community_id, - auth_user=auth_user, + parent_id=collection_id, + community_ids=[community_id], + store_type="graph", offset=0, limit=1, ) + print(f"results: {results}") + if len(results) == 0 or len(results[0]) == 0: + raise R2RException("Community not found", 404) + return results[0][0] @self.router.delete( "/graphs/{collection_id}/communities/{community_id}", summary="Delete a community", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.graphs.delete_community( + collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + community_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + ) + """ + ), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent( + """ + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient("http://localhost:7272"); + + function main() { + const response = await client.graphs.deleteCommunity({ + collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + communityId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + }); + } + + main(); + """ + ), + }, + ] + }, ) @self.base_endpoint async def delete_community( - request: Request, collection_id: UUID = Path( ..., description="The collection ID corresponding to the graph to delete the community from.", @@ -1192,14 +1501,18 @@ async def delete_community( ), auth_user=Depends(self.providers.auth.auth_wrapper), ): - if not auth_user.is_superuser: + if ( + not auth_user.is_superuser + and collection_id not in auth_user.graph_ids + ): raise R2RException( - "Only superusers can access this endpoint.", 403 + "The currently authenticated user does not have access to the specified graph.", + 403, ) - await self.services["kg"].delete_community_v3( - graph_id=collection_id, + + await self.services["kg"].delete_community( + parent_id=collection_id, community_id=community_id, - auth_user=auth_user, ) return GenericBooleanResponse(success=True) # type: ignore @@ -1228,6 +1541,31 @@ async def delete_community( )""" ), }, + { + "lang": "JavaScript", + "source": textwrap.dedent( + """ + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient("http://localhost:7272"); + + async function main() { + const response = await client.graphs.updateCommunity({ + collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + communityId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + communityUpdate: { + metadata: { + topic: "Technology", + description: "Tech companies and products" + } + } + }); + } + + main(); + """ + ), + }, ] }, ) @@ -1238,34 +1576,102 @@ async def update_community( name: Optional[str] = Body(None), summary: Optional[str] = Body(None), findings: Optional[list[str]] = Body(None), - rating: Optional[float] = Body(None), + rating: Optional[float] = Body(default=None, ge=1, le=10), rating_explanation: Optional[str] = Body(None), - level: Optional[int] = Body(None), - attributes: Optional[dict] = Body(None), auth_user=Depends(self.providers.auth.auth_wrapper), - ): + ) -> WrappedCommunityResponse: """ - Updates an existing community's metadata and properties. + Updates an existing community in the graph. """ - if not auth_user.is_superuser: + if ( + not auth_user.is_superuser + and collection_id not in auth_user.graph_ids + ): raise R2RException( - "Only superusers can update communities", 403 + "The currently authenticated user does not have access to the specified graph.", + 403, ) - return await self.services["kg"].update_community_v3( - id=collection_id, + return await self.services["kg"].update_community( community_id=community_id, name=name, summary=summary, findings=findings, rating=rating, rating_explanation=rating_explanation, - level=level, - attributes=attributes, - auth_user=auth_user, ) - async def _pull(collection_id: UUID, auth_user): + @self.router.post( + "/graphs/{collection_id}/pull", + summary="Pull latest entities to the graph", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.graphs.pull( + collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + )""" + ), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent( + """ + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient("http://localhost:7272"); + + async function main() { + const response = await client.graphs.pull({ + collection_id: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + }); + } + + main(); + """ + ), + }, + ] + }, + ) + @self.base_endpoint + async def pull( + collection_id: UUID = Path( + ..., description="The ID of the graph to initialize." + ), + # document_ids: list[UUID] = Body( + # ..., description="List of document IDs to add to the graph." + # ), + auth_user=Depends(self.providers.auth.auth_wrapper), + ) -> WrappedBooleanResponse: + """ + Adds documents to a graph by copying their entities and relationships. + + This endpoint: + 1. Copies document entities to the graph_entity table + 2. Copies document relationships to the graph_relationship table + 3. Associates the documents with the graph + + When a document is added: + - Its entities and relationships are copied to graph-specific tables + - Existing entities/relationships are updated by merging their properties + - The document ID is recorded in the graph's document_ids array + + Documents added to a graph will contribute their knowledge to: + - Graph analysis and querying + - Community detection + - Knowledge graph enrichment + + The user must have access to both the graph and the documents being added. + """ + # Check user permissions for graph if ( not auth_user.is_superuser and collection_id not in auth_user.graph_ids @@ -1313,7 +1719,10 @@ async def _pull(collection_id: UUID, auth_user): ) entities = ( await self.providers.database.graph_handler.entities.get( - document.id, store_type="document" + parent_id=document.id, + store_type="document", + offset=0, + limit=100, ) ) has_document = ( @@ -1344,11 +1753,12 @@ async def _pull(collection_id: UUID, auth_user): logger.warning( f"No documents were added to graph {collection_id}, marking as failed." ) - return success - @self.router.post( - "/graphs/{collection_id}/pull", - summary="Pull latest entities to the graph", + return GenericBooleanResponse(success=success) # type: ignore + + @self.router.delete( + "/graphs/{collection_id}/documents/{document_id}", + summary="Remove document from graph", openapi_extra={ "x-codeSamples": [ { @@ -1360,8 +1770,9 @@ async def _pull(collection_id: UUID, auth_user): client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.graphs.initialize( - collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + result = client.graphs.remove_document( + collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + document_id="f98db41a-5555-4444-3333-222222222222" )""" ), }, @@ -1374,8 +1785,9 @@ async def _pull(collection_id: UUID, auth_user): const client = new r2rClient("http://localhost:7272"); async function main() { - const response = await client.graphs.addDocuments({ - collection_id: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + const response = await client.graphs.removeDocument({ + collectionId: "d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + documentId: "f98db41a-5555-4444-3333-222222222222" }); } @@ -1387,63 +1799,6 @@ async def _pull(collection_id: UUID, auth_user): }, ) @self.base_endpoint - async def pull( - collection_id: UUID = Path( - ..., description="The ID of the graph to initialize." - ), - # document_ids: list[UUID] = Body( - # ..., description="List of document IDs to add to the graph." - # ), - auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> WrappedBooleanResponse: - """ - Adds documents to a graph by copying their entities and relationships. - - This endpoint: - 1. Copies document entities to the graph_entity table - 2. Copies document relationships to the graph_relationship table - 3. Associates the documents with the graph - - When a document is added: - - Its entities and relationships are copied to graph-specific tables - - Existing entities/relationships are updated by merging their properties - - The document ID is recorded in the graph's document_ids array - - Documents added to a graph will contribute their knowledge to: - - Graph analysis and querying - - Community detection - - Knowledge graph enrichment - - The user must have access to both the graph and the documents being added. - """ - # Check user permissions for graph - success = await _pull(collection_id, auth_user) - return GenericBooleanResponse(success=success) - - @self.router.delete( - "/graphs/{collection_id}/documents/{document_id}", - summary="Remove document from graph", - openapi_extra={ - "x-codeSamples": [ - { - "lang": "Python", - "source": textwrap.dedent( - """ - from r2r import R2RClient - - client = R2RClient("http://localhost:7272") - # when using auth, do client.login(...) - - result = client.graphs.remove_document( - collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", - document_id="f98db41a-5555-4444-3333-222222222222" - )""" - ), - }, - ] - }, - ) - @self.base_endpoint async def remove_document( collection_id: UUID = Path( ..., @@ -1463,7 +1818,6 @@ async def remove_document( The user must have access to both the graph and the document being removed. """ - # Check user permissions for graph if ( not auth_user.is_superuser and collection_id not in auth_user.graph_ids @@ -1473,7 +1827,6 @@ async def remove_document( 403, ) - # Check user permissions for document if ( not auth_user.is_superuser and document_id not in auth_user.document_ids @@ -1490,4 +1843,4 @@ async def remove_document( ) ) - return GenericBooleanResponse(success=success) + return GenericBooleanResponse(success=success) # type: ignore diff --git a/py/core/main/api/v3/prompts_router.py b/py/core/main/api/v3/prompts_router.py index b84efb3cf..5de440f7a 100644 --- a/py/core/main/api/v3/prompts_router.py +++ b/py/core/main/api/v3/prompts_router.py @@ -195,6 +195,95 @@ async def get_prompts( }, ) + @self.router.post( + "/prompts/{name}", + summary="Get a specific prompt", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.prompts.get( + "greeting_prompt", + inputs={"name": "John"}, + prompt_override="Hi, {name}!" + ) + """ + ), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent( + """ + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient("http://localhost:7272"); + + function main() { + const response = await client.prompts.retrieve({ + name: "greeting_prompt", + inputs: { name: "John" }, + promptOverride: "Hi, {name}!", + }); + } + + main(); + """ + ), + }, + { + "lang": "CLI", + "source": textwrap.dedent( + """ + r2r prompts retrieve greeting_prompt --inputs '{"name": "John"}' --prompt-override "Hi, {name}!" + """ + ), + }, + { + "lang": "cURL", + "source": textwrap.dedent( + """ + curl -X POST "https://api.example.com/v3/prompts/greeting_prompt?inputs=%7B%22name%22%3A%22John%22%7D&prompt_override=Hi%2C%20%7Bname%7D!" \\ + -H "Authorization: Bearer YOUR_API_KEY" + """ + ), + }, + ] + }, + ) + @self.base_endpoint + async def get_prompt( + name: str = Path(..., description="Prompt name"), + inputs: Optional[dict[str, str]] = Body( + None, description="Prompt inputs" + ), + prompt_override: Optional[str] = Query( + None, description="Prompt override" + ), + auth_user=Depends(self.providers.auth.auth_wrapper), + ) -> WrappedPromptResponse: + """ + Get a specific prompt by name, optionally with inputs and override. + + This endpoint retrieves a specific prompt and allows for optional inputs and template override. + Only superusers can access this endpoint. + """ + if not auth_user.is_superuser: + raise R2RException( + "Only a superuser can retrieve prompts.", + 403, + ) + result = await self.services["management"].get_prompt( + name, inputs, prompt_override + ) + return result # type: ignore + @self.router.put( "/prompts/{name}", summary="Update an existing prompt", diff --git a/py/core/main/api/v3/retrieval_router.py b/py/core/main/api/v3/retrieval_router.py index 54ceb7302..4c32bfc54 100644 --- a/py/core/main/api/v3/retrieval_router.py +++ b/py/core/main/api/v3/retrieval_router.py @@ -220,7 +220,6 @@ async def search_app( description="Search query to find relevant documents", ), search_settings: SearchSettings = Body( - # alias="searchSettings", default_factory=SearchSettings, description="Settings for vector-based search", ), @@ -381,22 +380,18 @@ async def search_app( async def rag_app( query: str = Body(...), search_settings: SearchSettings = Body( - alias="searchSettings", default_factory=SearchSettings, description="Settings for vector-based search", ), rag_generation_config: GenerationConfig = Body( - alias="ragGenerationConfig", default_factory=GenerationConfig, description="Configuration for RAG generation", ), task_prompt_override: Optional[str] = Body( - alias="taskPromptOverride", default=None, description="Optional custom prompt to override default", ), include_title_if_available: bool = Body( - alias="includeTitleIfAvailable", default=False, description="Include document titles in responses when available", ), @@ -566,32 +561,26 @@ async def agent_app( description="List of messages (deprecated, use message instead)", ), search_settings: SearchSettings = Body( - alias="searchSettings", default_factory=SearchSettings, description="Settings for vector-based search", ), rag_generation_config: GenerationConfig = Body( - alias="ragGenerationConfig", default_factory=GenerationConfig, description="Configuration for RAG generation", ), task_prompt_override: Optional[str] = Body( - alias="taskPromptOverride", default=None, description="Optional custom prompt to override default", ), include_title_if_available: bool = Body( - alias="includeTitleIfAvailable", default=True, description="Include document titles in responses when available", ), conversation_id: Optional[UUID] = Body( - alias="conversationId", default=None, description="ID of the conversation", ), branch_id: Optional[UUID] = Body( - alias="branchId", default=None, description="ID of the conversation branch", ), @@ -772,7 +761,6 @@ async def completion( ], ), generation_config: GenerationConfig = Body( - alias="generationConfig", default_factory=GenerationConfig, description="Configuration for text generation", example={ diff --git a/py/core/main/app.py b/py/core/main/app.py index 5fe984d45..de18c0972 100644 --- a/py/core/main/app.py +++ b/py/core/main/app.py @@ -85,11 +85,11 @@ async def r2r_exception_handler(request: Request, exc: R2RException): def _setup_routes(self): # Include routers in the app - # self.app.include_router(self.ingestion_router, prefix="/v2") - # self.app.include_router(self.management_router, prefix="/v2") - # self.app.include_router(self.retrieval_router, prefix="/v2") - # self.app.include_router(self.auth_router, prefix="/v2") - # self.app.include_router(self.kg_router, prefix="/v2") + self.app.include_router(self.ingestion_router, prefix="/v2") + self.app.include_router(self.management_router, prefix="/v2") + self.app.include_router(self.retrieval_router, prefix="/v2") + self.app.include_router(self.auth_router, prefix="/v2") + self.app.include_router(self.kg_router, prefix="/v2") self.app.include_router(self.documents_router, prefix="/v3") self.app.include_router(self.chunks_router, prefix="/v3") diff --git a/py/core/main/services/kg_service.py b/py/core/main/services/kg_service.py index 3fdb4ef6a..74f296298 100644 --- a/py/core/main/services/kg_service.py +++ b/py/core/main/services/kg_service.py @@ -134,15 +134,15 @@ async def kg_relationships_extraction( return await _collect_results(result_gen) - @telemetry_event("create_entities") - async def create_entities( + @telemetry_event("create_entity") + async def create_entity( self, name: str, description: str, - metadata: Optional[dict] = None, + parent_id: UUID, category: Optional[str] = None, - auth_user: Optional[Any] = None, - ): + metadata: Optional[dict] = None, + ) -> Entity: description_embedding = str( await self.providers.embedding.async_get_embedding(description) @@ -150,317 +150,266 @@ async def create_entities( return await self.providers.database.graph_handler.entities.create( name=name, + parent_id=parent_id, + store_type="graph", # type: ignore category=category, description=description, description_embedding=description_embedding, metadata=metadata, - auth_user=auth_user, - ) - - @telemetry_event("list_entities") - async def list_entities( - self, - offset: int, - limit: int, - id: Optional[UUID] = None, - graph_id: Optional[UUID] = None, - document_id: Optional[UUID] = None, - entity_names: Optional[list[str]] = None, - include_embeddings: Optional[bool] = False, - user_id: Optional[UUID] = None, - ): - return await self.providers.database.graph_handler.entities.get( - id=id, - graph_id=graph_id, - document_id=document_id, - entity_names=entity_names, - include_embeddings=include_embeddings, - offset=offset, - limit=limit, - user_id=user_id, ) @telemetry_event("update_entity") - async def update_entity_v3( + async def update_entity( self, - id: UUID, - name: Optional[str], - category: Optional[str], - description: Optional[str], - attributes: Optional[dict], - auth_user: Optional[Any] = None, - ): + entity_id: UUID, + name: Optional[str] = None, + description: Optional[str] = None, + category: Optional[str] = None, + metadata: Optional[dict] = None, + ) -> Entity: + description_embedding = None if description is not None: description_embedding = str( await self.providers.embedding.async_get_embedding(description) ) - else: - description_embedding = None return await self.providers.database.graph_handler.entities.update( - id=id, + entity_id=entity_id, + store_type="graph", # type: ignore name=name, - category=category, description=description, description_embedding=description_embedding, - attributes=attributes, - auth_user=auth_user, + category=category, + metadata=metadata, ) @telemetry_event("delete_entity") - async def delete_entity_v3( + async def delete_entity( self, - id: UUID, + parent_id: UUID, entity_id: UUID, - level: DataLevel, - **kwargs, ): return await self.providers.database.graph_handler.entities.delete( - id=id, - entity_id=entity_id, - level=level, - ) - - @telemetry_event("add_entity_to_graph") - async def add_entity_to_graph( - self, - graph_id: UUID, - entity_id: UUID, - auth_user: Optional[Any] = None, - ): - return ( - await self.providers.database.graph_handler.entities.add_to_graph( - graph_id, entity_id, auth_user - ) + parent_id=parent_id, + entity_ids=[entity_id], + store_type="graph", # type: ignore ) - # TODO: deprecate this @telemetry_event("get_entities") async def get_entities( self, - collection_id: Optional[UUID] = None, - entity_ids: Optional[list[str]] = None, - entity_table_name: str = "entity", - offset: Optional[int] = None, - limit: Optional[int] = None, - **kwargs, - ): - return await self.providers.database.graph_handler.get_entities( - collection_id=collection_id, - entity_ids=entity_ids, - entity_table_name=entity_table_name, - offset=offset or 0, - limit=limit or -1, - ) - - ################### RELATIONSHIPS ################### - - @telemetry_event("list_relationships_v3") - async def list_relationships_v3( - self, - id: UUID, - level: DataLevel, + parent_id: UUID, offset: int, limit: int, + entity_ids: Optional[list[UUID]] = None, entity_names: Optional[list[str]] = None, - relationship_types: Optional[list[str]] = None, - attributes: Optional[list[str]] = None, - relationship_id: Optional[UUID] = None, + include_embeddings: bool = False, ): - return await self.providers.database.graph_handler.relationships.get( - id=id, - level=level, - entity_names=entity_names, - relationship_types=relationship_types, - attributes=attributes, + return await self.providers.database.graph_handler.get_entities( + parent_id=parent_id, offset=offset, limit=limit, - relationship_id=relationship_id, + entity_ids=entity_ids, + entity_names=entity_names, + include_embeddings=include_embeddings, ) - @telemetry_event("create_relationships_v3") - async def create_relationships_v3( + @telemetry_event("create_relationship") + async def create_relationship( self, - relationships: list[Relationship], - **kwargs, - ): - for relationship in relationships: - if relationship.description: - relationships.description_embedding = str( - await self.providers.embedding.async_get_embedding( - relationship.description - ) - ) - - print("relationships = ", relationships) + subject: str, + subject_id: UUID, + predicate: str, + object: str, + object_id: UUID, + parent_id: UUID, + description: str | None = None, + weight: float | None = 1.0, + metadata: Optional[dict[str, Any] | str] = None, + ) -> Relationship: + description_embedding = None + if description: + description_embedding = str( + await self.providers.embedding.async_get_embedding(description) + ) return ( await self.providers.database.graph_handler.relationships.create( - relationships + subject=subject, + subject_id=subject_id, + predicate=predicate, + object=object, + object_id=object_id, + parent_id=parent_id, + description=description, + description_embedding=description_embedding, + weight=weight, + metadata=metadata, + store_type="graph", # type: ignore ) ) - @telemetry_event("delete_relationship_v3") - async def delete_relationship_v3( + @telemetry_event("delete_relationship") + async def delete_relationship( self, - level: DataLevel, - id: UUID, + parent_id: UUID, relationship_id: UUID, - **kwargs, ): return ( await self.providers.database.graph_handler.relationships.delete( - level=level, - id=id, - relationship_id=relationship_id, + parent_id=parent_id, + relationship_ids=[relationship_id], + store_type="graph", # type: ignore ) ) - @telemetry_event("update_relationship_v3") - async def update_relationship_v3( + @telemetry_event("update_relationship") + async def update_relationship( self, - relationship: Relationship, - **kwargs, - ): + relationship_id: UUID, + subject: Optional[str] = None, + subject_id: Optional[UUID] = None, + predicate: Optional[str] = None, + object: Optional[str] = None, + object_id: Optional[UUID] = None, + description: Optional[str] = None, + weight: Optional[float] = None, + metadata: Optional[dict[str, Any] | str] = None, + ) -> Relationship: + + description_embedding = None + if description is not None: + description_embedding = str( + await self.providers.embedding.async_get_embedding(description) + ) + return ( await self.providers.database.graph_handler.relationships.update( - relationship + relationship_id=relationship_id, + subject=subject, + subject_id=subject_id, + predicate=predicate, + object=object, + object_id=object_id, + description=description, + description_embedding=description_embedding, + weight=weight, + metadata=metadata, + store_type="graph", # type: ignore ) ) - # TODO: deprecate this - @telemetry_event("get_triples") + @telemetry_event("get_relationships") async def get_relationships( self, - collection_id: Optional[UUID] = None, + parent_id: UUID, + offset: int, + limit: int, + relationship_ids: Optional[list[UUID]] = None, entity_names: Optional[list[str]] = None, - relationship_ids: Optional[list[str]] = None, - offset: Optional[int] = None, - limit: Optional[int] = None, - **kwargs, ): - return await self.providers.database.graph_handler.get_relationships( - collection_id=collection_id, - entity_names=entity_names, + return await self.providers.database.graph_handler.relationships.get( + parent_id=parent_id, + store_type="graph", # type: ignore + offset=offset, + limit=limit, relationship_ids=relationship_ids, - offset=offset or 0, - limit=limit or -1, + entity_names=entity_names, ) - ################### COMMUNITIES ################### - - @telemetry_event("create_community_v3") - async def create_community_v3( + @telemetry_event("create_community") + async def create_community( self, - graph_id: UUID, + parent_id: UUID, name: str, summary: str, - findings: list[str], + findings: Optional[list[str]], rating: Optional[float], rating_explanation: Optional[str], - level: Optional[int], - attributes: Optional[dict], - auth_user: Any, - **kwargs, - ): - embedding = str( + ) -> Community: + description_embedding = str( await self.providers.embedding.async_get_embedding(summary) ) return await self.providers.database.graph_handler.communities.create( - graph_id=graph_id, + parent_id=parent_id, + store_type="graph", # type: ignore name=name, summary=summary, - embedding=embedding, + description_embedding=description_embedding, findings=findings, rating=rating, rating_explanation=rating_explanation, - level=level, - attributes=attributes, - auth_user=auth_user, ) - @telemetry_event("update_community_v3") - async def update_community_v3( + @telemetry_event("update_community") + async def update_community( self, - id: UUID, community_id: UUID, name: Optional[str], summary: Optional[str], findings: Optional[list[str]], rating: Optional[float], rating_explanation: Optional[str], - level: Optional[int], - attributes: Optional[dict], - auth_user: Any, - **kwargs, - ): + ) -> Community: + summary_embedding = None if summary is not None: - embedding = str( + summary_embedding = str( await self.providers.embedding.async_get_embedding(summary) ) - else: - embedding = None return await self.providers.database.graph_handler.communities.update( - id=id, community_id=community_id, + store_type="graph", # type: ignore name=name, summary=summary, - embedding=embedding, + summary_embedding=summary_embedding, findings=findings, rating=rating, rating_explanation=rating_explanation, - level=level, - attributes=attributes, - auth_user=auth_user, ) - @telemetry_event("delete_community_v3") - async def delete_community_v3( + @telemetry_event("delete_community") + async def delete_community( self, - graph_id: UUID, + parent_id: UUID, community_id: UUID, - auth_user: Any, - **kwargs, - ): - return await self.providers.database.graph_handler.communities.delete( - graph_id=graph_id, + ) -> None: + await self.providers.database.graph_handler.communities.delete( + parent_id=parent_id, community_id=community_id, - auth_user=auth_user, ) - @telemetry_event("list_communities_v3") - async def list_communities_v3( + @telemetry_event("list_communities") + async def list_communities( self, - id: UUID, + collection_id: UUID, offset: int, limit: int, - **kwargs, ): return await self.providers.database.graph_handler.communities.get( - id=id, + parent_id=collection_id, + store_type="graph", # type: ignore offset=offset, limit=limit, ) - # TODO: deprecate this @telemetry_event("get_communities") async def get_communities( self, - collection_id: Optional[UUID] = None, - levels: Optional[list[int]] = None, - community_ids: Optional[list[int]] = None, - offset: Optional[int] = None, - limit: Optional[int] = None, - **kwargs, + parent_id: UUID, + offset: int, + limit: int, + community_ids: Optional[list[UUID]] = None, + community_names: Optional[list[str]] = None, + include_embeddings: bool = False, ): return await self.providers.database.graph_handler.get_communities( - collection_id=collection_id, - levels=levels, + parent_id=parent_id, + offset=offset, + limit=limit, community_ids=community_ids, - offset=offset or 0, - limit=limit or -1, + include_embeddings=include_embeddings, ) # @telemetry_event("create_new_graph") @@ -635,14 +584,6 @@ async def kg_entity_description( return all_results - @telemetry_event("get_graph_status") - async def get_graph_status( - self, - collection_id: UUID, - **kwargs, - ): - raise NotImplementedError("Not implemented") - @telemetry_event("kg_clustering") async def kg_clustering( self, @@ -1241,12 +1182,30 @@ async def store_kg_extractions( "storing len(extraction.entities) = ", len(extraction.entities) ) - if extraction.entities: + for entity in extraction.entities: await self.providers.database.graph_handler.entities.create( - extraction.entities, store_type="document" + name=entity.name, + parent_id=entity.parent_id, + store_type="document", # type: ignore + category=entity.category, + description=entity.description, + description_embedding=entity.description_embedding, + chunk_ids=entity.chunk_ids, + metadata=entity.metadata, ) if extraction.relationships: - await self.providers.database.graph_handler.relationships.create( - extraction.relationships, store_type="document" - ) + for relationship in extraction.relationships: + await self.providers.database.graph_handler.relationships.create( + subject=relationship.subject, + subject_id=relationship.subject_id, + predicate=relationship.predicate, + object=relationship.object, + object_id=relationship.object_id, + parent_id=relationship.parent_id, + description=relationship.description, + description_embedding=relationship.description_embedding, + weight=relationship.weight, + metadata=relationship.metadata, + store_type="document", # type: ignore + ) diff --git a/py/core/main/services/retrieval_service.py b/py/core/main/services/retrieval_service.py index 706afe23c..b599bfac2 100644 --- a/py/core/main/services/retrieval_service.py +++ b/py/core/main/services/retrieval_service.py @@ -315,6 +315,9 @@ async def agent( ) messages = messages or [] + if message and not messages: + messages = [message] + current_message = messages[-1] # type: ignore # Save the new message to the conversation diff --git a/py/core/pipes/kg/clustering.py b/py/core/pipes/kg/clustering.py index 1103c09b5..c9275c240 100644 --- a/py/core/pipes/kg/clustering.py +++ b/py/core/pipes/kg/clustering.py @@ -54,9 +54,8 @@ async def cluster_kg( num_communities = await self.database_provider.graph_handler.perform_graph_clustering( collection_id=collection_id, - # graph_id=graph_id, leiden_params=leiden_params, - ) # type: ignore + ) logger.info( f"Clustering completed. Generated {num_communities} communities." diff --git a/py/core/pipes/kg/deduplication.py b/py/core/pipes/kg/deduplication.py index 5838d1b7c..9ce3f62ba 100644 --- a/py/core/pipes/kg/deduplication.py +++ b/py/core/pipes/kg/deduplication.py @@ -1,10 +1,8 @@ import json import logging -from typing import Any, Union +from typing import Any from uuid import UUID -from fastapi import HTTPException - from core.base import AsyncState from core.base.abstractions import DataLevel, Entity, KGEntityDeduplicationType from core.base.pipes import AsyncPipe @@ -26,14 +24,12 @@ def __init__( self, config: AsyncPipe.PipeConfig, database_provider: PostgresDBProvider, - llm_provider: Union[ - OpenAICompletionProvider, LiteLLMCompletionProvider - ], - embedding_provider: Union[ - LiteLLMEmbeddingProvider, - OpenAIEmbeddingProvider, - OllamaEmbeddingProvider, - ], + llm_provider: OpenAICompletionProvider | LiteLLMCompletionProvider, + embedding_provider: ( + LiteLLMEmbeddingProvider + | OpenAIEmbeddingProvider + | OllamaEmbeddingProvider + ), logging_provider: SqlitePersistentLoggingProvider, **kwargs, ): diff --git a/py/core/pipes/kg/extraction.py b/py/core/pipes/kg/extraction.py index f5ea9700e..6a98d2a85 100644 --- a/py/core/pipes/kg/extraction.py +++ b/py/core/pipes/kg/extraction.py @@ -143,7 +143,7 @@ def parse_fn(response_str: str) -> Any: category=entity_category, description=entity_description, name=entity_value, - document_id=extractions[0].document_id, + parent_id=extractions[0].document_id, chunk_ids=[ extraction.id for extraction in extractions ], @@ -167,7 +167,7 @@ def parse_fn(response_str: str) -> Any: object=object, description=description, weight=weight, - document_id=extractions[0].document_id, + parent_id=extractions[0].document_id, chunk_ids=[ extraction.id for extraction in extractions ], diff --git a/py/core/pipes/kg/storage.py b/py/core/pipes/kg/storage.py index 37eeef3f7..510cd5ca6 100644 --- a/py/core/pipes/kg/storage.py +++ b/py/core/pipes/kg/storage.py @@ -5,7 +5,6 @@ from core.base import AsyncState, KGExtraction, R2RDocumentProcessingError from core.base.pipes.base_pipe import AsyncPipe -from core.providers.database.graph import DataLevel from core.providers.database.postgres import PostgresDBProvider from core.providers.logger.r2r_logger import SqlitePersistentLoggingProvider @@ -64,7 +63,7 @@ async def store( if not extraction.entities[0].chunk_ids: for i in range(len(extraction.entities)): extraction.entities[i].chunk_ids = extraction.chunk_ids - extraction.entities[i].document_id = ( + extraction.entities[i].parent_id = ( extraction.document_id ) diff --git a/py/core/pipes/retrieval/kg_search_pipe.py b/py/core/pipes/retrieval/kg_search_pipe.py index a8ba17177..1e42169cd 100644 --- a/py/core/pipes/retrieval/kg_search_pipe.py +++ b/py/core/pipes/retrieval/kg_search_pipe.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any, AsyncGenerator, Optional +from typing import Any, AsyncGenerator from uuid import UUID from core.base import ( @@ -15,7 +15,6 @@ KGCommunityResult, KGEntityResult, KGRelationshipResult, - KGSearchMethod, KGSearchResultType, SearchSettings, ) @@ -180,6 +179,14 @@ async def search( # "document_ids", ], ): + try: + # TODO - remove this nasty hack + search_result["metadata"] = json.loads( + search_result["metadata"] + ) + except: + pass + yield GraphSearchResult( content=KGRelationshipResult( # name=search_result["name"], @@ -234,7 +241,6 @@ async def search( rating_explanation=search_result["rating_explanation"], findings=search_result["findings"], ), - # method=KGSearchMethod.LOCAL, result_type=KGSearchResultType.COMMUNITY, metadata=( { diff --git a/py/core/providers/database/document.py b/py/core/providers/database/document.py index 227fd036c..ba6aba655 100644 --- a/py/core/providers/database/document.py +++ b/py/core/providers/database/document.py @@ -138,7 +138,7 @@ async def upsert_documents_overview( summary = $12, summary_embedding = $13 WHERE document_id = $14 """ - print("db_entry = ", db_entry) + await conn.execute( update_query, db_entry["collection_ids"], @@ -465,7 +465,7 @@ async def get_documents_overview( logger.warning( f"Failed to parse embedding for document {row['document_id']}: {e}" ) - print("row = ", row) + documents.append( DocumentResponse( id=row["document_id"], diff --git a/py/core/providers/database/graph.py b/py/core/providers/database/graph.py index b6e42e7bb..261959bdc 100644 --- a/py/core/providers/database/graph.py +++ b/py/core/providers/database/graph.py @@ -4,7 +4,7 @@ import logging import time from enum import Enum -from typing import Any, AsyncGenerator, List, Optional, Set, Tuple, Union +from typing import Any, AsyncGenerator, Optional, Set, Tuple, Union from uuid import UUID, uuid4 import asyncpg @@ -14,7 +14,6 @@ from core.base.abstractions import ( Community, CommunityInfo, - DataLevel, Entity, Graph, KGCreationSettings, @@ -120,56 +119,66 @@ async def create_tables(self) -> None: await self.connection_manager.execute_query(QUERY) async def create( - self, entities: list[Entity], store_type: StoreType - ) -> list[UUID]: - """Create multiple entities in the specified store.""" + self, + parent_id: UUID, + store_type: StoreType, + name: str, + category: Optional[str] = None, + description: Optional[str] = None, + description_embedding: Optional[list[float] | str] = None, + chunk_ids: Optional[list[UUID]] = None, + metadata: Optional[dict[str, Any] | str] = None, + ) -> Entity: + """Create a new entity in the specified store.""" table_name = self._get_entity_table_for_store(store_type) - values = [] - results = [] - - for entity in entities: - metadata = entity.metadata - if isinstance(metadata, str): - try: - metadata = json.loads(metadata) - except json.JSONDecodeError: - pass - description_embedding = entity.description_embedding - if isinstance(description_embedding, list): - description_embedding = str(description_embedding) + if isinstance(metadata, str): + try: + metadata = json.loads(metadata) + except json.JSONDecodeError: + pass - value = ( - entity.name, - entity.category, - entity.description, - entity.parent_id, - description_embedding, - entity.chunk_ids, - json.dumps(metadata) if metadata else None, - ) - values.append(value) + if isinstance(description_embedding, list): + description_embedding = str(description_embedding) - QUERY = f""" + query = f""" INSERT INTO {self._get_table_name(table_name)} (name, category, description, parent_id, description_embedding, chunk_ids, metadata) VALUES ($1, $2, $3, $4, $5, $6, $7) - RETURNING id + RETURNING id, name, category, description, parent_id, chunk_ids, metadata """ - for value in values: - print("inserting len(values) into graph = ", len(values)) - result = await self.connection_manager.fetchrow_query(QUERY, value) - results.append(result["id"]) + params = [ + name, + category, + description, + parent_id, + description_embedding, + chunk_ids, + json.dumps(metadata) if metadata else None, + ] - return results + result = await self.connection_manager.fetchrow_query( + query=query, + params=params, + ) + + return Entity( + id=result["id"], + name=result["name"], + category=result["category"], + description=result["description"], + parent_id=result["parent_id"], + chunk_ids=result["chunk_ids"], + metadata=result["metadata"], + ) async def get( self, parent_id: UUID, store_type: StoreType, - offset: int = 0, - limit: int = 100, + offset: int, + limit: int, entity_ids: Optional[list[UUID]] = None, entity_names: Optional[list[str]] = None, include_embeddings: bool = False, @@ -178,7 +187,7 @@ async def get( table_name = self._get_entity_table_for_store(store_type) conditions = ["parent_id = $1"] - params = [parent_id] + params: list[Any] = [parent_id] param_index = 2 if entity_ids: @@ -246,70 +255,91 @@ async def get( return entities, count async def update( - self, entities: list[Entity], store_type: StoreType - ) -> list[UUID]: - """Update multiple entities in the specified store.""" + self, + entity_id: UUID, + store_type: StoreType, + name: Optional[str] = None, + description: Optional[str] = None, + description_embedding: Optional[list[float] | str] = None, + category: Optional[str] = None, + metadata: Optional[dict] = None, + ) -> Entity: + """Update an entity in the specified store.""" table_name = self._get_entity_table_for_store(store_type) - results = [] + update_fields = [] + params: list[Any] = [] + param_index = 1 - print("entities = ", entities) - QUERY = f""" - UPDATE {self._get_table_name(table_name)} - SET - name = $1, - category = $2, - description = $3, - description_embedding = $4, - chunk_ids = $5, - metadata = $6, - updated_at = CURRENT_TIMESTAMP - WHERE id = $7 AND parent_id = $8 - RETURNING id - """ + if isinstance(metadata, str): + try: + metadata = json.loads(metadata) + except json.JSONDecodeError: + pass - for entity in entities: - metadata = entity.metadata - if isinstance(metadata, str): - try: - metadata = json.loads(metadata) - except json.JSONDecodeError: - pass + if name is not None: + update_fields.append(f"name = ${param_index}") + params.append(name) + param_index += 1 - description_embedding = entity.description_embedding - if isinstance(description_embedding, list): - description_embedding = str(description_embedding) + if description is not None: + update_fields.append(f"description = ${param_index}") + params.append(description) + param_index += 1 - params = [ - entity.name, - entity.category, - entity.description, - description_embedding, - entity.chunk_ids, - json.dumps(metadata) if metadata else None, - entity.id, - entity.parent_id, - ] - print("QUERY = ", QUERY) + if description_embedding is not None: + update_fields.append(f"description_embedding = ${param_index}") + params.append(description_embedding) + param_index += 1 + + if category is not None: + update_fields.append(f"category = ${param_index}") + params.append(category) + param_index += 1 + + if metadata is not None: + update_fields.append(f"metadata = ${param_index}") + params.append(json.dumps(metadata)) + param_index += 1 + + if not update_fields: + raise R2RException(status_code=400, message="No fields to update") + + update_fields.append("updated_at = NOW()") + params.append(entity_id) + query = f""" + UPDATE {self._get_table_name(table_name)} + SET {', '.join(update_fields)} + WHERE id = ${param_index}\ + RETURNING id, name, category, description, parent_id, chunk_ids, metadata + """ + try: result = await self.connection_manager.fetchrow_query( - QUERY, params + query=query, + params=params, ) - if not result: - raise R2RException( - f"Entity {entity.id} not found in {store_type} store or no permission to update", - 404, - ) - results.append(result["id"]) - - return results + return Entity( + id=result["id"], + name=result["name"], + category=result["category"], + description=result["description"], + parent_id=result["parent_id"], + chunk_ids=result["chunk_ids"], + metadata=result["metadata"], + ) + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"An error occurred while updating the entity: {e}", + ) async def delete( self, parent_id: UUID, entity_ids: Optional[list[UUID]] = None, store_type: StoreType = StoreType.GRAPH, - ) -> list[UUID]: + ) -> None: """ Delete entities from the specified store. If entity_ids is not provided, deletes all entities for the given parent_id. @@ -344,6 +374,9 @@ async def delete( WHERE id = ANY($1) AND parent_id = $2 RETURNING id """ + print("QUERY = ", QUERY) + print("entity_ids = ", entity_ids) + print("parent_id = ", parent_id) results = await self.connection_manager.fetch_query( QUERY, [entity_ids, parent_id] ) @@ -356,8 +389,6 @@ async def delete( 404, ) - return [row["id"] for row in results] - class PostgresRelationshipHandler(RelationshipHandler): def __init__(self, *args: Any, **kwargs: Any) -> None: @@ -436,65 +467,84 @@ async def create_tables(self) -> None: await self.connection_manager.execute_query(QUERY) async def create( - self, relationships: list[Relationship], store_type: StoreType - ) -> list[UUID]: - """Create multiple relationships in the specified store.""" + self, + subject: str, + subject_id: UUID, + predicate: str, + object: str, + object_id: UUID, + parent_id: UUID, + store_type: StoreType, + description: str | None = None, + weight: float | None = 1.0, + chunk_ids: Optional[list[UUID]] = None, + description_embedding: Optional[list[float] | str] = None, + metadata: Optional[dict[str, Any] | str] = None, + ) -> Relationship: + """Create a new relationship in the specified store.""" table_name = self._get_relationship_table_for_store(store_type) - values = [] - results = [] - for relationship in relationships: - metadata = relationship.metadata - if isinstance(metadata, str): - try: - metadata = json.loads(metadata) - except json.JSONDecodeError: - pass + if isinstance(metadata, str): + try: + metadata = json.loads(metadata) + except json.JSONDecodeError: + pass - description_embedding = relationship.description_embedding - if isinstance(description_embedding, list): - description_embedding = str(description_embedding) + if isinstance(description_embedding, list): + description_embedding = str(description_embedding) - value = ( - relationship.subject, - relationship.predicate, - relationship.object, - relationship.description, - relationship.subject_id, - relationship.object_id, - relationship.weight, - relationship.chunk_ids, - relationship.parent_id, - description_embedding, - json.dumps(metadata) if metadata else None, - ) - values.append(value) - - QUERY = f""" + query = f""" INSERT INTO {self._get_table_name(table_name)} (subject, predicate, object, description, subject_id, object_id, weight, chunk_ids, parent_id, description_embedding, metadata) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) - RETURNING id + RETURNING id, subject, predicate, object, description, subject_id, object_id, weight, chunk_ids, parent_id, metadata """ - for value in values: - result = await self.connection_manager.fetchrow_query(QUERY, value) - results.append(result["id"]) + params = [ + subject, + predicate, + object, + description, + subject_id, + object_id, + weight, + chunk_ids, + parent_id, + description_embedding, + json.dumps(metadata) if metadata else None, + ] + + result = await self.connection_manager.fetchrow_query( + query=query, + params=params, + ) - return results + return Relationship( + id=result["id"], + subject=result["subject"], + predicate=result["predicate"], + object=result["object"], + description=result["description"], + subject_id=result["subject_id"], + object_id=result["object_id"], + weight=result["weight"], + chunk_ids=result["chunk_ids"], + parent_id=result["parent_id"], + metadata=result["metadata"], + ) async def get( self, parent_id: UUID, store_type: StoreType, - offset: int = 0, - limit: int = 100, + offset: int, + limit: int, relationship_ids: Optional[list[UUID]] = None, entity_names: Optional[list[str]] = None, relationship_types: Optional[list[str]] = None, include_metadata: bool = False, - ) -> tuple[list[Relationship], int]: + ): """ Get relationships from the specified store. @@ -514,7 +564,7 @@ async def get( table_name = self._get_relationship_table_for_store(store_type) conditions = ["parent_id = $1"] - params = [parent_id] + params: list[Any] = [parent_id] param_index = 2 if relationship_ids: @@ -584,74 +634,122 @@ async def get( ) except json.JSONDecodeError: pass + elif not include_metadata: + relationship_dict.pop("metadata", None) relationships.append(Relationship(**relationship_dict)) return relationships, count async def update( - self, relationships: list[Relationship], store_type: StoreType - ) -> list[UUID]: + self, + relationship_id: UUID, + store_type: StoreType, + subject: Optional[str], + subject_id: Optional[UUID], + predicate: Optional[str], + object: Optional[str], + object_id: Optional[UUID], + description: Optional[str], + description_embedding: Optional[list[float] | str], + weight: Optional[float], + metadata: Optional[dict[str, Any] | str], + ) -> Relationship: """Update multiple relationships in the specified store.""" table_name = self._get_relationship_table_for_store(store_type) - results = [] + update_fields = [] + params: list = [] + param_index = 1 - QUERY = f""" - UPDATE {self._get_table_name(table_name)} - SET - subject = $1, - predicate = $2, - object = $3, - description = $4, - subject_id = $5, - object_id = $6, - weight = $7, - chunk_ids = $8, - metadata = $9, - updated_at = CURRENT_TIMESTAMP - WHERE id = $10 AND parent_id = $11 - RETURNING id - """ + if isinstance(metadata, str): + try: + metadata = json.loads(metadata) + except json.JSONDecodeError: + pass - for relationship in relationships: - metadata = relationship.metadata - if isinstance(metadata, str): - try: - metadata = json.loads(metadata) - except json.JSONDecodeError: - pass + if subject is not None: + update_fields.append(f"subject = ${param_index}") + params.append(subject) + param_index += 1 - params = [ - relationship.subject, - relationship.predicate, - relationship.object, - relationship.description, - relationship.subject_id, - relationship.object_id, - relationship.weight, - relationship.chunk_ids, - json.dumps(metadata) if metadata else None, - relationship.id, - relationship.parent_id, - ] + if subject_id is not None: + update_fields.append(f"subject_id = ${param_index}") + params.append(subject_id) + param_index += 1 + + if predicate is not None: + update_fields.append(f"predicate = ${param_index}") + params.append(predicate) + param_index += 1 + + if object is not None: + update_fields.append(f"object = ${param_index}") + params.append(object) + param_index += 1 + + if object_id is not None: + update_fields.append(f"object_id = ${param_index}") + params.append(object_id) + param_index += 1 + + if description is not None: + update_fields.append(f"description = ${param_index}") + params.append(description) + param_index += 1 + + if description_embedding is not None: + update_fields.append(f"description_embedding = ${param_index}") + params.append(description_embedding) + param_index += 1 + + if weight is not None: + update_fields.append(f"weight = ${param_index}") + params.append(weight) + param_index += 1 + + if not update_fields: + raise R2RException(status_code=400, message="No fields to update") + + update_fields.append("updated_at = NOW()") + params.append(relationship_id) + + query = f""" + UPDATE {self._get_table_name(table_name)} + SET {', '.join(update_fields)} + WHERE id = ${param_index} + RETURNING id, subject, predicate, object, description, subject_id, object_id, weight, chunk_ids, parent_id, metadata + """ + try: result = await self.connection_manager.fetchrow_query( - QUERY, params + query=query, + params=params, ) - if not result: - raise R2RException( - f"Relationship {relationship.id} not found in {store_type} store or no permission to update", - 404, - ) - results.append(result["id"]) - return results + return Relationship( + id=result["id"], + subject=result["subject"], + predicate=result["predicate"], + object=result["object"], + description=result["description"], + subject_id=result["subject_id"], + object_id=result["object_id"], + weight=result["weight"], + chunk_ids=result["chunk_ids"], + parent_id=result["parent_id"], + metadata=result["metadata"], + ) + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"An error occurred while updating the relationship: {e}", + ) async def delete( self, parent_id: UUID, relationship_ids: Optional[list[UUID]] = None, store_type: StoreType = StoreType.GRAPH, - ) -> list[UUID]: + ) -> None: """ Delete relationships from the specified store. If relationship_ids is not provided, deletes all relationships for the given parent_id. @@ -695,8 +793,6 @@ async def delete( 404, ) - return [row["id"] for row in results] - class PostgresCommunityHandler(CommunityHandler): @@ -721,7 +817,7 @@ async def create_tables(self) -> None: node TEXT NOT NULL, cluster UUID NOT NULL, parent_cluster INT, - level INT NOT NULL, + level INT, is_final_cluster BOOLEAN NOT NULL, relationship_ids UUID[] NOT NULL, graph_id UUID, @@ -740,207 +836,275 @@ async def create_tables(self) -> None: graph_id UUID, collection_id UUID, community_id UUID, - level INT NOT NULL, + level INT, name TEXT NOT NULL, summary TEXT NOT NULL, - findings TEXT[] NOT NULL, - rating FLOAT NOT NULL, - rating_explanation TEXT NOT NULL, + findings TEXT[], + rating FLOAT, + rating_explanation TEXT, description_embedding {vector_column_str} NOT NULL, created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, metadata JSONB, UNIQUE (community_id, level, graph_id, collection_id) );""" - # created_by UUID REFERENCES {self._get_table_name("users")}(user_id), - # updated_by UUID REFERENCES {self._get_table_name("users")}(user_id), await self.connection_manager.execute_query(query) async def create( self, - graph_id: UUID, + parent_id: UUID, + store_type: StoreType, name: str, summary: str, - embedding: str, - findings: list[str], + findings: Optional[list[str]], rating: Optional[float], rating_explanation: Optional[str], - level: Optional[int], - metadata: Optional[dict], - auth_user: Any, - ) -> None: + description_embedding: Optional[list[float] | str] = None, + ) -> Community: + # Do we ever want to get communities from document store? + table_name = "graph_community" - if not auth_user.is_superuser: - if not await self._check_permissions(graph_id, auth_user.id): - raise R2RException( - "You do not have permission to create this community.", - 403, - ) + if isinstance(description_embedding, list): + description_embedding = str(description_embedding) - QUERY = f""" - INSERT INTO {self._get_table_name("graph_community")} - (graph_id, name, summary, findings, rating, rating_explanation, description_embedding, level, metadata, created_by, updated_by) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) - RETURNING id, graph_id, name, summary, findings, rating, rating_explanation, level, metadata, created_by, updated_by + query = f""" + INSERT INTO {self._get_table_name(table_name)} + (collection_id, name, summary, findings, rating, rating_explanation, description_embedding) + VALUES ($1, $2, $3, $4, $5, $6, $7) + RETURNING id, collection_id, name, summary, findings, rating, rating_explanation, created_at, updated_at """ params = [ - graph_id, + parent_id, name, summary, findings, rating, rating_explanation, - embedding, - level, - metadata, - auth_user.id, - auth_user.id, + description_embedding, ] - return await self.connection_manager.fetchrow_query(QUERY, params) + try: + result = await self.connection_manager.fetchrow_query( + query=query, + params=params, + ) + + return Community( + id=result["id"], + collection_id=result["collection_id"], + name=result["name"], + summary=result["summary"], + findings=result["findings"], + rating=result["rating"], + rating_explanation=result["rating_explanation"], + created_at=result["created_at"], + updated_at=result["updated_at"], + ) + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"An error occurred while creating the community: {e}", + ) async def update( self, - id: UUID, community_id: UUID, - name: Optional[str], - summary: Optional[str], - embedding: Optional[str], - findings: Optional[list[str]], - rating: Optional[float], - rating_explanation: Optional[str], - level: Optional[int], - metadata: Optional[dict], - auth_user: Any, - ) -> None: - - if not auth_user.is_superuser: - if not await self._check_permissions(id, auth_user.id): - raise R2RException( - "You do not have permission to update this community.", - 403, - ) - + store_type: StoreType, + name: Optional[str] = None, + summary: Optional[str] = None, + summary_embedding: Optional[list[float] | str] = None, + findings: Optional[list[str]] = None, + rating: Optional[float] = None, + rating_explanation: Optional[str] = None, + ) -> Community: + table_name = "graph_community" update_fields = [] - params = [community_id] # type: ignore + params: list[Any] = [] + param_index = 1 + if name is not None: - update_fields.append(f"name = ${len(params)+1}") + update_fields.append(f"name = ${param_index}") params.append(name) + param_index += 1 if summary is not None: - update_fields.append(f"summary = ${len(params)+1}") + update_fields.append(f"summary = ${param_index}") params.append(summary) + param_index += 1 - if embedding is not None: - update_fields.append(f"description_embedding = ${len(params)+1}") - params.append(embedding) + if summary_embedding is not None: + update_fields.append(f"description_embedding = ${param_index}") + params.append(summary_embedding) + param_index += 1 if findings is not None: - update_fields.append(f"findings = ${len(params)+1}") + update_fields.append(f"findings = ${param_index}") params.append(findings) + param_index += 1 if rating is not None: - update_fields.append(f"rating = ${len(params)+1}") + update_fields.append(f"rating = ${param_index}") params.append(rating) + param_index += 1 if rating_explanation is not None: - update_fields.append(f"rating_explanation = ${len(params)+1}") + update_fields.append(f"rating_explanation = ${param_index}") params.append(rating_explanation) + param_index += 1 - if level is not None: - update_fields.append(f"level = ${len(params)+1}") - params.append(level) - - if metadata is not None: - update_fields.append(f"metadata = ${len(params)+1}") - params.append(metadata) - - update_fields.append(f"updated_by = ${len(params)+1}") - params.append(auth_user.id) + if not update_fields: + raise R2RException(status_code=400, message="No fields to update") - update_fields.append(f"updated_at = CURRENT_TIMESTAMP") + update_fields.append("updated_at = NOW()") + params.append(community_id) - QUERY = f""" - UPDATE {self._get_table_name("graph_community")} SET {", ".join(update_fields)} WHERE id = $1 - RETURNING id, graph_id, name, summary, findings, rating, rating_explanation, metadata, level, created_by, updated_by, updated_at + query = f""" + UPDATE {self._get_table_name(table_name)} + SET {", ".join(update_fields)} + WHERE id = ${param_index}\ + RETURNING id, community_id, name, summary, findings, rating, rating_explanation, created_at, updated_at """ - return await self.connection_manager.fetchrow_query(QUERY, params) + try: + result = await self.connection_manager.fetchrow_query( + query, params + ) + + return Community( + id=result["id"], + community_id=result["community_id"], + name=result["name"], + summary=result["summary"], + findings=result["findings"], + rating=result["rating"], + rating_explanation=result["rating_explanation"], + created_at=result["created_at"], + updated_at=result["updated_at"], + ) + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"An error occurred while updating the community: {e}", + ) async def delete( - self, graph_id: UUID, community_id: UUID, auth_user: Any + self, + parent_id: UUID, + community_id: UUID, ) -> None: + table_name = "graph_community" - if not auth_user.is_superuser: - if not await self._check_permissions(graph_id, auth_user.id): - raise R2RException( - "You do not have permission to delete this community.", - 403, - ) + query = f""" + DELETE FROM {self._get_table_name(table_name)} + WHERE id = $1 AND collection_id = $2 + """ + print("query = ", query) + print("parent_id = ", parent_id) + print("community_id = ", community_id) + params = [community_id, parent_id] - QUERY = f""" - DELETE FROM {self._get_table_name("graph_community")} WHERE id = $1 + try: + results = await self.connection_manager.execute_query( + query, params + ) + print("results = ", results) + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"An error occurred while deleting the community: {e}", + ) + table_name = "graph_community_info" + query = f""" + DELETE FROM {self._get_table_name(table_name)} + WHERE id = $1 AND collection_id = $2 """ - await self.connection_manager.execute_query(QUERY, [community_id]) + print("query = ", query) + print("parent_id = ", parent_id) + print("community_id = ", community_id) + params = [community_id, parent_id] + + try: + results = await self.connection_manager.execute_query( + query, params + ) + print("results = ", results) + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"An error occurred while deleting the community: {e}", + ) async def get( self, - graph_id: UUID, + parent_id: UUID, + store_type: StoreType, offset: int, limit: int, - community_id: Optional[UUID] = None, - auth_user: Optional[Any] = None, + community_ids: Optional[list[UUID]] = None, + community_names: Optional[list[str]] = None, + include_embeddings: bool = False, ): + """Retrieve communities from the specified store.""" + # Do we ever want to get communities from document store? + table_name = "graph_community" - if not auth_user.is_superuser: - if not await self._check_permissions(graph_id, auth_user.id): - raise R2RException( - "You do not have permission to access this graph.", - 403, - ) + conditions = ["collection_id = $1"] + params: list[Any] = [parent_id] + param_index = 2 - if community_id is None: + if community_ids: + conditions.append(f"id = ANY(${param_index})") + params.append(community_ids) + param_index += 1 - QUERY = f""" - SELECT - id, graph_id, name, summary, findings, rating, rating_explanation, level, metadata, created_by, updated_by, created_at, updated_at - FROM {self._get_table_name("graph_community")} WHERE graph_id = $1 - OFFSET $2 LIMIT $3 - """ - params = [graph_id, offset, limit] - communities = [ - Community(**row) - for row in await self.connection_manager.fetch_query( - QUERY, params - ) - ] + if community_names: + conditions.append(f"name = ANY(${param_index})") + params.append(community_names) + param_index += 1 - QUERY_COUNT = f""" - SELECT COUNT(*) FROM {self._get_table_name("graph_community")} WHERE graph_id = $1 - """ - count = ( - await self.connection_manager.fetch_query( - QUERY_COUNT, [graph_id] - ) - )[0]["count"] + select_fields = """ + id, community_id, name, summary, findings, rating, + rating_explanation, level, created_at, updated_at + """ + if include_embeddings: + select_fields += ", description_embedding" - return communities, count + COUNT_QUERY = f""" + SELECT COUNT(*) + FROM {self._get_table_name(table_name)} + WHERE {' AND '.join(conditions)} + """ - else: - QUERY = f""" - SELECT - id, graph_id, name, summary, findings, rating, rating_explanation, level, metadata, created_by, updated_by, created_at, updated_at - FROM {self._get_table_name("graph_community")} WHERE graph_id = $1 AND id = $2 - """ - params = [graph_id, community_id] - return [ - Community( - **await self.connection_manager.fetchrow_query( - QUERY, params - ) - ) - ] + count = ( + await self.connection_manager.fetch_query( + COUNT_QUERY, params[: param_index - 1] + ) + )[0]["count"] + + QUERY = f""" + SELECT {select_fields} + FROM {self._get_table_name(table_name)} + WHERE {' AND '.join(conditions)} + ORDER BY created_at + OFFSET ${param_index} + """ + params.append(offset) + param_index += 1 + + if limit != -1: + QUERY += f" LIMIT ${param_index}" + params.append(limit) + + rows = await self.connection_manager.fetch_query(QUERY, params) + + communities = [] + for row in rows: + community_dict = dict(row) + + communities.append(Community(**community_dict)) + + return communities, count class PostgresGraphHandler(GraphHandler): @@ -1117,7 +1281,7 @@ async def reset(self, graph_id: UUID) -> None: # Delete all graph communities and community info community_delete_queries = [ f"""DELETE FROM {self._get_table_name("graph_community_info")} - WHERE graph_id = $1""", + WHERE collection_id = $1""", f"""DELETE FROM {self._get_table_name("graph_community")} WHERE collection_id = $1""", ] @@ -1668,58 +1832,6 @@ async def add_entities_v3( # return True - async def add_relationships_v3( - self, id: UUID, relationship_ids: list[UUID], copy_data: bool = True - ) -> bool: - """ - Add relationships to the graph. - """ - QUERY = f""" - UPDATE {self._get_table_name("relationship")} - SET graph_ids = array_append(graph_ids, $1) - WHERE id = ANY($2) - """ - await self.connection_manager.execute_query( - QUERY, [id, relationship_ids] - ) - - if copy_data: - QUERY = f""" - INSERT INTO {self._get_table_name("graph_relationship")} - SELECT * FROM {self._get_table_name("relationship")} - WHERE id = ANY($1) - """ - await self.connection_manager.execute_query( - QUERY, [relationship_ids] - ) - - return True - - async def remove_relationships( - self, id: UUID, relationship_ids: list[UUID], delete_data: bool = True - ) -> bool: - """ - Remove relationships from the graph. - """ - QUERY = f""" - UPDATE {self._get_table_name("relationship")} - SET graph_ids = array_remove(graph_ids, $1) - WHERE id = ANY($2) - """ - await self.connection_manager.execute_query( - QUERY, [id, relationship_ids] - ) - - if delete_data: - QUERY = f""" - DELETE FROM {self._get_table_name("graph_relationship")} WHERE id = ANY($1) - """ - await self.connection_manager.execute_query( - QUERY, [relationship_ids] - ) - - return True - async def update( self, graph_id: UUID, @@ -1751,7 +1863,7 @@ async def update( UPDATE {self._get_table_name("graph")} SET {', '.join(update_fields)} WHERE id = ${param_index} - RETURNING id, name, description, status, created_at, updated_at, collection_id, document_ids, + RETURNING id, name, description, status, created_at, updated_at, collection_id, document_ids """ try: @@ -2013,7 +2125,7 @@ async def get_deduplication_estimate( async def get_entities( self, - graph_id: UUID, + parent_id: UUID, offset: int, limit: int, entity_ids: Optional[list[UUID]] = None, @@ -2026,7 +2138,7 @@ async def get_entities( Args: offset: Number of records to skip limit: Maximum number of records to return (-1 for no limit) - graph_id: UUID of the graph + parent_id: UUID of the collection entity_ids: Optional list of entity IDs to filter by entity_names: Optional list of entity names to filter by include_embeddings: Whether to include embeddings in the response @@ -2035,7 +2147,7 @@ async def get_entities( Tuple of (list of entities, total count) """ conditions = ["parent_id = $1"] - params = [graph_id] + params: list[Any] = [parent_id] param_index = 2 if entity_ids: @@ -2100,6 +2212,7 @@ async def get_entities( ) except json.JSONDecodeError: pass + entities.append(Entity(**entity_dict)) return entities, count @@ -2157,7 +2270,6 @@ async def delete_node_via_document_id( # Execute separate DELETE queries delete_queries = [ - f"DELETE FROM {self._get_table_name('chunk_entity')} WHERE document_id = $1", f"DELETE FROM {self._get_table_name('relationship')} WHERE document_id = $1", f"DELETE FROM {self._get_table_name('entity')} WHERE document_id = $1", ] @@ -2195,30 +2307,6 @@ async def delete_node_via_document_id( return None return None - ##################### RELATIONSHIP METHODS ##################### - - # DEPRECATED - async def add_relationships( - self, - relationships: list[Relationship], - table_name: str = "relationship", - ): # type: ignore - """ - Upsert relationships into the relationship table. These are raw relationships extracted from the document. - - Args: - relationships: list[Relationship]: list of relationships to upsert - table_name: str: name of the table to upsert into - - Returns: - result: asyncpg.Record: result of the upsert operation - """ - return await _add_objects( - objects=[ele.to_dict() for ele in relationships], - full_table_name=self._get_table_name(table_name), - connection_manager=self.connection_manager, - ) - async def get_all_relationships( self, collection_id: UUID | None, @@ -2273,8 +2361,8 @@ async def get( self, parent_id: UUID, store_type: StoreType, - offset: int = 0, - limit: int = 100, + offset: int, + limit: int, entity_names: Optional[list[str]] = None, relationship_types: Optional[list[str]] = None, relationship_id: Optional[UUID] = None, @@ -2319,7 +2407,7 @@ async def get( # Build conditions and parameters for listing relationships conditions = ["parent_id = $1"] - params = [parent_id] + params: list[Any] = [parent_id] param_index = 2 if entity_names: @@ -2374,6 +2462,9 @@ async def get( ) except json.JSONDecodeError: pass + elif not include_metadata: + relationship_dict.pop("metadata", None) + relationships.append(Relationship(**relationship_dict)) return relationships, count @@ -2425,7 +2516,7 @@ async def has_document(self, graph_id: UUID, document_id: UUID) -> bool: # return [item["community_id"] for item in community_ids] async def check_communities_exist( - self, collection_id: UUID, community_ids: List[UUID] + self, collection_id: UUID, community_ids: list[UUID] ) -> Set[UUID]: """ Check which communities already exist in the database. @@ -2480,56 +2571,71 @@ async def add_community_info( async def get_communities( self, + parent_id: UUID, offset: int, limit: int, - collection_id: Optional[UUID] = None, - levels: Optional[list[int]] = None, community_ids: Optional[list[UUID]] = None, - ) -> dict: - conditions = [] - params: list = [collection_id] - param_index = 2 + include_embeddings: bool = False, + ) -> tuple[list[Community], int]: + """ + Get communities for a graph. - if levels is not None: - conditions.append(f"level = ANY(${param_index})") - params.append(levels) - param_index += 1 + Args: + collection_id: UUID of the collection + offset: Number of records to skip + limit: Maximum number of records to return (-1 for no limit) + community_ids: Optional list of community IDs to filter by + include_embeddings: Whether to include embeddings in the response + + Returns: + Tuple of (list of communities, total count) + """ + conditions = ["collection_id = $1"] + params: list[Any] = [parent_id] + param_index = 2 - if community_ids is not None: - conditions.append(f"community_id = ANY(${param_index})") + if community_ids: + conditions.append(f"id = ANY(${param_index})") params.append(community_ids) param_index += 1 - pagination_params = [] - if offset: - pagination_params.append(f"OFFSET ${param_index}") - params.append(offset) - param_index += 1 + select_fields = """ + id, collection_id, name, summary, findings, rating, rating_explanation + """ + if include_embeddings: + select_fields += ", description_embedding" + + COUNT_QUERY = f""" + SELECT COUNT(*) + FROM {self._get_table_name("graph_community")} + WHERE {' AND '.join(conditions)} + """ + count = ( + await self.connection_manager.fetch_query(COUNT_QUERY, params) + )[0]["count"] + + QUERY = f""" + SELECT {select_fields} + FROM {self._get_table_name("graph_community")} + WHERE {' AND '.join(conditions)} + ORDER BY created_at + OFFSET ${param_index} + """ + params.append(offset) + param_index += 1 if limit != -1: - pagination_params.append(f"LIMIT ${param_index}") + QUERY += f" LIMIT ${param_index}" params.append(limit) - param_index += 1 - pagination_clause = " ".join(pagination_params) + rows = await self.connection_manager.fetch_query(QUERY, params) - query = f""" - SELECT id, community_id, collection_id, level, name, summary, findings, rating, rating_explanation, COUNT(*) OVER() AS total_entries - FROM {self._get_table_name('graph_community')} - WHERE collection_id = $1 - {" AND " + " AND ".join(conditions) if conditions else ""} - ORDER BY community_id - {pagination_clause} - """ + communities = [] + for row in rows: + community_dict = dict(row) + communities.append(Community(**community_dict)) - results = await self.connection_manager.fetch_query(query, params) - total_entries = results[0]["total_entries"] if results else 0 - communities = [Community(**community) for community in results] - - return { - "communities": communities, - "total_entries": total_entries, - } + return communities, count async def get_community_details( self, @@ -2659,7 +2765,6 @@ async def delete_graph_for_collection( # TODO: make these queries more efficient. Pass the document_ids as params. if cascade: DELETE_QUERIES += [ - f"DELETE FROM {self._get_table_name('chunk_entity')} WHERE document_id = ANY($1::uuid[]);", f"DELETE FROM {self._get_table_name('relationship')} WHERE document_id = ANY($1::uuid[]);", f"DELETE FROM {self._get_table_name('entity')} WHERE document_id = ANY($1::uuid[]);", f"DELETE FROM {self._get_table_name('graph_entity')} WHERE collection_id = $1;", @@ -2693,10 +2798,8 @@ async def delete_graph_for_collection( async def perform_graph_clustering( self, - collection_id: UUID | None, - # graph_id: UUID | None, + collection_id: UUID, leiden_params: dict[str, Any], - use_community_cache: bool = False, ) -> int: """ Leiden clustering algorithm to cluster the knowledge graph relationships into communities. @@ -2715,8 +2818,6 @@ async def perform_graph_clustering( check_directed: bool = True, """ - start_time = time.time() - # # relationships = await self.get_all_relationships( # # collection_id, collection_id # , graph_id # # ) @@ -2772,7 +2873,7 @@ async def perform_graph_clustering( # relationship_ids_cache, leiden_params, collection_id # ) # else: - num_communities = await self._cluster_and_add_community_info( + return await self._cluster_and_add_community_info( relationships=relationships, relationship_ids_cache=relationship_ids_cache, leiden_params=leiden_params, @@ -2780,10 +2881,6 @@ async def perform_graph_clustering( # graph_id=collection_id, ) - return num_communities - - ####################### MANAGEMENT METHODS ####################### - async def get_entity_map( self, offset: int, limit: int, document_id: UUID ) -> dict[str, dict[str, list[dict[str, Any]]]]: @@ -2852,64 +2949,6 @@ async def get_entity_map( return entity_map - async def get_graph_status(self, collection_id: UUID) -> dict: - # check document_info table for the documents in the collection and return the status of each document - kg_extraction_statuses = await self.connection_manager.fetch_query( - f"SELECT document_id, extraction_status FROM {self._get_table_name('document_info')} WHERE collection_id = $1", - [collection_id], - ) - - document_ids = [ - doc_id["document_id"] for doc_id in kg_extraction_statuses - ] - - graph_cluster_statuses = await self.connection_manager.fetch_query( - f"SELECT enrichment_status FROM {self._get_table_name(PostgresCollectionHandler.TABLE_NAME)} WHERE id = $1", - [collection_id], - ) - - # entity and relationship counts - chunk_entity_count = await self.connection_manager.fetch_query( - f"SELECT COUNT(*) FROM {self._get_table_name('chunk_entity')} WHERE document_id = ANY($1)", - [document_ids], - ) - - relationship_count = await self.connection_manager.fetch_query( - f"SELECT COUNT(*) FROM {self._get_table_name('relationship')} WHERE document_id = ANY($1)", - [document_ids], - ) - - entity_count = await self.connection_manager.fetch_query( - f"SELECT COUNT(*) FROM {self._get_table_name('entity')} WHERE document_id = ANY($1)", - [document_ids], - ) - - graph_entity_count = await self.connection_manager.fetch_query( - f"SELECT COUNT(*) FROM {self._get_table_name('graph_entity')} WHERE collection_id = $1", - [collection_id], - ) - - community_count = await self.connection_manager.fetch_query( - f"SELECT COUNT(*) FROM {self._get_table_name('community')} WHERE collection_id = $1", - [collection_id], - ) - - return { - "kg_extraction_statuses": kg_extraction_statuses, - "graph_cluster_status": graph_cluster_statuses[0][ - "enrichment_status" - ], - "chunk_entity_count": chunk_entity_count[0]["count"], - "relationship_count": relationship_count[0]["count"], - "entity_count": entity_count[0]["count"], - "graph_entity_count": graph_entity_count[0]["count"], - "community_count": community_count[0]["count"], - } - - ####################### ESTIMATION METHODS ####################### - - ####################### GRAPH SEARCH METHODS ####################### - def _build_filters( self, filters: dict, parameters: list[Union[str, int, bytes]] ) -> str: @@ -3090,8 +3129,6 @@ async def graph_search( print("output = ", output) yield output - ####################### GRAPH CLUSTERING METHODS ####################### - async def _create_graph_and_cluster( self, relationships: list[Relationship], leiden_params: dict[str, Any] ) -> Any: @@ -3107,11 +3144,7 @@ async def _create_graph_and_cluster( logger.info(f"Graph has {len(G.nodes)} nodes and {len(G.edges)} edges") - hierarchical_communities = await self._compute_leiden_communities( - G, leiden_params - ) - - return hierarchical_communities + return await self._compute_leiden_communities(G, leiden_params) async def _cluster_and_add_community_info( self, @@ -3395,8 +3428,6 @@ async def _compute_leiden_communities( except ImportError as e: raise ImportError("Please install the graspologic package.") from e - ####################### UTILITY METHODS ####################### - async def get_existing_document_entity_chunk_ids( self, document_id: UUID ) -> list[str]: @@ -3410,23 +3441,6 @@ async def get_existing_document_entity_chunk_ids( ) ] - async def create_vector_index(self): - # need to implement this. Just call vector db provider's create_vector_index method. - # this needs to be run periodically for every collection. - raise NotImplementedError - - async def structured_query(self): - raise NotImplementedError - - async def update_extraction_prompt(self): - raise NotImplementedError - - async def update_kg_search_prompt(self): - raise NotImplementedError - - async def upsert_relationships(self): - raise NotImplementedError - async def get_entity_count( self, collection_id: Optional[UUID] = None, @@ -3475,41 +3489,6 @@ async def get_entity_count( "count" ] - async def get_relationship_count( - self, - collection_id: Optional[UUID] = None, - document_id: Optional[UUID] = None, - ) -> int: - if collection_id is None and document_id is None: - raise ValueError( - "Either collection_id or document_id must be provided." - ) - - conditions = [] - params = [] - - if collection_id: - conditions.append( - f""" - document_id = ANY( - SELECT document_id FROM {self._get_table_name("document_info")} - WHERE $1 = ANY(collection_ids) - ) - """ - ) - params.append(str(collection_id)) - else: - conditions.append("document_id = $1") - params.append(str(document_id)) - - QUERY = f""" - SELECT COUNT(*) FROM {self._get_table_name("relationship")} - WHERE {" AND ".join(conditions)} - """ - return (await self.connection_manager.fetch_query(QUERY, params))[0][ - "count" - ] - async def update_entity_descriptions(self, entities: list[Entity]): query = f""" @@ -3521,7 +3500,7 @@ async def update_entity_descriptions(self, entities: list[Entity]): inputs = [ ( entity.name, - entity.graph_id, + entity.parent_id, entity.description, entity.description_embedding, ) @@ -3530,8 +3509,6 @@ async def update_entity_descriptions(self, entities: list[Entity]): await self.connection_manager.execute_many(query, inputs) # type: ignore - ####################### PRIVATE METHODS ########################## - def _json_serialize(obj): if isinstance(obj, UUID): diff --git a/py/sdk/models.py b/py/sdk/models.py index 197a17397..2891c4216 100644 --- a/py/sdk/models.py +++ b/py/sdk/models.py @@ -12,7 +12,6 @@ KGGlobalResult, KGRelationshipResult, KGRunType, - KGSearchMethod, KGSearchResultType, Message, MessageType, @@ -23,9 +22,7 @@ ) from shared.api.models import ( CombinedSearchResponse, - KGCreationResponse, KGEnrichmentResponse, - KGEntityDeduplicationResponse, RAGResponse, UserResponse, ) @@ -40,7 +37,6 @@ "KGGlobalResult", "KGRelationshipResult", "KGRunType", - "KGSearchMethod", "GraphSearchResult", "KGSearchResultType", "GraphSearchSettings", @@ -52,8 +48,6 @@ "ChunkSearchResult", "SearchSettings", "KGEntityDeduplicationSettings", - "KGEntityDeduplicationResponse", - "KGCreationResponse", "KGEnrichmentResponse", "RAGResponse", "CombinedSearchResponse", diff --git a/py/sdk/v2/kg.py b/py/sdk/v2/kg.py index 6d930b812..e3c8412d3 100644 --- a/py/sdk/v2/kg.py +++ b/py/sdk/v2/kg.py @@ -4,7 +4,6 @@ from ..models import ( KGCreationSettings, KGEnrichmentSettings, - KGEntityDeduplicationResponse, KGEntityDeduplicationSettings, KGRunType, ) @@ -216,7 +215,7 @@ async def deduplicate_entities( deduplication_settings: Optional[ Union[dict, KGEntityDeduplicationSettings] ] = None, - ) -> KGEntityDeduplicationResponse: + ): """ Deduplicate entities in the knowledge graph. Args: diff --git a/py/sdk/v2/sync_kg.py b/py/sdk/v2/sync_kg.py index bd48898d8..d0c1cd88b 100644 --- a/py/sdk/v2/sync_kg.py +++ b/py/sdk/v2/sync_kg.py @@ -4,7 +4,6 @@ from ..models import ( KGCreationSettings, KGEnrichmentSettings, - KGEntityDeduplicationResponse, KGEntityDeduplicationSettings, KGRunType, ) @@ -216,7 +215,7 @@ def deduplicate_entities( deduplication_settings: Optional[ Union[dict, KGEntityDeduplicationSettings] ] = None, - ) -> KGEntityDeduplicationResponse: + ): """ Deduplicate entities in the knowledge graph. Args: diff --git a/py/sdk/v3/collections.py b/py/sdk/v3/collections.py index 2b4b05e29..ef2ac9fd1 100644 --- a/py/sdk/v3/collections.py +++ b/py/sdk/v3/collections.py @@ -36,7 +36,7 @@ async def create( return await self.client._make_request( "POST", "collections", - json=data, # {"config": data} + json=data, version="v3", ) @@ -111,7 +111,7 @@ async def update( return await self.client._make_request( "POST", f"collections/{str(id)}", - json=data, # {"config": data} + json=data, version="v3", ) diff --git a/py/sdk/v3/conversations.py b/py/sdk/v3/conversations.py index bc6b0aad6..4fa3887a8 100644 --- a/py/sdk/v3/conversations.py +++ b/py/sdk/v3/conversations.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any, Optional from uuid import UUID from shared.api.models.base import WrappedBooleanResponse @@ -125,7 +125,7 @@ async def add_message( Returns: dict: Result of the operation, including the new message ID """ - data = { + data: dict[str, Any] = { "content": content, "role": role, } diff --git a/py/sdk/v3/documents.py b/py/sdk/v3/documents.py index 0050dab99..9fdbbc2bf 100644 --- a/py/sdk/v3/documents.py +++ b/py/sdk/v3/documents.py @@ -324,3 +324,23 @@ async def delete_by_filter( params={"filters": filters_json}, version="v3", ) + + async def extract( + self, + id: str | UUID, + run_type: Optional[str] = "estimate", + run_with_orchestration: Optional[bool] = True, + ): + data = {} + + if run_type: + data["run_type"] = run_type + if run_with_orchestration is not None: + data["run_with_orchestration"] = str(run_with_orchestration) + + return await self.client._make_request( + "POST", + f"documents/{str(id)}/extract", + data=data, + version="v3", + ) diff --git a/py/sdk/v3/graphs.py b/py/sdk/v3/graphs.py index 99905029c..c5732ae7b 100644 --- a/py/sdk/v3/graphs.py +++ b/py/sdk/v3/graphs.py @@ -1,9 +1,19 @@ -from typing import Optional +from typing import Any, Optional from uuid import UUID -from core.base.abstractions import DataLevel, KGRunType +from shared.api.models.base import WrappedBooleanResponse +from shared.api.models.kg.responses import ( + WrappedCommunitiesResponse, + WrappedCommunityResponse, + WrappedEntitiesResponse, + WrappedEntityResponse, + WrappedGraphResponse, + WrappedGraphsResponse, + WrappedRelationshipResponse, + WrappedRelationshipsResponse, +) -from ..models import KGCreationSettings, KGRunType +_list = list # Required for type hinting since we have a list method class GraphsSDK: @@ -14,254 +24,206 @@ class GraphsSDK: def __init__(self, client): self.client = client - async def create( + async def list( self, - collection_id: str | UUID, - run_type: Optional[str | KGRunType] = None, - settings: Optional[dict | KGCreationSettings] = None, - run_with_orchestration: Optional[bool] = True, - ): + collection_ids: Optional[list[str | UUID]] = None, + offset: Optional[int] = 0, + limit: Optional[int] = 100, + ) -> WrappedGraphsResponse: """ - Create a new knowledge graph for a collection. + List graphs with pagination and filtering options. Args: - collection_id (str | UUID): Collection ID to create graph for - settings (Optional[dict]): Graph creation settings - run_with_orchestration (Optional[bool]): Whether to run with task orchestration + ids (Optional[list[str | UUID]]): Filter graphs by ids + offset (int, optional): Specifies the number of objects to skip. Defaults to 0. + limit (int, optional): Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. Returns: - WrappedKGCreationResponse: Creation results + dict: List of graphs and pagination information """ - if isinstance(settings, KGCreationSettings): - settings = settings.model_dump() - - data = { - # "collection_id": str(collection_id) if collection_id else None, - "run_type": str(run_type) if run_type else None, - "settings": settings or {}, - "run_with_orchestration": run_with_orchestration or True, + params: dict = { + "offset": offset, + "limit": limit, } + if collection_ids: + params["collection_ids"] = collection_ids - return await self.client._make_request("POST", f"graphs/{collection_id}", json=data) # type: ignore + return await self.client._make_request( + "GET", "graphs", params=params, version="v3" + ) - async def get_status(self, collection_id: str | UUID) -> dict: + async def retrieve( + self, + collection_id: str | UUID, + ) -> WrappedGraphResponse: """ - Get the status of a graph. + Get detailed information about a specific graph. Args: - collection_id (str | UUID): Collection ID to get graph status for + collection_id (str | UUID): Graph ID to retrieve Returns: - dict: Graph status information + dict: Detailed graph information """ return await self.client._make_request( - "GET", f"graphs/{str(collection_id)}" + "GET", f"graphs/{str(collection_id)}", version="v3" ) - async def delete( + async def reset( self, collection_id: str | UUID, - cascade: bool = False, - ) -> dict: + ) -> WrappedBooleanResponse: """ - Delete a graph. + Deletes a graph and all its associated data. + + This endpoint permanently removes the specified graph along with all + entities and relationships that belong to only this graph. + + Entities and relationships extracted from documents are not deleted. Args: - collection_id (str | UUID): Collection ID of graph to delete - cascade (bool): Whether to delete associated entities and relationships + collection_id (str | UUID): Graph ID to reset Returns: - dict: Deletion confirmation + dict: Success message """ - params = {"cascade": cascade} return await self.client._make_request( - "DELETE", f"graphs/{str(collection_id)}", params=params + "POST", f"graphs/{str(collection_id)}/reset", version="v3" ) - # Entity operations - async def create_entity( + async def update( self, collection_id: str | UUID, - entity: dict, - ) -> dict: + name: Optional[str] = None, + description: Optional[str] = None, + ) -> WrappedGraphResponse: """ - Create a new entity in the graph. + Update graph information. Args: - collection_id (str | UUID): Collection ID to create entity in - entity (dict): Entity data including name, type, and metadata + collection_id (str | UUID): The collection ID corresponding to the graph + name (Optional[str]): Optional new name for the graph + description (Optional[str]): Optional new description for the graph Returns: - dict: Created entity information + dict: Updated graph information """ + data = {} + if name is not None: + data["name"] = name + if description is not None: + data["description"] = description + return await self.client._make_request( "POST", - f"graphs/{str(collection_id)}/entities", - json=entity, + f"graphs/{str(collection_id)}", + json=data, version="v3", ) - async def get_entity( + # TODO: create entity + + async def list_entities( self, collection_id: str | UUID, - entity_id: str | int, - include_embeddings: bool = False, - ) -> dict: + offset: Optional[int] = 0, + limit: Optional[int] = 100, + ) -> WrappedEntitiesResponse: """ - Get details of a specific entity. + List entities in a graph. Args: - collection_id (str | UUID): Collection ID containing the entity - entity_id (str | UUID): Entity ID to retrieve - include_embeddings (bool): Whether to include vector embeddings + collection_id (str | UUID): Graph ID to list entities from + offset (int, optional): Specifies the number of objects to skip. Defaults to 0. + limit (int, optional): Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. Returns: - dict: Entity details + dict: List of entities and pagination information """ - params = {"include_embeddings": include_embeddings} + params: dict = { + "offset": offset, + "limit": limit, + } + return await self.client._make_request( "GET", - f"graphs/{str(collection_id)}/entities/{str(entity_id)}", + f"graphs/{str(collection_id)}/entities", params=params, version="v3", ) - async def update_entity( + async def get_entity( self, collection_id: str | UUID, entity_id: str | UUID, - entity_update: dict, - ) -> dict: + ) -> WrappedEntityResponse: """ - Update an existing entity. + Get entity information in a graph. Args: - collection_id (str | UUID): Collection ID containing the entity - entity_id (str | UUID): Entity ID to update - entity_update (dict): Updated entity data + collection_id (str | UUID): The collection ID corresponding to the graph + entity_id (str | UUID): Entity ID to get from the graph Returns: - dict: Updated entity information + dict: Entity information """ return await self.client._make_request( - "POST", + "GET", f"graphs/{str(collection_id)}/entities/{str(entity_id)}", - json=entity_update, version="v3", ) - async def delete_entity( + # TODO: update entity + + async def remove_entity( self, collection_id: str | UUID, entity_id: str | UUID, - cascade: bool = False, - ) -> dict: + ) -> WrappedBooleanResponse: """ - Delete an entity. + Remove an entity from a graph. Args: - collection_id (str | UUID): Collection ID containing the entity - entity_id (str | UUID): Entity ID to delete - cascade (bool): Whether to delete related relationships + collection_id (str | UUID): The collection ID corresponding to the graph + entity_id (str | UUID): Entity ID to remove from the graph Returns: - dict: Deletion confirmation + dict: Success message """ - params = {"cascade": cascade} return await self.client._make_request( "DELETE", f"graphs/{str(collection_id)}/entities/{str(entity_id)}", - params=params, version="v3", ) - async def list_entities( + # TODO: create relationship + + async def list_relationships( self, collection_id: str | UUID, - level=DataLevel.DOCUMENT, - include_embeddings: bool = False, offset: Optional[int] = 0, limit: Optional[int] = 100, - ) -> dict: + ) -> WrappedRelationshipsResponse: """ - List entities in the graph. + List relationships in a graph. Args: - collection_id (str | UUID): Collection ID to list entities from - level (DataLevel): Entity level filter - include_embeddings (bool): Whether to include vector embeddings + collection_id (str | UUID): The collection ID corresponding to the graph offset (int, optional): Specifies the number of objects to skip. Defaults to 0. limit (int, optional): Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. Returns: - dict: List of entities and pagination information + dict: List of relationships and pagination information """ - params = { - "level": level, + params: dict = { "offset": offset, "limit": limit, - "include_embeddings": include_embeddings, } - return await self.client._make_request( - "GET", - f"graphs/{str(collection_id)}/entities", - params=params, - version="v3", - ) - - async def deduplicate_entities( - self, - collection_id: str | UUID, - settings: Optional[dict] = None, - run_type: str = "ESTIMATE", - run_with_orchestration: bool = True, - ): - """ - Deduplicate entities in the graph. - - Args: - collection_id (str | UUID): Collection ID to deduplicate entities in - settings (Optional[dict]): Deduplication settings - run_type (str): Whether to estimate cost or run deduplication - run_with_orchestration (bool): Whether to run with task orchestration - - Returns: - WrappedKGEntityDeduplicationResponse: Deduplication results or cost estimate - """ - params = { - "run_type": run_type, - "run_with_orchestration": run_with_orchestration, - } - data = {} - if settings: - data["settings"] = settings return await self.client._make_request( - "POST", - f"graphs/{str(collection_id)}/entities/deduplicate", - json=data, - params=params, - version="v3", - ) - - # Relationship operations - async def create_relationship( - self, collection_id: str | UUID, relationship: dict - ) -> dict: - """ - Create a new relationship between entities. - - Args: - collection_id (str | UUID): Collection ID to create relationship in - relationship (dict): Relationship data including source, target, and type - - Returns: - dict: Created relationship information - """ - return await self.client._make_request( - "POST", + "GET", f"graphs/{str(collection_id)}/relationships", - json=relationship, + params=params, version="v3", ) @@ -269,16 +231,16 @@ async def get_relationship( self, collection_id: str | UUID, relationship_id: str | UUID, - ) -> dict: + ) -> WrappedRelationshipResponse: """ - Get details of a specific relationship. + Get relationship information in a graph. Args: - collection_id (str | UUID): Collection ID containing the relationship - relationship_id (str | UUID): Relationship ID to retrieve + collection_id (str | UUID): The collection ID corresponding to the graph + relationship_id (str | UUID): Relationship ID to get from the graph Returns: - dict: Relationship details + dict: Relationship information """ return await self.client._make_request( "GET", @@ -286,123 +248,88 @@ async def get_relationship( version="v3", ) - async def update_relationship( + # TODO: update relationship + + async def remove_relationship( self, collection_id: str | UUID, relationship_id: str | UUID, - relationship_update: dict, - ) -> dict: + ) -> WrappedBooleanResponse: """ - Update an existing relationship. + Remove a relationship from a graph. Args: - collection_id (str | UUID): Collection ID containing the relationship - relationship_id (str | UUID): Relationship ID to update - relationship_update (dict): Updated relationship data + collection_id (str | UUID): The collection ID corresponding to the graph + relationship_id (str | UUID): Relationship ID to remove from the graph Returns: - dict: Updated relationship information + dict: Success message """ return await self.client._make_request( - "POST", + "DELETE", f"graphs/{str(collection_id)}/relationships/{str(relationship_id)}", - json=relationship_update, version="v3", ) - async def delete_relationship( + async def build( self, collection_id: str | UUID, - relationship_id: str | UUID, - ) -> dict: + settings: dict, + run_type: str = "estimate", + run_with_orchestration: bool = True, + ) -> WrappedBooleanResponse: """ - Delete a relationship. + Build a graph. Args: - collection_id (str | UUID): Collection ID containing the relationship - relationship_id (str | UUID): Relationship ID to delete + collection_id (str | UUID): The collection ID corresponding to the graph + settings (dict): Settings for the build + run_type (str, optional): Type of build to run. Defaults to "estimate". + run_with_orchestration (bool, optional): Whether to run with orchestration. Defaults to True. Returns: - dict: Deletion confirmation + dict: Success message """ + data = { + "settings": settings, + "run_type": run_type, + "run_with_orchestration": run_with_orchestration, + } + return await self.client._make_request( - "DELETE", - f"graphs/{str(collection_id)}/relationships/{str(relationship_id)}", + "POST", + f"graphs/{str(collection_id)}/build", + json=data, version="v3", ) - async def list_relationships( + # TODO: create community + + async def list_communities( self, collection_id: str | UUID, - source_id: Optional[str | UUID] = None, - target_id: Optional[str | UUID] = None, - relationship_type: Optional[str] = None, offset: Optional[int] = 0, limit: Optional[int] = 100, - ) -> dict: + ) -> WrappedCommunitiesResponse: """ - List relationships in the graph. + List communities in a graph. Args: - collection_id (str | UUID): Collection ID to list relationships from - source_id (Optional[str | UUID]): Filter by source entity - target_id (Optional[str | UUID]): Filter by target entity - relationship_type (Optional[str]): Filter by relationship type + collection_id (str | UUID): The collection ID corresponding to the graph offset (int, optional): Specifies the number of objects to skip. Defaults to 0. limit (int, optional): Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. Returns: - dict: List of relationships and pagination information + dict: List of communities and pagination information """ - params = { + params: dict = { "offset": offset, "limit": limit, } - if source_id: - params["source_id"] = str(source_id) - if target_id: - params["target_id"] = str(target_id) - if relationship_type: - params["relationship_type"] = relationship_type return await self.client._make_request( "GET", - f"graphs/{str(collection_id)}/relationships", - params=params, - version="v3", - ) - - # Community operations - async def create_communities( - self, - collection_id: str | UUID, - run_type: Optional[str | KGRunType] = None, - settings: Optional[dict] = None, - run_with_orchestration: bool = True, - ): # -> WrappedKGCommunitiesResponse: - """ - Create communities in the graph. - - Args: - collection_id (str | UUID): Collection ID to create communities in - settings (Optional[dict]): Community detection settings - run_with_orchestration (bool): Whether to run with task orchestration - - Returns: - WrappedKGCommunitiesResponse: Community creation results - """ - params = {"run_with_orchestration": run_with_orchestration} - data = {} - if settings: - data["settings"] = settings - - if run_type: - data["run_type"] = str(run_type) - - return await self.client._make_request( - "POST", f"graphs/{str(collection_id)}/communities", - json=data, params=params, version="v3", ) @@ -411,16 +338,16 @@ async def get_community( self, collection_id: str | UUID, community_id: str | UUID, - ) -> dict: + ) -> WrappedCommunityResponse: """ - Get details of a specific community. + Get community information in a graph. Args: - collection_id (str | UUID): Collection ID containing the community - community_id (str | UUID): Community ID to retrieve + collection_id (str | UUID): The collection ID corresponding to the graph + community_id (str | UUID): Community ID to get from the graph Returns: - dict: Community details + dict: Community information """ return await self.client._make_request( "GET", @@ -432,56 +359,51 @@ async def update_community( self, collection_id: str | UUID, community_id: str | UUID, - community_update: dict, - ) -> dict: + name: Optional[str] = None, + summary: Optional[str] = None, + findings: Optional[_list[str]] = None, + rating: Optional[int] = None, + rating_explanation: Optional[str] = None, + level: Optional[int] = None, + attributes: Optional[dict] = None, + ) -> WrappedCommunityResponse: """ - Update a community. + Update community information. Args: - collection_id (str | UUID): Collection ID containing the community + collection_id (str | UUID): The collection ID corresponding to the graph community_id (str | UUID): Community ID to update - community_update (dict): Updated community data + name (Optional[str]): Optional new name for the community + summary (Optional[str]): Optional new summary for the community + findings (Optional[list[str]]): Optional new findings for the community + rating (Optional[int]): Optional new rating for the community + rating_explanation (Optional[str]): Optional new rating explanation for the community + level (Optional[int]): Optional new level for the community + attributes (Optional[dict]): Optional new attributes for the community Returns: dict: Updated community information """ - return await self.client._make_request( - "POST", - f"graphs/{str(collection_id)}/communities/{str(community_id)}", - json=community_update, - version="v3", - ) - - async def list_communities( - self, - collection_id: str | UUID, - level: Optional[int] = None, - offset: Optional[int] = 0, - limit: Optional[int] = 100, - ) -> dict: - """ - List communities in the graph. - - Args: - collection_id (str | UUID): Collection ID to list communities from - level (Optional[int]): Filter by community level - offset (int, optional): Specifies the number of objects to skip. Defaults to 0. - limit (int, optional): Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100. - - Returns: - dict: List of communities and pagination information - """ - params = { - "offset": offset, - "limit": limit, - } + data: dict[str, Any] = {} + if name is not None: + data["name"] = name + if summary is not None: + data["summary"] = summary + if findings is not None: + data["findings"] = findings + if rating is not None: + data["rating"] = str(rating) + if rating_explanation is not None: + data["rating_explanation"] = rating_explanation if level is not None: - params["level"] = level + data["level"] = level + if attributes is not None: + data["attributes"] = attributes return await self.client._make_request( - "GET", - f"graphs/{str(collection_id)}/communities", - params=params, + "POST", + f"graphs/{str(collection_id)}/communities/{str(community_id)}", + json=data, version="v3", ) @@ -489,16 +411,16 @@ async def delete_community( self, collection_id: str | UUID, community_id: str | UUID, - ) -> dict: + ) -> WrappedBooleanResponse: """ - Delete a specific community. + Remove a community from a graph. Args: - collection_id (str | UUID): Collection ID containing the community - community_id (str | UUID): Community ID to delete + collection_id (str | UUID): The collection ID corresponding to the graph + community_id (str | UUID): Community ID to remove from the graph Returns: - dict: Deletion confirmation + dict: Success message """ return await self.client._make_request( "DELETE", @@ -506,67 +428,50 @@ async def delete_community( version="v3", ) - async def delete_communities( + async def pull( self, collection_id: str | UUID, - level: Optional[int] = None, - ) -> dict: + ) -> WrappedBooleanResponse: """ - Delete communities from the graph. + Adds documents to a graph by copying their entities and relationships. - Args: - collection_id (str | UUID): Collection ID to delete communities from - level (Optional[int]): Specific level to delete, or None for all levels + This endpoint: + 1. Copies document entities to the graph_entity table + 2. Copies document relationships to the graph_relationship table + 3. Associates the documents with the graph - Returns: - dict: Deletion confirmation - """ - params = {} - if level is not None: - params["level"] = level + When a document is added: + - Its entities and relationships are copied to graph-specific tables + - Existing entities/relationships are updated by merging their properties + - The document ID is recorded in the graph's document_ids array + Documents added to a graph will contribute their knowledge to: + - Graph analysis and querying + - Community detection + - Knowledge graph enrichment + """ return await self.client._make_request( - "DELETE", - f"graphs/{str(collection_id)}/communities", - params=params, + "POST", + f"graphs/{str(collection_id)}/pull", version="v3", ) - async def tune_prompt( + async def remove_document( self, collection_id: str | UUID, - prompt_name: str, - documents_offset: Optional[int] = 0, - documents_limit: Optional[int] = 100, - chunks_offset: Optional[int] = 0, - chunks_limit: Optional[int] = 100, - ): # -> WrappedKGTunePromptResponse: + document_id: str | UUID, + ) -> WrappedBooleanResponse: """ - Tune a graph-related prompt using collection data. + Removes a document from a graph and removes any associated entities - Args: - collection_id (Union[str, UUID]): Collection ID to tune prompt for - prompt_name (str): Name of prompt to tune (graphrag_relationships_extraction_few_shot, - graphrag_entity_description, or graphrag_communities) - documents_offset (int): Document pagination offset - documents_limit (int): Maximum number of documents to use - chunks_offset (int): Chunk pagination offset - chunks_limit (int): Maximum number of chunks to use + This endpoint: + 1. Removes the document ID from the graph's document_ids array + 2. Optionally deletes the document's copied entities and relationships - Returns: - WrappedKGTunePromptResponse: Tuned prompt results + The user must have access to both the graph and the document being removed. """ - data = { - "prompt_name": prompt_name, - "documents_offset": documents_offset, - "documents_limit": documents_limit, - "chunks_offset": chunks_offset, - "chunks_limit": chunks_limit, - } - return await self.client._make_request( - "POST", - f"graphs/{str(collection_id)}/tune-prompt", - json=data, + "DELETE", + f"graphs/{str(collection_id)}/documents/{str(document_id)}", version="v3", ) diff --git a/py/shared/abstractions/__init__.py b/py/shared/abstractions/__init__.py index 8c55df1b8..8e70a0008 100644 --- a/py/shared/abstractions/__init__.py +++ b/py/shared/abstractions/__init__.py @@ -50,7 +50,6 @@ KGEntityResult, KGGlobalResult, KGRelationshipResult, - KGSearchMethod, KGSearchResultType, SearchSettings, ) @@ -110,7 +109,6 @@ # Search abstractions "AggregateSearchResult", "GraphSearchResult", - "KGSearchMethod", "KGSearchResultType", "KGEntityResult", "KGRelationshipResult", diff --git a/py/shared/abstractions/graph.py b/py/shared/abstractions/graph.py index 146909b9c..71221546f 100644 --- a/py/shared/abstractions/graph.py +++ b/py/shared/abstractions/graph.py @@ -36,24 +36,15 @@ class Entity(R2RSerializable): """An entity extracted from a document.""" name: str - # id is Union of UUID and int for backwards compatibility - # we will migrate to UUID only in the future - # sid is also deprecated and needs to be removed in the future - id: Optional[UUID | int] = None - category: Optional[str] = None description: Optional[str] = None + category: Optional[str] = None + metadata: Optional[dict[str, Any] | str] = None + + id: Optional[UUID] = None parent_id: Optional[UUID] = None # graph_id | document_id - # document_ids: list[UUID] = [] description_embedding: Optional[list[float] | str] = None - chunk_ids: Optional[list[UUID]] = [] - # we don't use these yet - # name_embedding: Optional[list[float]] = None - # graph_embedding: Optional[list[float]] = None - # rank: Optional[int] = None - metadata: Optional[dict[str, Any] | str] = None - def __str__(self): return f"{self.name}:{self.category}" @@ -69,7 +60,6 @@ def __init__(self, **kwargs): class Relationship(R2RSerializable): """A relationship between two entities. This is a generic relationship, and can be used to represent any type of relationship between any two entities.""" - # id is Union of UUID and int for backwards compatibility id: Optional[UUID] = None subject: str predicate: str @@ -78,7 +68,7 @@ class Relationship(R2RSerializable): subject_id: Optional[UUID] = None object_id: Optional[UUID] = None weight: float | None = 1.0 - chunk_ids: list[UUID] = [] + chunk_ids: Optional[list[UUID]] = [] parent_id: Optional[UUID] = None description_embedding: Optional[list[float] | str] = None @@ -99,10 +89,10 @@ class CommunityInfo(R2RSerializable): node: str cluster: UUID - level: int - id: Optional[UUID | int] = None + level: Optional[int] parent_cluster: int | None is_final_cluster: bool + id: Optional[UUID | int] = None graph_id: Optional[UUID] = None collection_id: Optional[UUID] = None # for backwards compatibility relationship_ids: Optional[list[UUID]] = None @@ -114,10 +104,10 @@ def __init__(self, **kwargs): @dataclass class Community(R2RSerializable): - level: int name: str = "" summary: str = "" + level: Optional[int] = None findings: list[str] = [] id: Optional[int | UUID] = None community_id: Optional[UUID] = None @@ -127,6 +117,12 @@ class Community(R2RSerializable): rating_explanation: str | None = None description_embedding: list[float] | None = None attributes: dict[str, Any] | None = None + created_at: datetime = Field( + default_factory=datetime.utcnow, + ) + updated_at: datetime = Field( + default_factory=datetime.utcnow, + ) def __init__(self, **kwargs): if isinstance(kwargs.get("attributes", None), str): diff --git a/py/shared/abstractions/search.py b/py/shared/abstractions/search.py index 991588282..fd3502691 100644 --- a/py/shared/abstractions/search.py +++ b/py/shared/abstractions/search.py @@ -62,10 +62,6 @@ class KGSearchResultType(str, Enum): COMMUNITY = "community" -class KGSearchMethod(str, Enum): - LOCAL = "local" - - class KGEntityResult(R2RSerializable): name: str description: str diff --git a/py/shared/api/models/__init__.py b/py/shared/api/models/__init__.py index d4880e3d6..4cfc484fa 100644 --- a/py/shared/api/models/__init__.py +++ b/py/shared/api/models/__init__.py @@ -18,14 +18,10 @@ ) from shared.api.models.kg.responses import ( # TODO: Need to review anything above this GraphResponse, - KGCreationResponse, KGEnrichmentResponse, - KGEntityDeduplicationResponse, WrappedGraphResponse, WrappedGraphsResponse, - WrappedKGCreationResponse, WrappedKGEnrichmentResponse, - WrappedKGEntityDeduplicationResponse, ) from shared.api.models.management.responses import ( # Chunk Responses; Conversation Responses; Document Responses; Collection Responses; Prompt Responses; System Responses; User Responses AnalyticsResponse, @@ -78,12 +74,8 @@ "WrappedUpdateResponse", "WrappedMetadataUpdateResponse", # Restructure Responses - "KGCreationResponse", "KGEnrichmentResponse", - "KGEntityDeduplicationResponse", - "WrappedKGCreationResponse", "WrappedKGEnrichmentResponse", - "WrappedKGEntityDeduplicationResponse", # TODO: Need to review anything above this "GraphResponse", "WrappedGraphResponse", diff --git a/py/shared/api/models/kg/responses.py b/py/shared/api/models/kg/responses.py index a72d606ed..e58f3d15f 100644 --- a/py/shared/api/models/kg/responses.py +++ b/py/shared/api/models/kg/responses.py @@ -119,25 +119,6 @@ class KGDeduplicationEstimate(R2RSerializable): ) -class KGCreationResponse(BaseModel): - message: str = Field( - ..., - description="A message describing the result of the KG creation request.", - ) - id: Optional[UUID] = Field( - None, - description="The ID of the created object.", - ) - task_id: Optional[UUID] = Field( - None, - description="The task ID of the KG creation request.", - ) - estimate: Optional[KGCreationEstimate] = Field( - None, - description="The estimation of the KG creation request.", - ) - - class Config: json_schema_extra = { "example": { @@ -195,40 +176,6 @@ class Config: } -class KGEntityDeduplicationResponse(BaseModel): - """Response for knowledge graph entity deduplication.""" - - message: str = Field( - ..., - description="The message to display to the user.", - ) - - task_id: Optional[UUID] = Field( - None, - description="The task ID of the KG entity deduplication request.", - ) - - estimate: Optional[KGDeduplicationEstimate] = Field( - None, - description="The estimation of the KG entity deduplication request.", - ) - - class Config: - json_schema_extra = { - "example": { - "message": "Entity deduplication queued successfully.", - "task_id": "c68dc72e-fc23-5452-8f49-d7bd46088a96", - "estimate": { - "num_entities": 1000, - "estimated_llm_calls": "1000", - "estimated_total_in_out_tokens_in_millions": "1000", - "estimated_cost_in_usd": "1000", - "estimated_total_time_in_minutes": "1000", - }, - } - } - - class KGTunePromptResponse(R2RSerializable): """Response containing just the tuned prompt string.""" @@ -250,12 +197,8 @@ class Config: # CREATE -WrappedKGCreationResponse = ResultsWrapper[KGCreationResponse] WrappedKGEnrichmentResponse = ResultsWrapper[KGEnrichmentResponse] WrappedKGTunePromptResponse = ResultsWrapper[KGTunePromptResponse] -WrappedKGEntityDeduplicationResponse = ResultsWrapper[ - KGEntityDeduplicationResponse -] class GraphResponse(BaseModel):