Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mps support #77

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion experiment_scripts/misc/dump_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
)
from tokenize_data import tokenize_function
from utils import emb
from vec2text.models.model_utils import device

num_workers = len(os.sched_getaffinity(0))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def load_model_and_tokenizers(
Expand Down
3 changes: 1 addition & 2 deletions experiment_scripts/misc/dump_embeddings_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
import tqdm
import transformers
from data_helpers import NQ_DEV, load_dpr_corpus

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from vec2text.models.model_utils import device


def entropy__bits(p: float) -> float:
Expand Down
2 changes: 1 addition & 1 deletion experiment_scripts/misc/emb_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

num_workers = len(os.sched_getaffinity(0))
max_seq_length = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from vec2text.models.model_utils import device


def reorder_words_except_padding(
Expand Down
3 changes: 1 addition & 2 deletions experiment_scripts/misc/sentence_closest_words.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
import transformers
from models import InversionModel, load_embedder_and_tokenizer, load_encoder_decoder
from utils import embed_all_tokens

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from vec2text.models.model_utils import device

# embedder_model_name = "dpr"
embedder_model_name = "gtr_base"
Expand Down
2 changes: 1 addition & 1 deletion experiment_scripts/misc/sentence_pair_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from models import InversionModel, load_embedder_and_tokenizer, load_encoder_decoder

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from vec2text.models.model_utils import device


def main():
Expand Down
8 changes: 1 addition & 7 deletions vec2text/analyze_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,8 @@
from vec2text.models.config import InversionConfig
from vec2text.run_args import DataArguments, ModelArguments, TrainingArguments
from vec2text import run_args as run_args
from vec2text.models.model_utils import device

device = torch.device(
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
transformers.logging.set_verbosity_error()

#############################################################################
Expand Down
21 changes: 14 additions & 7 deletions vec2text/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import transformers

import vec2text
from vec2text.models.model_utils import device
from vec2text.models.model_utils import device

SUPPORTED_MODELS = ["text-embedding-ada-002", "gtr-base"]

Expand All @@ -23,17 +23,17 @@ def load_pretrained_corrector(embedder: str) -> vec2text.trainers.Corrector:
if embedder == "text-embedding-ada-002":
inversion_model = vec2text.models.InversionModel.from_pretrained(
"jxm/vec2text__openai_ada002__msmarco__msl128__hypothesizer"
)
).to(device)
model = vec2text.models.CorrectorEncoderModel.from_pretrained(
"jxm/vec2text__openai_ada002__msmarco__msl128__corrector"
)
).to(device)
elif embedder == "gtr-base":
inversion_model = vec2text.models.InversionModel.from_pretrained(
"jxm/gtr__nq__32"
)
).to(device)
model = vec2text.models.CorrectorEncoderModel.from_pretrained(
"jxm/gtr__nq__32__correct"
)
).to(device)
else:
raise NotImplementedError(f"embedder `{embedder}` not implemented")

Expand Down Expand Up @@ -82,13 +82,17 @@ def invert_embeddings(
corrector: vec2text.trainers.Corrector,
num_steps: int = None,
sequence_beam_width: int = 0,
max_length: int = 128,
) -> List[str]:
# Ensure embeddings are on the correct device
embeddings = embeddings.to(device)

corrector.inversion_trainer.model.eval()
corrector.model.eval()

gen_kwargs = copy.copy(corrector.gen_kwargs)
gen_kwargs["min_length"] = 1
gen_kwargs["max_length"] = 128
gen_kwargs["max_length"] = max_length

if num_steps is None:
assert (
Expand Down Expand Up @@ -124,6 +128,9 @@ def invert_embeddings_and_return_hypotheses(
num_steps: int = None,
sequence_beam_width: int = 0,
) -> List[str]:
# Ensure embeddings are on the correct device
embeddings = embeddings.to(device)

corrector.inversion_trainer.model.eval()
corrector.model.eval()

Expand Down Expand Up @@ -175,4 +182,4 @@ def invert_strings(
corrector=corrector,
num_steps=num_steps,
sequence_beam_width=sequence_beam_width,
)
)
179 changes: 179 additions & 0 deletions vec2text/api_old.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import copy
from typing import List

import torch
import transformers

import vec2text
from vec2text.models.model_utils import device

SUPPORTED_MODELS = ["text-embedding-ada-002", "gtr-base"]


def load_pretrained_corrector(embedder: str) -> vec2text.trainers.Corrector:
"""Gets the Corrector object for the given embedder.

For now, we just support inverting OpenAI Ada 002 and gtr-base embeddings; we plan to
expand this support over time.
"""
assert (
embedder in SUPPORTED_MODELS
), f"embedder to invert `{embedder} not in list of supported models: {SUPPORTED_MODELS}`"

if embedder == "text-embedding-ada-002":
inversion_model = vec2text.models.InversionModel.from_pretrained(
"jxm/vec2text__openai_ada002__msmarco__msl128__hypothesizer"
)
model = vec2text.models.CorrectorEncoderModel.from_pretrained(
"jxm/vec2text__openai_ada002__msmarco__msl128__corrector"
)
elif embedder == "gtr-base":
inversion_model = vec2text.models.InversionModel.from_pretrained(
"jxm/gtr__nq__32"
)
model = vec2text.models.CorrectorEncoderModel.from_pretrained(
"jxm/gtr__nq__32__correct"
)
else:
raise NotImplementedError(f"embedder `{embedder}` not implemented")

return load_corrector(inversion_model, model)


def load_corrector(
inversion_model: vec2text.models.InversionModel,
corrector_model: vec2text.models.CorrectorEncoderModel,
) -> vec2text.trainers.Corrector:
"""Load in the inversion and corrector models

Args:
inversion_model (vec2text.models.InversionModel): _description_
corrector_model (vec2text.models.CorrectorEncoderModel): _description_

Returns:
vec2text.trainers.Corrector: Corrector model to invert an embedding back to text
"""

inversion_trainer = vec2text.trainers.InversionTrainer(
model=inversion_model,
train_dataset=None,
eval_dataset=None,
data_collator=transformers.DataCollatorForSeq2Seq(
inversion_model.tokenizer,
label_pad_token_id=-100,
),
)

# backwards compatibility stuff
corrector_model.config.dispatch_batches = None
corrector = vec2text.trainers.Corrector(
model=corrector_model,
inversion_trainer=inversion_trainer,
args=None,
data_collator=vec2text.collator.DataCollatorForCorrection(
tokenizer=inversion_trainer.model.tokenizer
),
)
return corrector


def invert_embeddings(
embeddings: torch.Tensor,
corrector: vec2text.trainers.Corrector,
num_steps: int = None,
sequence_beam_width: int = 0,
max_length: int = 128, #this is not sufficient to add functionality: corrector and hypothesizer also must be modified
) -> List[str]:
corrector.inversion_trainer.model.eval()
corrector.model.eval()

gen_kwargs = copy.copy(corrector.gen_kwargs)
gen_kwargs["min_length"] = 1
gen_kwargs["max_length"] = max_length

if num_steps is None:
assert (
sequence_beam_width == 0
), "can't set a nonzero beam width without multiple steps"

regenerated = corrector.inversion_trainer.generate(
inputs={
"frozen_embeddings": embeddings,
},
generation_kwargs=gen_kwargs,
)
else:
corrector.return_best_hypothesis = sequence_beam_width > 0
regenerated = corrector.generate(
inputs={
"frozen_embeddings": embeddings,
},
generation_kwargs=gen_kwargs,
num_recursive_steps=num_steps,
sequence_beam_width=sequence_beam_width,
)

output_strings = corrector.tokenizer.batch_decode(
regenerated, skip_special_tokens=True
)
return output_strings


def invert_embeddings_and_return_hypotheses(
embeddings: torch.Tensor,
corrector: vec2text.trainers.Corrector,
num_steps: int = None,
sequence_beam_width: int = 0,
) -> List[str]:
corrector.inversion_trainer.model.eval()
corrector.model.eval()

gen_kwargs = copy.copy(corrector.gen_kwargs)
gen_kwargs["min_length"] = 1
gen_kwargs["max_length"] = 128

corrector.return_best_hypothesis = sequence_beam_width > 0

regenerated, hypotheses = corrector.generate_with_hypotheses(
inputs={
"frozen_embeddings": embeddings,
},
generation_kwargs=gen_kwargs,
num_recursive_steps=num_steps,
sequence_beam_width=sequence_beam_width,
)

output_strings = []
for hypothesis in regenerated:
output_strings.append(
corrector.tokenizer.batch_decode(hypothesis, skip_special_tokens=True)
)

return output_strings, hypotheses


def invert_strings(
strings: List[str],
corrector: vec2text.trainers.Corrector,
num_steps: int = None,
sequence_beam_width: int = 0,
) -> List[str]:
inputs = corrector.embedder_tokenizer(
strings,
return_tensors="pt",
max_length=128,
truncation=True,
padding="max_length",
)
inputs = inputs.to(device)
with torch.no_grad():
frozen_embeddings = corrector.inversion_trainer.call_embedding_model(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
)
return invert_embeddings(
embeddings=frozen_embeddings,
corrector=corrector,
num_steps=num_steps,
sequence_beam_width=sequence_beam_width,
)
1 change: 1 addition & 0 deletions vec2text/models/corrector_encoder_from_logits.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from vec2text.models.config import InversionConfig

from .corrector_encoder import CorrectorEncoderModel
from vec2text.models.model_utils import device


class CorrectorEncoderFromLogitsModel(CorrectorEncoderModel):
Expand Down
1 change: 1 addition & 0 deletions vec2text/run_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ def __post_init__(self):
)
self.dataloader_pin_memory = True
num_workers = torch.cuda.device_count()

os.environ["RAYON_RS_NUM_CPUS"] = str(
num_workers
) # Sets threads for hf tokenizers
Expand Down