Skip to content

Commit

Permalink
Fix linting issues
Browse files Browse the repository at this point in the history
  • Loading branch information
botirk38 committed Sep 3, 2024
1 parent 2a9f836 commit 4001d3a
Showing 1 changed file with 6 additions and 9 deletions.
15 changes: 6 additions & 9 deletions huggingface_pipelines/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,7 @@ def load_spacy_model(self, lang_code: str) -> Language:
nlp = pipeline.load_spacy_model('en')
"""
if lang_code not in self.SPACY_MODELS:
raise ValueError(
f"No installed model found for language code: {lang_code}")
raise ValueError(f"No installed model found for language code: {lang_code}")
return spacy.load(self.SPACY_MODELS[lang_code])

def segment_text(self, text: Optional[str]) -> List[str]:
Expand Down Expand Up @@ -295,7 +294,7 @@ def __init__(self, config: EmbeddingToTextPipelineConfig):
self.t2t_model = EmbeddingToTextModelPipeline(
decoder=self.config.decoder_model,
tokenizer=self.config.decoder_model,
device=torch.device(self.config.device)
device=torch.device(self.config.device),
)
logger.info("Model initialized.")

Expand Down Expand Up @@ -329,8 +328,7 @@ def process_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]:
and not isinstance(item[0], (list, np.ndarray))
for item in embeddings
):
all_embeddings = np.asarray(
embeddings, dtype=self.config.dtype)
all_embeddings = np.asarray(embeddings, dtype=self.config.dtype)
all_decoded_texts = self.decode_embeddings(all_embeddings)
batch[f"{column}_{self.config.output_column_suffix}"] = (
all_decoded_texts
Expand Down Expand Up @@ -382,10 +380,9 @@ def decode_embeddings(self, embeddings: np.ndarray) -> List[str]:
decoded_texts = []

for i in range(0, len(embeddings), self.config.batch_size):
batch_embeddings = embeddings[i: i + self.config.batch_size]
batch_embeddings = embeddings[i : i + self.config.batch_size]
batch_embeddings_tensor = (
torch.from_numpy(batch_embeddings).float().to(
self.config.device)
torch.from_numpy(batch_embeddings).float().to(self.config.device)
)

batch_decoded = self.t2t_model.predict(
Expand Down Expand Up @@ -546,7 +543,7 @@ def encode_texts(self, texts: List[str]) -> np.ndarray:
try:
embeddings: List[np.ndarray] = []
for i in range(0, len(texts), self.config.batch_size):
batch_texts = texts[i: i + self.config.batch_size]
batch_texts = texts[i : i + self.config.batch_size]
batch_embeddings = self.t2vec_model.predict(
batch_texts,
source_lang=self.config.source_lang,
Expand Down

0 comments on commit 4001d3a

Please sign in to comment.