Skip to content

Commit

Permalink
Improved SparseVectorField forms for Django
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Jun 12, 2024
1 parent 1e2006a commit 8d6da74
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 0 deletions.
11 changes: 11 additions & 0 deletions pgvector/django/sparsevec.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from django import forms
from django.db.models import Field
from ..utils import SparseVector

Expand Down Expand Up @@ -33,3 +34,13 @@ def get_prep_value(self, value):

def value_to_string(self, obj):
return self.get_prep_value(self.value_from_object(obj))

def formfield(self, **kwargs):
return super().formfield(form_class=SparseVectorFormField, **kwargs)


class SparseVectorFormField(forms.CharField):
def to_python(self, value):
if isinstance(value, str) and value == '':
return None
return super().to_python(value)
8 changes: 8 additions & 0 deletions tests/test_django.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,14 @@ def test_sparsevec_form_save(self):
assert form.save()
assert [4, 5, 6] == Item.objects.get(pk=1).sparse_embedding.to_list()

def test_sparesevec_form_save_missing(self):
Item(id=1).save()
item = Item.objects.get(pk=1)
form = SparseVectorForm(instance=item, data={'sparse_embedding': ''})
assert form.is_valid()
assert form.save()
assert Item.objects.get(pk=1).sparse_embedding is None

def test_clean(self):
item = Item(id=1, embedding=[1, 2, 3], half_embedding=[1, 2, 3], binary_embedding='101', sparse_embedding=SparseVector.from_dense([1, 2, 3]))
item.full_clean()
Expand Down

0 comments on commit 8d6da74

Please sign in to comment.