-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
change model with all management with sentence_transformers lib
- Loading branch information
Showing
3 changed files
with
30 additions
and
143 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,32 +21,17 @@ | |
__email__ = "[email protected]" | ||
__status__ = "Devel" | ||
|
||
import torch,json,os,re | ||
from transformers import BertTokenizer, BertModel | ||
from transformers import AutoTokenizer, AutoModel | ||
from sentence_transformers import SentenceTransformer, util | ||
import torch,os,re | ||
from sentence_transformers import SentenceTransformer, SimilarityFunction | ||
import numpy as np | ||
from scipy.spatial.distance import cdist | ||
from rich import print | ||
import pandas as pd | ||
|
||
import torch.nn.functional as F | ||
from sentence_transformers.util import cos_sim | ||
from tqdm import tqdm | ||
|
||
# Charger le modèle BERT et le tokenizer | ||
# bert-base-uncased | ||
# bert-large-uncased | ||
# roberta-base | ||
|
||
#tokenizer = BertTokenizer.from_pretrained('bert-large-uncased') | ||
#model = BertModel.from_pretrained('bert-large-uncased') | ||
|
||
# FacebookAI/roberta-base | ||
# sentence-transformers/all-MiniLM-L6-v2 | ||
|
||
# https://huggingface.co/spaces/mteb/leaderboard | ||
# mixedbread-ai/mxbai-embed-large-v1 | ||
|
||
class Singleton(type): | ||
_instances = {} | ||
|
@@ -57,128 +42,51 @@ def __call__(cls, *args, **kwargs): | |
|
||
|
||
class ModelEmbeddingManager(metaclass=Singleton): | ||
def __init__(self,config): | ||
self.config=config | ||
def __init__(self, config): | ||
self.config = config | ||
self.retention_dir = config['retention_dir'] | ||
|
||
if 'encodeur' in config: | ||
self.model_name = config['encodeur'] | ||
else: | ||
self.model_name = 'sentence-transformers/all-MiniLM-L6-v2' | ||
|
||
#self.model_name = 'mixedbread-ai/mxbai-embed-large-v1' | ||
#self.model_name = 'sentence-transformers/all-mpnet-base-v2' | ||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name,clean_up_tokenization_spaces=True) | ||
self.model = AutoModel.from_pretrained(self.model_name) | ||
|
||
if 'batch_size' in config: | ||
self.batch_size = config['batch_size'] | ||
else: | ||
self.batch_size=32 | ||
|
||
if 'threshold_similarity_tag' in config: | ||
self.threshold_similarity_tag = config['threshold_similarity_tag'] | ||
else: | ||
self.threshold_similarity_tag = 0.75 | ||
|
||
if 'threshold_similarity_tag_chunk' in config: | ||
self.threshold_similarity_tag_chunk = config['threshold_similarity_tag_chunk'] | ||
else: | ||
self.threshold_similarity_tag_chunk = 0.75 | ||
|
||
self.model_suffix=self.model_name.split("/").pop() | ||
self.model_suffix="all-MiniLM-L6-v2" | ||
self.model_name = config.get('encodeur', 'sentence-transformers/all-MiniLM-L6-v2') | ||
self.model = SentenceTransformer(self.model_name) | ||
self.model.similarity_fn_name = SimilarityFunction.MANHATTAN | ||
self.batch_size = config.get('batch_size', 32) | ||
self.threshold_similarity_tag = config.get('threshold_similarity_tag', 0.75) | ||
self.threshold_similarity_tag_chunk = config.get('threshold_similarity_tag_chunk', 0.75) | ||
|
||
print("------------------------------------") | ||
print("endoceur:",self.model_name) | ||
print("threshold_similarity_tag:",self.threshold_similarity_tag) | ||
print("threshold_similarity_tag_chunk:",self.threshold_similarity_tag_chunk) | ||
print("batch_size:",self.batch_size) | ||
print("Encoder:", self.model_name) | ||
print("Threshold similarity tag:", self.threshold_similarity_tag) | ||
print("Threshold similarity tag chunk:", self.threshold_similarity_tag_chunk) | ||
print("Batch size:", self.batch_size) | ||
print("------------------------------------") | ||
|
||
def get_filename_pth(self,name_embeddings): | ||
return f"{self.retention_dir}/{name_embeddings}-{self.model_suffix}.pth" | ||
def get_filename_pth(self, name_embeddings): | ||
return f"{self.retention_dir}/{name_embeddings}-{self.model_name.split('/')[-1]}.pth" | ||
|
||
def load_filepth(self,filename_embeddings): | ||
return torch.load(filename_embeddings,weights_only=False) | ||
def load_filepth(self, filename_embeddings): | ||
return torch.load(filename_embeddings,weights_only=True) | ||
|
||
def load_pth(self,name_embeddings): | ||
def load_pth(self, name_embeddings): | ||
filename = self.get_filename_pth(name_embeddings) | ||
|
||
tag_embeddings = {} | ||
|
||
if os.path.exists(filename): | ||
print(f"load embeddings - {filename}") | ||
tag_embeddings = torch.load(filename,weights_only=False) | ||
print(f"Loading embeddings from {filename}") | ||
tag_embeddings = torch.load(filename,weights_only=True) | ||
return tag_embeddings | ||
|
||
def save_pth(self,tag_embeddings,name_embeddings): | ||
def save_pth(self, tag_embeddings, name_embeddings): | ||
filename = self.get_filename_pth(name_embeddings) | ||
torch.save(tag_embeddings, filename) | ||
|
||
#Mean Pooling - Take attention mask into account for correct averaging | ||
def mean_pooling(self,model_output, attention_mask): | ||
token_embeddings = model_output[0] #First element of model_output contains all token embeddings | ||
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | ||
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) | ||
|
||
def encode_text_base(self,text): | ||
inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True) | ||
with torch.no_grad(): | ||
outputs = self.model(**inputs) | ||
return outputs.last_hidden_state.mean(dim=1) | ||
|
||
def encode_text_allMiniLML6V2(self,text): | ||
inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True) | ||
with torch.no_grad(): | ||
outputs = self.model(**inputs) | ||
|
||
return F.normalize(self.mean_pooling(outputs, inputs['attention_mask']), p=2, dim=1) | ||
#return outputs.last_hidden_state.mean(dim=1) | ||
|
||
def encode_text(self,text): | ||
return self.encode_text_allMiniLML6V2(text) | ||
#return self.encode_text_base(text) | ||
|
||
def encode_text_batch_allMiniLML6V2(self,texts, batch_size=32): | ||
# Passage en mode évaluation | ||
self.model.eval() | ||
|
||
all_embeddings = [] | ||
|
||
# Traitement par lots | ||
for i in tqdm(range(0, len(texts), batch_size)): | ||
batch_texts = texts[i:i+batch_size] | ||
|
||
# Tokenization du lot | ||
inputs = self.tokenizer(batch_texts, return_tensors='pt', truncation=True, padding=True) | ||
|
||
# Déplacement des tenseurs sur GPU si disponible | ||
if torch.cuda.is_available(): | ||
inputs = {k: v.to('cuda') for k, v in inputs.items()} | ||
self.model.to('cuda') | ||
|
||
# Calcul des embeddings | ||
with torch.no_grad(): | ||
outputs = self.model(**inputs) | ||
|
||
# Pooling et normalisation | ||
embeddings = F.normalize(self.mean_pooling(outputs, inputs['attention_mask']), p=2, dim=1) | ||
|
||
# Déplacement des embeddings sur CPU si nécessaire | ||
if torch.cuda.is_available(): | ||
embeddings = embeddings.cpu() | ||
|
||
all_embeddings.append(embeddings) | ||
|
||
# Concaténation de tous les embeddings | ||
return torch.cat(all_embeddings, dim=0) | ||
def encode_text(self, text): | ||
return self.model.encode(text, convert_to_tensor=True) | ||
|
||
def encode_text_batch(self,texts): | ||
return self.encode_text_batch_allMiniLML6V2(texts, self.batch_size) | ||
def encode_text_batch(self, texts): | ||
return self.model.encode(texts, batch_size=self.batch_size, convert_to_tensor=True) | ||
|
||
def cosine_similarity(self,a, b): | ||
return cos_sim(a, b)[0].item() | ||
#return F.cosine_similarity(a, b).item() | ||
def cosine_similarity(self, a, b): | ||
return self.model.similarity(a, b) | ||
|
||
def best_similarity_for_tag(self,chunks_embedding, tag_embeddings): | ||
best_similarity = float('-inf') | ||
|