Skip to content

Commit

Permalink
bulk_write method (#239)
Browse files Browse the repository at this point in the history
* bulk_write and required infra, sync impl / no test yet

* bulk_write (sync): add unordered support and all tests

* async bulk_write + all tests

* add test for BulkWriteResult reduction
  • Loading branch information
hemidactylus authored Mar 6, 2024
1 parent 619bfe6 commit 8c5dbd2
Show file tree
Hide file tree
Showing 6 changed files with 898 additions and 5 deletions.
96 changes: 91 additions & 5 deletions astrapy/idiomatic/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@

from __future__ import annotations

import asyncio
import json
from typing import Any, Dict, Iterable, List, Optional, Union
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, Iterable, List, Optional, Union, TYPE_CHECKING

from astrapy.db import AstraDBCollection, AsyncAstraDBCollection
from astrapy.idiomatic.types import (
Expand All @@ -30,11 +32,17 @@
InsertManyResult,
InsertOneResult,
UpdateResult,
BulkWriteResult,
)
from astrapy.idiomatic.cursors import AsyncCursor, Cursor


if TYPE_CHECKING:
from astrapy.idiomatic.operations import AsyncBaseOperation, BaseOperation


INSERT_MANY_CONCURRENCY = 20
BULK_WRITE_CONCURRENCY = 10


def _prepare_update_info(status: Dict[str, Any]) -> Dict[str, Any]:
Expand Down Expand Up @@ -143,6 +151,7 @@ def insert_one(
if io_response["status"]["insertedIds"]:
inserted_id = io_response["status"]["insertedIds"][0]
return InsertOneResult(
raw_result=io_response,
inserted_id=inserted_id,
)
else:
Expand Down Expand Up @@ -195,7 +204,11 @@ def insert_many(
if isinstance(response, dict)
for ins_id in (response.get("status") or {}).get("insertedIds", [])
]
return InsertManyResult(inserted_ids=inserted_ids)
return InsertManyResult(
# if we are here, cim_responses are all dicts (no exceptions)
raw_result=cim_responses, # type: ignore[arg-type]
inserted_ids=inserted_ids,
)

def find(
self,
Expand Down Expand Up @@ -473,7 +486,7 @@ def delete_many(
raw_result=dm_responses,
)
else:
# expected a non-negative integer (None :
# per API specs, deleted_count has to be a non-negative integer.
return DeleteResult(
deleted_count=deleted_count,
raw_result=dm_responses,
Expand All @@ -484,6 +497,37 @@ def delete_many(
f"(gotten '${json.dumps(dm_responses)}')"
)

def bulk_write(
self,
requests: Iterable[BaseOperation],
*,
ordered: bool = True,
) -> BulkWriteResult:
# lazy importing here against circular-import error
from astrapy.idiomatic.operations import reduce_bulk_write_results

if ordered:
bulk_write_results = [
operation.execute(self, operation_i)
for operation_i, operation in enumerate(requests)
]
return reduce_bulk_write_results(bulk_write_results)
else:
with ThreadPoolExecutor(max_workers=BULK_WRITE_CONCURRENCY) as executor:
bulk_write_futures = [
executor.submit(
operation.execute,
self,
operation_i,
)
for operation_i, operation in enumerate(requests)
]
bulk_write_results = [
bulk_write_future.result()
for bulk_write_future in bulk_write_futures
]
return reduce_bulk_write_results(bulk_write_results)


class AsyncCollection:
def __init__(
Expand Down Expand Up @@ -579,6 +623,7 @@ async def insert_one(
if io_response["status"]["insertedIds"]:
inserted_id = io_response["status"]["insertedIds"][0]
return InsertOneResult(
raw_result=io_response,
inserted_id=inserted_id,
)
else:
Expand Down Expand Up @@ -631,7 +676,11 @@ async def insert_many(
if isinstance(response, dict)
for ins_id in (response.get("status") or {}).get("insertedIds", [])
]
return InsertManyResult(inserted_ids=inserted_ids)
return InsertManyResult(
# if we are here, cim_responses are all dicts (no exceptions)
raw_result=cim_responses, # type: ignore[arg-type]
inserted_ids=inserted_ids,
)

def find(
self,
Expand Down Expand Up @@ -916,7 +965,7 @@ async def delete_many(
raw_result=dm_responses,
)
else:
# expected a non-negative integer (None :
# per API specs, deleted_count has to be a non-negative integer.
return DeleteResult(
deleted_count=deleted_count,
raw_result=dm_responses,
Expand All @@ -926,3 +975,40 @@ async def delete_many(
"Could not complete a chunked_delete_many operation. "
f"(gotten '${json.dumps(dm_responses)}')"
)

async def bulk_write(
self,
requests: Iterable[AsyncBaseOperation],
*,
ordered: bool = True,
) -> BulkWriteResult:
# lazy importing here against circular-import error
from astrapy.idiomatic.operations import reduce_bulk_write_results

if ordered:
bulk_write_results = [
await operation.execute(self, operation_i)
for operation_i, operation in enumerate(requests)
]
return reduce_bulk_write_results(bulk_write_results)
else:
sem = asyncio.Semaphore(BULK_WRITE_CONCURRENCY)

async def concurrent_execute_operation(
operation: AsyncBaseOperation,
collection: AsyncCollection,
index_in_bulk_write: int,
) -> BulkWriteResult:
async with sem:
return await operation.execute(
collection=collection, index_in_bulk_write=index_in_bulk_write
)

tasks = [
asyncio.create_task(
concurrent_execute_operation(operation, self, operation_i)
)
for operation_i, operation in enumerate(requests)
]
bulk_write_results = await asyncio.gather(*tasks)
return reduce_bulk_write_results(bulk_write_results)
Loading

0 comments on commit 8c5dbd2

Please sign in to comment.