Skip to content

Commit

Permalink
feat: dd method queryid to check id exists (#63)
Browse files Browse the repository at this point in the history
* add num_docs

add num_docs

* Update test_inmemory_vectordb.py

* Update test_hnswlib_vectordb.py

* add

* add

* Update test_inmemory_vectordb.py

* change method name

* commit some miss files

* blank spaces change

* Update test_inmemory_vectordb.py

* Update test_inmemory_vectordb.py
  • Loading branch information
0x376h authored Oct 17, 2023
1 parent a430808 commit 5f8fc99
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 6 deletions.
13 changes: 11 additions & 2 deletions tests/unit/test_hnswlib_vectordb.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,5 +175,14 @@ def test_hnswlib_num_dos(tmpdir):
db = HNSWVectorDB[MyDoc](workspace=str(tmpdir))
doc_list = [MyDoc(text=f'toy doc {i}', embedding=np.random.rand(128)) for i in range(1000)]
db.index(inputs=DocList[MyDoc](doc_list))
x=db.num_docs()
assert x['num_docs']==1000
x = db.num_docs()
assert x['num_docs'] == 1000

def test_hnswlib_query_id(tmpdir):
db = HNSWVectorDB[MyDoc](workspace=str(tmpdir))
doc_list = [MyDoc(id='test_1',text=f'test', embedding=np.random.rand(128)) ]
db.index(inputs=DocList[MyDoc](doc_list))
queryobjtest1 = db.get_by_id('test_1')
queryobjtest2 = db.get_by_id('test_2')
assert queryobjtest2 is None
assert queryobjtest1.id == 'test_1'
13 changes: 11 additions & 2 deletions tests/unit/test_inmemory_vectordb.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,5 +177,14 @@ def test_inmemory_num_dos(tmpdir):
db = InMemoryExactNNVectorDB[MyDoc](workspace=str(tmpdir))
doc_list = [MyDoc(text=f'toy doc {i}', embedding=np.random.rand(128)) for i in range(1000)]
db.index(inputs=DocList[MyDoc](doc_list))
x=db.num_docs()
assert x['num_docs']==1000
x = db.num_docs()
assert x['num_docs'] == 1000

def test_inmemory_query_id(tmpdir):
db = InMemoryExactNNVectorDB[MyDoc](workspace=str(tmpdir))
doc_list = [MyDoc(id='test_1',text=f'test', embedding=np.random.rand(128)) ]
db.index(inputs=DocList[MyDoc](doc_list))
queryobjtest1 = db.get_by_id('test_1')
queryobjtest2 = db.get_by_id('test_2')
assert queryobjtest2 is None
assert queryobjtest1.id == 'test_1'
8 changes: 8 additions & 0 deletions vectordb/db/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,14 @@ async def _deploy():

def num_docs(self, **kwargs):
return self._executor.num_docs()

def get_by_id(self,info_id, **kwargs):
ret = None
try:
ret = self._executor.get_by_id(info_id)
except KeyError:
pass
return ret

@pass_kwargs_as_params
@unify_input_output
Expand Down
5 changes: 4 additions & 1 deletion vectordb/db/executors/hnsw_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,10 @@ async def async_update(self, docs, *args, **kwargs):

def num_docs(self, **kwargs):
return {'num_docs': self._indexer.num_docs()}


def get_by_id(self,info_id,**kwargs):
return self._indexer[info_id]

def snapshot(self, snapshot_dir):
# TODO: Maybe copy the work_dir to workspace if `handle` is False
raise NotImplementedError('Act as not implemented')
Expand Down
5 changes: 4 additions & 1 deletion vectordb/db/executors/inmemory_exact_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ def update(self, docs, *args, **kwargs):

def num_docs(self, *args, **kwargs):
return {'num_docs': self._indexer.num_docs()}


def get_by_id(self,info_id,**kwargs):
return self._indexer[info_id]

def snapshot(self, snapshot_dir):
snapshot_file = f'{snapshot_dir}/index.bin'
self._indexer.persist(snapshot_file)
Expand Down

0 comments on commit 5f8fc99

Please sign in to comment.