Skip to content

Commit

Permalink
fix output typing
Browse files Browse the repository at this point in the history
  • Loading branch information
Corentin authored Mar 13, 2024
1 parent 69129c8 commit 10ea129
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import ClassVar, Dict, List, Optional
from typing import ClassVar, Dict, List, Optional, Union

import numpy as np

Expand Down Expand Up @@ -81,7 +81,7 @@ def __init__(
):
self.model = SparseTextEmbedding(model_name=model_name, cache_dir=cache_dir, threads=threads)

def embed(self, data: List[List[str]], **kwargs) -> List[Dict[str, np.ndarray]]:
def embed(self, data: List[List[str]], **kwargs) -> List[Dict[str, Union[List[int], List[float]]]]:
# The embed method returns a Iterable[SparseEmbedding], so we convert it to a list of dictionaries.
# Each dict contains an `indices` key containing a list of int and an `values` key containing a list of floats.
sparse_embeddings = [sparse_embedding.as_object() for sparse_embedding in self.model.embed(data, **kwargs)]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union
import numpy as np
from haystack import component, default_to_dict

Expand Down Expand Up @@ -94,7 +94,7 @@ def warm_up(self):
model_name=self.model_name, cache_dir=self.cache_dir, threads=self.threads
)

@component.output_types(embedding=List[Dict[str, np.ndarray]])
@component.output_types(embedding=List[Dict[str, Union[List[int], List[float]]]])
def run(self, text: str):
"""
Embeds text using the Fastembed model.
Expand Down

0 comments on commit 10ea129

Please sign in to comment.