Skip to content

Commit

Permalink
Refactor Python SDK for Intellisense, Thread Safety (#1430)
Browse files Browse the repository at this point in the history
* Refactor Python SDK

* Fix CLI after SDK changes

* Add convo to agent

* Update conversation error handling, JS

* Remove unused, bad import
  • Loading branch information
NolanTrem authored Oct 20, 2024
1 parent 65a3a51 commit c9b6549
Show file tree
Hide file tree
Showing 46 changed files with 534 additions and 887 deletions.
17 changes: 12 additions & 5 deletions js/sdk/__tests__/r2rClientIntegrationSuperUser.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -181,16 +181,23 @@ describe("r2rClient Integration Tests", () => {
{ role: "user", content: "Tell me about Raskolnikov." },
];

const stream = await client.agent(messages, undefined, undefined, {
stream: true,
});
const stream = await client.agent(messages, { stream: true });

expect(stream).toBeDefined();

let fullResponse = "";

for await (const chunk of stream) {
fullResponse += chunk;
if (stream && stream.getReader) {
const reader = stream.getReader();
while (true) {
const { done, value } = await reader.read();
if (done) {
break;
}
fullResponse += new TextDecoder().decode(value);
}
} else {
throw new Error('Stream is not a ReadableStream');
}

expect(fullResponse.length).toBeGreaterThan(0);
Expand Down
2 changes: 1 addition & 1 deletion js/sdk/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion js/sdk/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "r2r-js",
"version": "0.3.8",
"version": "0.3.9",
"description": "",
"main": "dist/index.js",
"browser": "dist/index.browser.js",
Expand Down
10 changes: 8 additions & 2 deletions js/sdk/src/r2rClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1518,21 +1518,25 @@ export class r2rClient {
/**
* Performs a single turn in a conversation with a RAG agent.
* @param messages The messages to send to the agent.
* @param rag_generation_config RAG generation configuration.
* @param vector_search_settings Vector search settings.
* @param kg_search_settings KG search settings.
* @param rag_generation_config RAG generation configuration.
* @param task_prompt_override Task prompt override.
* @param include_title_if_available Include title if available.
* @param conversation_id The ID of the conversation, if not a new conversation.
* @param branch_id The ID of the branch to use, if not a new branch.
* @returns A promise that resolves to the response from the server.
*/
@feature("agent")
async agent(
messages: Message[],
rag_generation_config?: GenerationConfig | Record<string, any>,
vector_search_settings?: VectorSearchSettings | Record<string, any>,
kg_search_settings?: KGSearchSettings | Record<string, any>,
rag_generation_config?: GenerationConfig | Record<string, any>,
task_prompt_override?: string,
include_title_if_available?: boolean,
conversation_id?: string,
branch_id?: string,
): Promise<any | AsyncGenerator<string, void, unknown>> {
this._ensureAuthenticated();

Expand All @@ -1543,6 +1547,8 @@ export class r2rClient {
rag_generation_config,
task_prompt_override,
include_title_if_available,
conversation_id,
branch_id,
};

Object.keys(json_data).forEach(
Expand Down
4 changes: 2 additions & 2 deletions py/cli/command_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from asyncclick import pass_context
from asyncclick.exceptions import Exit

from sdk.client import R2RClient
from r2r import R2RAsyncClient


@click.group()
Expand All @@ -13,7 +13,7 @@
async def cli(ctx, base_url):
"""R2R CLI for all core operations."""

ctx.obj = R2RClient(base_url=base_url)
ctx.obj = R2RAsyncClient(base_url=base_url)

# Override the default exit behavior
def silent_exit(self, code=0):
Expand Down
39 changes: 16 additions & 23 deletions py/cli/commands/ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from cli.utils.timer import timer


def ingest_files_from_urls(client, urls):
async def ingest_files_from_urls(client, urls):
"""Download and ingest files from given URLs."""
files_to_ingest = []
metadatas = []
Expand Down Expand Up @@ -45,7 +45,7 @@ def ingest_files_from_urls(client, urls):
# TODO: use the utils function generate_document_id
document_ids.append(uuid.uuid5(uuid.NAMESPACE_DNS, url))

response = client.ingest_files(
response = await client.ingest_files(
files_to_ingest, metadatas=metadatas, document_ids=document_ids
)

Expand All @@ -67,24 +67,15 @@ def ingest_files_from_urls(client, urls):
"--metadatas", type=JSON, help="Metadatas for ingestion as a JSON string"
)
@pass_context
def ingest_files(ctx, file_paths, document_ids, metadatas):
async def ingest_files(ctx, file_paths, document_ids, metadatas):
"""Ingest files into R2R."""
client = ctx.obj
with timer():
file_paths = list(file_paths)
document_ids = list(document_ids) if document_ids else None
response = client.ingest_files(file_paths, metadatas, document_ids)
click.echo(json.dumps(response, indent=2))


@cli.command()
@click.argument("document_ids", nargs=-1, required=True, type=click.UUID)
@pass_context
def retry_ingest_files(ctx, document_ids):
"""Retry ingestion for failed documents."""
client = ctx.obj
with timer():
response = client.retry_ingest_files(document_ids)
response = await client.ingest_files(
file_paths, metadatas, document_ids
)
click.echo(json.dumps(response, indent=2))


Expand All @@ -101,7 +92,7 @@ def retry_ingest_files(ctx, document_ids):
"--metadatas", type=JSON, help="Metadatas for updating as a JSON string"
)
@pass_context
def update_files(ctx, file_paths, document_ids, metadatas):
async def update_files(ctx, file_paths, document_ids, metadatas):
"""Update existing files in R2R."""
client = ctx.obj
with timer():
Expand All @@ -119,7 +110,9 @@ def update_files(ctx, file_paths, document_ids, metadatas):
"Metadatas must be a JSON string representing a list of dictionaries or a single dictionary"
)

response = client.update_files(file_paths, document_ids, metadatas)
response = await client.update_files(
file_paths, document_ids, metadatas
)
click.echo(json.dumps(response, indent=2))


Expand All @@ -128,21 +121,21 @@ def update_files(ctx, file_paths, document_ids, metadatas):
"--v2", is_flag=True, help="use aristotle_v2.txt (a smaller file)"
)
@pass_context
def ingest_sample_file(ctx, v2=False):
async def ingest_sample_file(ctx, v2=False):
"""Ingest the first sample file into R2R."""
sample_file_url = f"https://raw.githubusercontent.com/SciPhi-AI/R2R/main/py/core/examples/data/aristotle{'_v2' if v2 else ''}.txt"
client = ctx.obj

with timer():
response = ingest_files_from_urls(client, [sample_file_url])
response = await ingest_files_from_urls(client, [sample_file_url])
click.echo(
f"Sample file ingestion completed. Ingest files response:\n\n{response}"
)


@cli.command()
@pass_context
def ingest_sample_files(ctx):
async def ingest_sample_files(ctx):
"""Ingest multiple sample files into R2R."""
client = ctx.obj
urls = [
Expand All @@ -157,7 +150,7 @@ def ingest_sample_files(ctx):
"https://raw.githubusercontent.com/SciPhi-AI/R2R/main/py/core/examples/data/pg_essay_2.html",
]
with timer():
response = ingest_files_from_urls(client, urls)
response = await ingest_files_from_urls(client, urls)

click.echo(
f"Sample files ingestion completed. Ingest files response:\n\n{response}"
Expand All @@ -166,7 +159,7 @@ def ingest_sample_files(ctx):

@cli.command()
@pass_context
def ingest_sample_files_from_unstructured(ctx):
async def ingest_sample_files_from_unstructured(ctx):
"""Ingest multiple sample files from URLs into R2R."""
client = ctx.obj

Expand All @@ -184,7 +177,7 @@ def ingest_sample_files_from_unstructured(ctx):
file_paths = [os.path.join(folder, file) for file in os.listdir(folder)]

with timer():
response = client.ingest_files(file_paths)
response = await client.ingest_files(file_paths)

click.echo(
f"Sample files ingestion completed. Ingest files response:\n\n{response}"
Expand Down
23 changes: 10 additions & 13 deletions py/cli/commands/kg.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
help="Force the graph creation process.",
)
@pass_context
def create_graph(
async def create_graph(
ctx, collection_id, run, kg_creation_settings, force_kg_creation
):
client = ctx.obj
Expand All @@ -52,7 +52,7 @@ def create_graph(
kg_creation_settings = {"force_kg_creation": True}

with timer():
response = client.create_graph(
response = await client.create_graph(
collection_id=collection_id,
run_type=run_type,
kg_creation_settings=kg_creation_settings,
Expand Down Expand Up @@ -138,7 +138,7 @@ def deduplicate_entities(
help="Settings for the graph enrichment process.",
)
@pass_context
def enrich_graph(
async def enrich_graph(
ctx, collection_id, run, force_kg_enrichment, kg_enrichment_settings
):
"""
Expand All @@ -163,7 +163,7 @@ def enrich_graph(
kg_enrichment_settings = {"force_kg_enrichment": True}

with timer():
response = client.enrich_graph(
response = await client.enrich_graph(
collection_id, run_type, kg_enrichment_settings
)

Expand Down Expand Up @@ -193,20 +193,15 @@ def enrich_graph(
multiple=True,
help="Entity IDs to filter by.",
)
@click.option(
"--with-description",
is_flag=True,
help="Include entity descriptions in the response.",
)
@pass_context
def get_entities(ctx, collection_id, offset, limit, entity_ids):
async def get_entities(ctx, collection_id, offset, limit, entity_ids):
"""
Retrieve entities from the knowledge graph.
"""
client = ctx.obj

with timer():
response = client.get_entities(
response = await client.get_entities(
collection_id,
offset,
limit,
Expand Down Expand Up @@ -245,14 +240,16 @@ def get_entities(ctx, collection_id, offset, limit, entity_ids):
help="Entity names to filter by.",
)
@pass_context
def get_triples(ctx, collection_id, offset, limit, triple_ids, entity_names):
async def get_triples(
ctx, collection_id, offset, limit, triple_ids, entity_names
):
"""
Retrieve triples from the knowledge graph.
"""
client = ctx.obj

with timer():
response = client.get_triples(
response = await client.get_triples(
collection_id,
offset,
limit,
Expand Down
27 changes: 15 additions & 12 deletions py/cli/commands/management.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,24 @@
@click.option("--filters", type=JSON, help="Filters for analytics as JSON")
@click.option("--analysis-types", type=JSON, help="Analysis types as JSON")
@pass_context
def analytics(ctx, filters: Dict[str, Any], analysis_types: Dict[str, Any]):
async def analytics(
ctx, filters: Dict[str, Any], analysis_types: Dict[str, Any]
):
client = ctx.obj
"""Retrieve analytics data."""
with timer():
response = client.analytics(filters, analysis_types)
response = await client.analytics(filters, analysis_types)

click.echo(response)


@cli.command()
@pass_context
def app_settings(client):
async def app_settings(ctx):
"""Retrieve application settings."""
client = ctx.obj
with timer():
response = client.app_settings()
response = await client.app_settings()

click.echo(response)

Expand All @@ -44,13 +47,13 @@ def app_settings(client):
help="The maximum number of nodes to return. Defaults to 100.",
)
@pass_context
def users_overview(ctx, user_ids, offset, limit):
async def users_overview(ctx, user_ids, offset, limit):
"""Get an overview of users."""
client = ctx.obj
user_ids = list(user_ids) if user_ids else None

with timer():
response = client.users_overview(user_ids, offset, limit)
response = await client.users_overview(user_ids, offset, limit)

if "results" in response:
click.echo("\nUser Overview:")
Expand All @@ -73,7 +76,7 @@ def users_overview(ctx, user_ids, offset, limit):
help="Filters for deletion in the format key:operator:value",
)
@pass_context
def delete(ctx, filter):
async def delete(ctx, filter):
"""Delete documents based on filters."""
client = ctx.obj
filters = {}
Expand All @@ -84,7 +87,7 @@ def delete(ctx, filter):
filters[key][f"${operator}"] = value

with timer():
response = client.delete(filters=filters)
response = await client.delete(filters=filters)

click.echo(response)

Expand All @@ -102,13 +105,13 @@ def delete(ctx, filter):
help="The maximum number of nodes to return. Defaults to 100.",
)
@pass_context
def documents_overview(ctx, document_ids, offset, limit):
async def documents_overview(ctx, document_ids, offset, limit):
"""Get an overview of documents."""
client = ctx.obj
document_ids = list(document_ids) if document_ids else None

with timer():
response = client.documents_overview(document_ids, offset, limit)
response = await client.documents_overview(document_ids, offset, limit)

for document in response["results"]:
click.echo(document)
Expand All @@ -133,15 +136,15 @@ def documents_overview(ctx, document_ids, offset, limit):
help="Should the vector be included in the response chunks",
)
@pass_context
def document_chunks(ctx, document_id, offset, limit, include_vectors):
async def document_chunks(ctx, document_id, offset, limit, include_vectors):
"""Get chunks of a specific document."""
client = ctx.obj
if not document_id:
click.echo("Error: Document ID is required.")
return

with timer():
chunks_data = client.document_chunks(
chunks_data = await client.document_chunks(
document_id, offset, limit, include_vectors
)

Expand Down
Loading

0 comments on commit c9b6549

Please sign in to comment.