Skip to content

Commit

Permalink
🐛 Add model-eval to smxm and gliner
Browse files Browse the repository at this point in the history
Signed-off-by: Marcos Martinez <[email protected]>
  • Loading branch information
Marcos Martinez authored and Marcos Martinez committed Nov 24, 2024
1 parent f5a2a9a commit 2268aaa
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 1 deletion.
4 changes: 3 additions & 1 deletion zshot/evaluation/zshot_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,10 @@ def make_table(data: Dict, title: str = ""):
return table

tables = []
mode = evaluation.pop('evaluation_mode')
mode = evaluation.get('evaluation_mode')
for component in evaluation:
if component == 'evaluation_mode':
continue
# General evaluation
t_repr = make_table(evaluation[component], f"{component} - {name} \n General - {mode}-based").get_string()
tables.append(fix_table_title(t_repr))
Expand Down
1 change: 1 addition & 0 deletions zshot/linker/linker_gliner.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def load_models(self):
""" Load GLINER model """
if self.model is None:
self.model = GLiNER.from_pretrained(self.model_name, cache_dir=MODELS_CACHE_PATH).to(self.device)
self.model.eval()

def predict(self, docs: Iterator[Doc], batch_size: Optional[Union[int, None]] = None) -> List[List[Span]]:
"""
Expand Down
1 change: 1 addition & 0 deletions zshot/linker/linker_smxm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def load_models(self):
self.model = BertTaggerMultiClass.from_pretrained(
self.model_name, output_hidden_states=True, cache_dir=MODELS_CACHE_PATH
).to(self.device)
self.model.eval()

def predict(self, docs: Iterator[Doc], batch_size: Optional[Union[int, None]] = None) -> List[List[Span]]:
"""
Expand Down
1 change: 1 addition & 0 deletions zshot/mentions_extractor/mentions_extractor_gliner.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def load_models(self):
""" Load GLINER model """
if self.model is None:
self.model = GLiNER.from_pretrained(self.model_name, cache_dir=MODELS_CACHE_PATH).to(self.device)
self.model.eval()

def predict(self, docs: Iterator[Doc], batch_size: Optional[Union[int, None]] = None) -> List[List[Span]]:
"""
Expand Down
1 change: 1 addition & 0 deletions zshot/mentions_extractor/mentions_extractor_smxm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def load_models(self):
self.model = BertTaggerMultiClass.from_pretrained(
self.model_name, output_hidden_states=True, cache_dir=MODELS_CACHE_PATH
).to(self.device)
self.model.eval()

def predict(self, docs: Iterator[Doc], batch_size: Optional[Union[int, None]] = None) -> List[List[Span]]:
"""
Expand Down

0 comments on commit 2268aaa

Please sign in to comment.