diff --git a/spanking/main.py b/spanking/main.py index f6c224c..98cfb3b 100644 --- a/spanking/main.py +++ b/spanking/main.py @@ -20,7 +20,7 @@ def add_texts(self, texts): def delete_text(self, index): if 0 <= index < len(self.texts): self.texts.pop(index) - self.embeddings = self.embeddings.at[index].delete() + self.embeddings = jnp.delete(self.embeddings, index, axis=0) else: raise IndexError("Invalid index")