diff --git a/sentence_transformers/models/StaticEmbedding.py b/sentence_transformers/models/StaticEmbedding.py index fe0b8c987..339893b96 100644 --- a/sentence_transformers/models/StaticEmbedding.py +++ b/sentence_transformers/models/StaticEmbedding.py @@ -19,7 +19,7 @@ class StaticEmbedding(nn.Module): def __init__( self, tokenizer: Tokenizer | PreTrainedTokenizerFast, - embedding_weights: np.array | torch.Tensor | None = None, + embedding_weights: np.ndarray | torch.Tensor | None = None, embedding_dim: int | None = None, **kwargs, ) -> None: @@ -30,7 +30,7 @@ def __init__( Args: tokenizer (Tokenizer | PreTrainedTokenizerFast): The tokenizer to be used. Must be a fast tokenizer from ``transformers`` or ``tokenizers``. - embedding_weights (np.array | torch.Tensor | None, optional): Pre-trained embedding weights. + embedding_weights (np.ndarray | torch.Tensor | None, optional): Pre-trained embedding weights. Defaults to None. embedding_dim (int | None, optional): Dimension of the embeddings. Required if embedding_weights is not provided. Defaults to None.