diff --git a/README.rst b/README.rst index c1564528..5ca0652b 100644 --- a/README.rst +++ b/README.rst @@ -68,6 +68,45 @@ Asynchronous `scroll `_ + +.. code-block:: python + + import asyncio + + from aioelasticsearch import Elasticsearch + from aioelasticsearch.helpers import bulk + + def gen_data(): + for i in range(10): + yield { "_index" : "test", + "_type" : "_doc", + "_id" : str(i), + "FIELD1": "TEXT", + } + def gen_data2(): + for i in range(10): + yield { "_index" : "test", + "_type" : "_doc", + "_id" : str(i), + "_source":{ + "FIELD1": "TEXT", + } + } + + + async def go(): + async with Elasticsearch() as es: + success, fails = \ + await bulk(es, gen_data()) + + + loop = asyncio.get_event_loop() + loop.run_until_complete(go()) + loop.close() + + + Thanks ------ diff --git a/aioelasticsearch/helpers.py b/aioelasticsearch/helpers.py index 484f035d..d627ab12 100644 --- a/aioelasticsearch/helpers.py +++ b/aioelasticsearch/helpers.py @@ -1,9 +1,11 @@ import asyncio import logging +from operator import methodcaller from aioelasticsearch import NotFoundError -from elasticsearch.helpers import ScanError +from elasticsearch.helpers import ScanError, _chunk_actions, expand_action +from elasticsearch.exceptions import TransportError from .compat import PY_352 @@ -147,3 +149,136 @@ def _update_state(self, resp): self._successful_shards = resp['_shards']['successful'] self._total_shards = resp['_shards']['total'] self._done = not self._hits or self._scroll_id is None + + +async def _process_bulk(client, datas, actions, **kwargs): + try: + resp = await client.bulk("\n".join(actions) + '\n', **kwargs) + except TransportError as e: + return e, datas + fail_actions = [] + finish_count = 0 + for data, (op_type, item) in zip(datas, map(methodcaller('popitem'), + resp['items'])): + ok = 200 <= item.get('status', 500) < 300 + if not ok: + fail_actions.append(data) + else: + finish_count += 1 + return finish_count, fail_actions + + +async def _retry_handler(client, coroutine, max_retries, initial_backoff, + max_backoff, **kwargs): + + finish = 0 + bulk_data = [] + for attempt in range(max_retries + 1): + bulk_action = [] + lazy_exception = None + + if attempt: + sleep = min(max_backoff, initial_backoff * 2 ** (attempt - 1)) + logger.debug('Retry %d count, sleep %d second.', attempt, sleep) + await asyncio.sleep(sleep, loop=client.loop) + + result = await coroutine + if isinstance(result[0], int): + finish += result[0] + else: + lazy_exception = result[0] + + bulk_data = result[1] + + for tuple_data in bulk_data: + data = None + if len(tuple_data) == 2: + data = tuple_data[1] + action = tuple_data[0] + + action = client.transport.serializer.dumps(action) + bulk_action.append(action) + if data is not None: + data = client.transport.serializer.dumps(data) + bulk_action.append(data) + + if not bulk_action or attempt == max_retries: + break + + coroutine = _process_bulk(client, bulk_data, bulk_action, **kwargs) + + if lazy_exception: + raise lazy_exception + + return finish, bulk_data + + +async def bulk(client, actions, chunk_size=500, max_retries=0, + max_chunk_bytes=100 * 1024 * 1024, + expand_action_callback=expand_action, initial_backoff=2, + max_backoff=600, stats_only=False, **kwargs): + actions = map(expand_action_callback, actions) + + finish_count = 0 + if stats_only: + fail_datas = 0 + else: + fail_datas = [] + + chunk_action_iter = _chunk_actions(actions, chunk_size, max_chunk_bytes, + client.transport.serializer) + + for bulk_data, bulk_action in chunk_action_iter: + coroutine = _process_bulk(client, bulk_data, bulk_action, **kwargs) + count, fails = await _retry_handler(client, + coroutine, + max_retries, + initial_backoff, + max_backoff, + **kwargs) + + finish_count += count + if stats_only: + fail_datas += len(fails) + else: + fail_datas.extend(fails) + + return finish_count, fail_datas + + +async def concurrency_bulk(client, actions, concurrency_count=4, + chunk_size=500, max_retries=0, + max_chunk_bytes=100 * 1024 * 1024, + expand_action_callback=expand_action, + initial_backoff=2, max_backoff=600, **kwargs): + + async def concurrency_wrapper(action_iter): + p_count = p_fails = 0 + for bulk_data, bulk_action in action_iter: + coroutine = _process_bulk(client, bulk_data, bulk_action, **kwargs) + count, fails = await _retry_handler(client, + coroutine, + max_retries, + initial_backoff, + max_backoff, **kwargs) + p_count += count + p_fails += len(fails) + return p_count, p_fails + + actions = map(expand_action_callback, actions) + chunk_action_iter = _chunk_actions(actions, chunk_size, max_chunk_bytes, + client.transport.serializer) + + tasks = [] + for i in range(concurrency_count): + tasks.append(concurrency_wrapper(chunk_action_iter)) + + results = await asyncio.gather(*tasks, loop=client.loop) + + finish_count = 0 + fail_count = 0 + for p_finish, p_fail in results: + finish_count += p_finish + fail_count += p_fail + + return finish_count, fail_count diff --git a/tests/test_bulk.py b/tests/test_bulk.py new file mode 100644 index 00000000..f558ac62 --- /dev/null +++ b/tests/test_bulk.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- +import asyncio +import logging + +import pytest + +from aioelasticsearch.helpers import bulk, concurrency_bulk, _retry_handler +from aioelasticsearch import Elasticsearch, TransportError + +logger = logging.getLogger('elasticsearch') + + +def gen_data1(): + for i in range(10): + yield {"_index": "test_aioes", + "_type": "type_3", + "_id": str(i), + "foo": "1"} + + +def gen_data2(): + for i in range(10, 20): + yield {"_index": "test_aioes", + "_type": "type_3", + "_id": str(i), + "_source": {"foo": "1"} + } + + +@pytest.mark.run_loop +async def test_bulk_simple(es): + success, fails = await bulk(es, gen_data1(), + stats_only=True) + assert success == 10 + assert fails == 0 + + success, fails = await bulk(es, gen_data2(), + stats_only=True) + assert success == 10 + assert fails == 0 + + success, fails = await bulk(es, gen_data1(), + stats_only=False) + assert success == 10 + assert fails == [] + + +@pytest.mark.run_loop +async def test_bulk_fails(es): + datas = [{'_op_type': 'delete', + '_index': 'test_aioes', + '_type': 'type_3', '_id': "999"} + ] + + success, fails = await bulk(es, datas, stats_only=True, max_retries=1) + assert success == 0 + assert fails == 1 + + +@pytest.mark.run_loop +async def test_concurrency_bulk(es): + success, fails = await concurrency_bulk(es, gen_data1()) + assert success == 10 + assert fails == 0 + + success, fails = await concurrency_bulk(es, gen_data2()) + assert success == 10 + assert fails == 0 + + +@pytest.mark.run_loop +async def test_bulk_raise_exception(loop): + + asyncio.set_event_loop(loop) + es = Elasticsearch() + datas = [{'_op_type': 'delete', + '_index': 'test_aioes', + '_type': 'type_3', '_id': "999"} + ] + with pytest.raises(TransportError): + success, fails = await bulk(es, datas, stats_only=True) + + +@pytest.mark.run_loop +async def test_retry_handler(es): + async def mock_data(): + # finish_count, [( es_action, source_data ), ... ] + return 0, [( + {'index': {'_index': 'test_aioes', '_type': 'test_aioes', '_id': 100}}, + {'name': 'Karl 1', 'email': 'karl@example.com'}), + ({'index': {'_index': 'test_aioes', '_type': 'test_aioes','_id': 101}}, + {'name': 'Karl 2', 'email': 'karl@example.com'})] + + done, fail = await _retry_handler(es, + mock_data(), + max_retries=1, + initial_backoff=2, + max_backoff=600) + assert done == 2 + assert fail == []