Skip to content

Commit 07ad82b

Browse files
committed
protection
1 parent a0fcc33 commit 07ad82b

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

src/stopping.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -180,14 +180,14 @@ def get_stopping(prompt_type, prompt_dict, tokenizer, device, base_model,
180180
stop_words_ids = [x if len(x.shape) > 0 else torch.tensor([x]) for x in stop_words_ids]
181181
stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0]
182182
# avoid padding in front of tokens
183-
if tokenizer._pad_token: # use hidden variable to avoid annoying properly logger bug
183+
if hasattr(tokenizer, '_pad_token') and tokenizer._pad_token: # use hidden variable to avoid annoying properly logger bug
184184
stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids]
185-
if tokenizer._unk_token: # use hidden variable to avoid annoying properly logger bug
185+
if hasattr(tokenizer, '_unk_token') and tokenizer._unk_token: # use hidden variable to avoid annoying properly logger bug
186186
stop_words_ids = [x[1:] if x[0] == tokenizer.unk_token_id and len(x) > 1 else x for x in stop_words_ids]
187187
stop_words_ids = [x[:-1] if x[-1] == tokenizer.unk_token_id and len(x) > 1 else x for x in stop_words_ids]
188-
if tokenizer._eos_token: # use hidden variable to avoid annoying properly logger bug
188+
if hasattr(tokenizer, '_eos_token') and tokenizer._eos_token: # use hidden variable to avoid annoying properly logger bug
189189
stop_words_ids = [x[:-1] if x[-1] == tokenizer.eos_token_id and len(x) > 1 else x for x in stop_words_ids]
190-
if tokenizer._bos_token: # use hidden variable to avoid annoying properly logger bug
190+
if hasattr(tokenizer, '_bos_token') and tokenizer._bos_token: # use hidden variable to avoid annoying properly logger bug
191191
stop_words_ids = [x[1:] if x[0] == tokenizer.bos_token_id and len(x) > 1 else x for x in stop_words_ids]
192192
stop_words_ids = [x[:-1] if x[-1] == tokenizer.bos_token_id and len(x) > 1 else x for x in stop_words_ids]
193193
if base_model and t5_type(base_model) and hasattr(tokenizer, 'vocab'):

src/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "2e830cc79a0bb6a7044e0794fec0ba30f4063f0f"
1+
__version__ = "a0fcc3344d53a834fe3cb5b26265aaeb84993b77"

0 commit comments

Comments
 (0)