diff --git a/nearpy/storage/storage_mongo.py b/nearpy/storage/storage_mongo.py index 19fc01e..da5962e 100644 --- a/nearpy/storage/storage_mongo.py +++ b/nearpy/storage/storage_mongo.py @@ -29,13 +29,20 @@ import numpy import scipy +from nearpy.utils.utils import convert2unicode + try: import cPickle as pickle except ImportError: import pickle -from future.builtins import bytes +try: + from pymongo import InsertOne +except ImportError: + pass + from nearpy.storage.storage import Storage +from future.builtins import zip class MongoStorage(Storage): @@ -45,7 +52,22 @@ def __init__(self, mongo_object): """ Uses specified pymongo object for storage. """ self.mongo_object = mongo_object + def store_many_vectors(self, hash_name, bucket_keys, vs, data): + requests = [] + + for v, d, bk in zip(vs, data, bucket_keys): + vc = self._get_vector(hash_name, bk, v, d) + + requests.append(InsertOne(vc)) + + self.mongo_object.bulk_write(requests, ordered=False) + def store_vector(self, hash_name, bucket_key, v, data): + val_dict = self._get_vector(hash_name, bucket_key, v, data) + + self.mongo_object.insert_one(val_dict) + + def _get_vector(self, hash_name, bucket_key, v, data): """ Stores vector and JSON-serializable data in MongoDB with specified key. """ @@ -83,8 +105,9 @@ def store_vector(self, hash_name, bucket_key, v, data): if data is not None: val_dict['data'] = data - # Push JSON representation of dict to end of bucket list - self.mongo_object.insert_one(val_dict) + convert2unicode(val_dict) + + return val_dict def _format_mongo_key(self, hash_name, bucket_key): return '{}{}'.format(self._format_hash_prefix(hash_name), bucket_key) @@ -147,7 +170,7 @@ def get_bucket(self, hash_name, bucket_key): shape=(val_dict['dim'], 1)) else: - vector = numpy.fromstring(val_dict['vector'], + vector = numpy.frombuffer(val_dict['vector'], dtype=val_dict['dtype']) [val_dict.pop(k) for k in ['vector', 'dtype', '_id']] # Add data to result tuple, if present @@ -186,5 +209,6 @@ def load_hash_configuration(self, hash_name): conf = self.mongo_object.find_one( {'hash_conf_name': hash_name + '_conf'} ) + return pickle.loads(conf['hash_configuration']) if conf is not None\ else None diff --git a/nearpy/utils/utils.py b/nearpy/utils/utils.py index 69f9b4c..15a6a34 100644 --- a/nearpy/utils/utils.py +++ b/nearpy/utils/utils.py @@ -90,3 +90,10 @@ def want_string(arg, encoding='utf-8'): rv = arg return rv + +def convert2unicode(mydict): + for k, v in mydict.iteritems(): + if isinstance(v, str): + mydict[k] = unicode(v, errors='replace') + elif isinstance(v, dict): + convert2unicode(v) diff --git a/setup.py b/setup.py index 498f421..18146bd 100644 --- a/setup.py +++ b/setup.py @@ -28,5 +28,6 @@ "redis", "mockredispy", "mongomock", + "pymongo" ] ) diff --git a/tests/storage_tests.py b/tests/storage_tests.py index 8f5750d..c2137d1 100644 --- a/tests/storage_tests.py +++ b/tests/storage_tests.py @@ -185,6 +185,10 @@ def test_store_zero(self): _, data = bucket[0] self.assertEqual(data, 0) + def test_store_many_vectors(self): + x = numpy.random.randn(100, 10) + self.check_store_many_vectors(x) + if __name__ == '__main__': unittest.main()