diff --git a/experiment_scripts/misc/dump_embeddings.py b/experiment_scripts/misc/dump_embeddings.py index cc2d0bd9..8d21f487 100644 --- a/experiment_scripts/misc/dump_embeddings.py +++ b/experiment_scripts/misc/dump_embeddings.py @@ -27,9 +27,9 @@ ) from tokenize_data import tokenize_function from utils import emb +from vec2text.models.model_utils import device num_workers = len(os.sched_getaffinity(0)) -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def load_model_and_tokenizers( diff --git a/experiment_scripts/misc/dump_embeddings_binary.py b/experiment_scripts/misc/dump_embeddings_binary.py index 65984ec2..0b00c340 100644 --- a/experiment_scripts/misc/dump_embeddings_binary.py +++ b/experiment_scripts/misc/dump_embeddings_binary.py @@ -14,8 +14,7 @@ import tqdm import transformers from data_helpers import NQ_DEV, load_dpr_corpus - -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +from vec2text.models.model_utils import device def entropy__bits(p: float) -> float: diff --git a/experiment_scripts/misc/emb_analysis.py b/experiment_scripts/misc/emb_analysis.py index 6edb2282..9067d1f0 100644 --- a/experiment_scripts/misc/emb_analysis.py +++ b/experiment_scripts/misc/emb_analysis.py @@ -28,7 +28,7 @@ num_workers = len(os.sched_getaffinity(0)) max_seq_length = 128 -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +from vec2text.models.model_utils import device def reorder_words_except_padding( diff --git a/experiment_scripts/misc/sentence_closest_words.py b/experiment_scripts/misc/sentence_closest_words.py index 7a03dfdb..b772e87b 100644 --- a/experiment_scripts/misc/sentence_closest_words.py +++ b/experiment_scripts/misc/sentence_closest_words.py @@ -11,8 +11,7 @@ import transformers from models import InversionModel, load_embedder_and_tokenizer, load_encoder_decoder from utils import embed_all_tokens - -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +from vec2text.models.model_utils import device # embedder_model_name = "dpr" embedder_model_name = "gtr_base" diff --git a/experiment_scripts/misc/sentence_pair_similarity.py b/experiment_scripts/misc/sentence_pair_similarity.py index e9e83807..cce004c8 100644 --- a/experiment_scripts/misc/sentence_pair_similarity.py +++ b/experiment_scripts/misc/sentence_pair_similarity.py @@ -9,7 +9,7 @@ import torch from models import InversionModel, load_embedder_and_tokenizer, load_encoder_decoder -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +from vec2text.models.model_utils import device def main(): diff --git a/vec2text/analyze_utils.py b/vec2text/analyze_utils.py index ebaaba40..bdb2b3b8 100644 --- a/vec2text/analyze_utils.py +++ b/vec2text/analyze_utils.py @@ -16,14 +16,8 @@ from vec2text.models.config import InversionConfig from vec2text.run_args import DataArguments, ModelArguments, TrainingArguments from vec2text import run_args as run_args +from vec2text.models.model_utils import device -device = torch.device( - "cuda" - if torch.cuda.is_available() - else "mps" - if torch.backends.mps.is_available() - else "cpu" -) transformers.logging.set_verbosity_error() ############################################################################# diff --git a/vec2text/api.py b/vec2text/api.py index b5e6c60f..29f4824a 100644 --- a/vec2text/api.py +++ b/vec2text/api.py @@ -5,7 +5,7 @@ import transformers import vec2text -from vec2text.models.model_utils import device +from vec2text.models.model_utils import device SUPPORTED_MODELS = ["text-embedding-ada-002", "gtr-base"] @@ -23,17 +23,17 @@ def load_pretrained_corrector(embedder: str) -> vec2text.trainers.Corrector: if embedder == "text-embedding-ada-002": inversion_model = vec2text.models.InversionModel.from_pretrained( "jxm/vec2text__openai_ada002__msmarco__msl128__hypothesizer" - ) + ).to(device) model = vec2text.models.CorrectorEncoderModel.from_pretrained( "jxm/vec2text__openai_ada002__msmarco__msl128__corrector" - ) + ).to(device) elif embedder == "gtr-base": inversion_model = vec2text.models.InversionModel.from_pretrained( "jxm/gtr__nq__32" - ) + ).to(device) model = vec2text.models.CorrectorEncoderModel.from_pretrained( "jxm/gtr__nq__32__correct" - ) + ).to(device) else: raise NotImplementedError(f"embedder `{embedder}` not implemented") @@ -82,13 +82,17 @@ def invert_embeddings( corrector: vec2text.trainers.Corrector, num_steps: int = None, sequence_beam_width: int = 0, + max_length: int = 128, ) -> List[str]: + # Ensure embeddings are on the correct device + embeddings = embeddings.to(device) + corrector.inversion_trainer.model.eval() corrector.model.eval() gen_kwargs = copy.copy(corrector.gen_kwargs) gen_kwargs["min_length"] = 1 - gen_kwargs["max_length"] = 128 + gen_kwargs["max_length"] = max_length if num_steps is None: assert ( @@ -124,6 +128,9 @@ def invert_embeddings_and_return_hypotheses( num_steps: int = None, sequence_beam_width: int = 0, ) -> List[str]: + # Ensure embeddings are on the correct device + embeddings = embeddings.to(device) + corrector.inversion_trainer.model.eval() corrector.model.eval() @@ -175,4 +182,4 @@ def invert_strings( corrector=corrector, num_steps=num_steps, sequence_beam_width=sequence_beam_width, - ) + ) \ No newline at end of file diff --git a/vec2text/api_old.py b/vec2text/api_old.py new file mode 100644 index 00000000..baa4456a --- /dev/null +++ b/vec2text/api_old.py @@ -0,0 +1,179 @@ +import copy +from typing import List + +import torch +import transformers + +import vec2text +from vec2text.models.model_utils import device + +SUPPORTED_MODELS = ["text-embedding-ada-002", "gtr-base"] + + +def load_pretrained_corrector(embedder: str) -> vec2text.trainers.Corrector: + """Gets the Corrector object for the given embedder. + + For now, we just support inverting OpenAI Ada 002 and gtr-base embeddings; we plan to + expand this support over time. + """ + assert ( + embedder in SUPPORTED_MODELS + ), f"embedder to invert `{embedder} not in list of supported models: {SUPPORTED_MODELS}`" + + if embedder == "text-embedding-ada-002": + inversion_model = vec2text.models.InversionModel.from_pretrained( + "jxm/vec2text__openai_ada002__msmarco__msl128__hypothesizer" + ) + model = vec2text.models.CorrectorEncoderModel.from_pretrained( + "jxm/vec2text__openai_ada002__msmarco__msl128__corrector" + ) + elif embedder == "gtr-base": + inversion_model = vec2text.models.InversionModel.from_pretrained( + "jxm/gtr__nq__32" + ) + model = vec2text.models.CorrectorEncoderModel.from_pretrained( + "jxm/gtr__nq__32__correct" + ) + else: + raise NotImplementedError(f"embedder `{embedder}` not implemented") + + return load_corrector(inversion_model, model) + + +def load_corrector( + inversion_model: vec2text.models.InversionModel, + corrector_model: vec2text.models.CorrectorEncoderModel, +) -> vec2text.trainers.Corrector: + """Load in the inversion and corrector models + + Args: + inversion_model (vec2text.models.InversionModel): _description_ + corrector_model (vec2text.models.CorrectorEncoderModel): _description_ + + Returns: + vec2text.trainers.Corrector: Corrector model to invert an embedding back to text + """ + + inversion_trainer = vec2text.trainers.InversionTrainer( + model=inversion_model, + train_dataset=None, + eval_dataset=None, + data_collator=transformers.DataCollatorForSeq2Seq( + inversion_model.tokenizer, + label_pad_token_id=-100, + ), + ) + + # backwards compatibility stuff + corrector_model.config.dispatch_batches = None + corrector = vec2text.trainers.Corrector( + model=corrector_model, + inversion_trainer=inversion_trainer, + args=None, + data_collator=vec2text.collator.DataCollatorForCorrection( + tokenizer=inversion_trainer.model.tokenizer + ), + ) + return corrector + + +def invert_embeddings( + embeddings: torch.Tensor, + corrector: vec2text.trainers.Corrector, + num_steps: int = None, + sequence_beam_width: int = 0, + max_length: int = 128, #this is not sufficient to add functionality: corrector and hypothesizer also must be modified +) -> List[str]: + corrector.inversion_trainer.model.eval() + corrector.model.eval() + + gen_kwargs = copy.copy(corrector.gen_kwargs) + gen_kwargs["min_length"] = 1 + gen_kwargs["max_length"] = max_length + + if num_steps is None: + assert ( + sequence_beam_width == 0 + ), "can't set a nonzero beam width without multiple steps" + + regenerated = corrector.inversion_trainer.generate( + inputs={ + "frozen_embeddings": embeddings, + }, + generation_kwargs=gen_kwargs, + ) + else: + corrector.return_best_hypothesis = sequence_beam_width > 0 + regenerated = corrector.generate( + inputs={ + "frozen_embeddings": embeddings, + }, + generation_kwargs=gen_kwargs, + num_recursive_steps=num_steps, + sequence_beam_width=sequence_beam_width, + ) + + output_strings = corrector.tokenizer.batch_decode( + regenerated, skip_special_tokens=True + ) + return output_strings + + +def invert_embeddings_and_return_hypotheses( + embeddings: torch.Tensor, + corrector: vec2text.trainers.Corrector, + num_steps: int = None, + sequence_beam_width: int = 0, +) -> List[str]: + corrector.inversion_trainer.model.eval() + corrector.model.eval() + + gen_kwargs = copy.copy(corrector.gen_kwargs) + gen_kwargs["min_length"] = 1 + gen_kwargs["max_length"] = 128 + + corrector.return_best_hypothesis = sequence_beam_width > 0 + + regenerated, hypotheses = corrector.generate_with_hypotheses( + inputs={ + "frozen_embeddings": embeddings, + }, + generation_kwargs=gen_kwargs, + num_recursive_steps=num_steps, + sequence_beam_width=sequence_beam_width, + ) + + output_strings = [] + for hypothesis in regenerated: + output_strings.append( + corrector.tokenizer.batch_decode(hypothesis, skip_special_tokens=True) + ) + + return output_strings, hypotheses + + +def invert_strings( + strings: List[str], + corrector: vec2text.trainers.Corrector, + num_steps: int = None, + sequence_beam_width: int = 0, +) -> List[str]: + inputs = corrector.embedder_tokenizer( + strings, + return_tensors="pt", + max_length=128, + truncation=True, + padding="max_length", + ) + inputs = inputs.to(device) + with torch.no_grad(): + frozen_embeddings = corrector.inversion_trainer.call_embedding_model( + input_ids=inputs.input_ids, + attention_mask=inputs.attention_mask, + ) + return invert_embeddings( + embeddings=frozen_embeddings, + corrector=corrector, + num_steps=num_steps, + sequence_beam_width=sequence_beam_width, + ) diff --git a/vec2text/models/corrector_encoder_from_logits.py b/vec2text/models/corrector_encoder_from_logits.py index cdabcebd..c05f7caa 100644 --- a/vec2text/models/corrector_encoder_from_logits.py +++ b/vec2text/models/corrector_encoder_from_logits.py @@ -7,6 +7,7 @@ from vec2text.models.config import InversionConfig from .corrector_encoder import CorrectorEncoderModel +from vec2text.models.model_utils import device class CorrectorEncoderFromLogitsModel(CorrectorEncoderModel): diff --git a/vec2text/run_args.py b/vec2text/run_args.py index 198ea4e6..6b663d89 100644 --- a/vec2text/run_args.py +++ b/vec2text/run_args.py @@ -371,6 +371,7 @@ def __post_init__(self): ) self.dataloader_pin_memory = True num_workers = torch.cuda.device_count() + os.environ["RAYON_RS_NUM_CPUS"] = str( num_workers ) # Sets threads for hf tokenizers