diff --git a/pgvector/utils/sparsevec.py b/pgvector/utils/sparsevec.py index c21e0d3..a451fc6 100644 --- a/pgvector/utils/sparsevec.py +++ b/pgvector/utils/sparsevec.py @@ -1,18 +1,23 @@ import numpy as np from struct import pack, unpack_from +NO_DEFAULT = object() + class SparseVector: - def __init__(self, value, dimensions=None, /): + def __init__(self, value, dimensions=NO_DEFAULT, /): if value.__class__.__module__ == 'scipy.sparse._arrays': - if dimensions is not None: + if dimensions is not NO_DEFAULT: raise ValueError('dimensions not allowed') self._from_sparse(value) elif isinstance(value, dict): + if dimensions is NO_DEFAULT: + raise ValueError('dimensions required') + self._from_dict(value, dimensions) else: - if dimensions is not None: + if dimensions is not NO_DEFAULT: raise ValueError('dimensions not allowed') self._from_dense(value) @@ -56,9 +61,6 @@ def to_binary(self): return pack(f'>iii{nnz}i{nnz}f', self._dim, nnz, 0, *self._indices, *self._values) def _from_dict(self, d, dim): - if dim is None: - raise ValueError('dimensions required') - elements = [(i, v) for i, v in d.items() if v != 0] elements.sort()