From c4ca039425e200c13a24c22f3e5590f2b799ac24 Mon Sep 17 00:00:00 2001 From: Jean Mercat Date: Wed, 24 Jul 2024 16:11:42 -0700 Subject: [PATCH 1/7] loading checkpoint with _orig_mod. name prefix --- eval/eval_openlm_ckpt.py | 1 + open_lm/utils/llm_foundry_wrapper.py | 18 ++++++++++++++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/eval/eval_openlm_ckpt.py b/eval/eval_openlm_ckpt.py index 0bfe7e06..1ce519d3 100644 --- a/eval/eval_openlm_ckpt.py +++ b/eval/eval_openlm_ckpt.py @@ -123,6 +123,7 @@ def main(): state_dict = checkpoint["state_dict"] state_dict = {x.replace("module.", ""): y for x, y in state_dict.items()} + state_dict = {x.replace("_orig_mod.", ""): y for x, y in state_dict.items()} open_lm.model.load_state_dict(state_dict) open_lm.model.eval() diff --git a/open_lm/utils/llm_foundry_wrapper.py b/open_lm/utils/llm_foundry_wrapper.py index f4f14e79..62bb6adc 100644 --- a/open_lm/utils/llm_foundry_wrapper.py +++ b/open_lm/utils/llm_foundry_wrapper.py @@ -3,7 +3,7 @@ """Implements a Hugging Causal LM wrapped inside a :class:`.ComposerModel`.""" -from typing import Mapping, Union +from typing import Mapping, Union, List from llmfoundry.eval.metrics.nlp import ( InContextLearningLMAccuracy, InContextLearningLMExpectedCalibrationError, @@ -16,6 +16,8 @@ LanguagePerplexity, ) from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from transformers import StoppingCriteria, StoppingCriteriaList +import torch from composer.models.huggingface import HuggingFaceModel @@ -38,6 +40,13 @@ InContextLearningMCExpectedCalibrationError(), ] +class CustomStopTokensCriteria(StoppingCriteria): + def __init__(self, stop_tokens: List[str]) -> None: + self.stop_tokens = stop_tokens + + def __call__(self, generated_tokens: torch.Tensor, *args, **kwargs) -> bool: + return any(token in self.stop_tokens for token in generated_tokens.flatten()) + class SimpleComposerOpenLMCausalLM(HuggingFaceModel): def __init__(self, model, tokenizer): @@ -50,4 +59,9 @@ def __init__(self, model, tokenizer): ) def generate(self, input_ids=None, inputs_embeds=None, **kwargs): - return super().generate(input_ids=input_ids, **kwargs) + stop_token = self.tokenizer.eos_token_id + stop_criteria = CustomStopTokensCriteria([stop_token]) + stop_criteria_list = StoppingCriteriaList([stop_criteria]) + if "stopping_criteria" in kwargs: + stop_criteria_list += kwargs.pop("stopping_criteria") + return super().generate(input_ids=input_ids, stopping_criteria=stop_criteria_list, **kwargs) From def4675aa64c6e059cf4a154b2bd3141788764ef Mon Sep 17 00:00:00 2001 From: Jean Mercat Date: Wed, 24 Jul 2024 16:12:26 -0700 Subject: [PATCH 2/7] added open_lm_1b_swiglutorch json --- open_lm/model_configs/open_lm_1b_swiglutorch.json | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 open_lm/model_configs/open_lm_1b_swiglutorch.json diff --git a/open_lm/model_configs/open_lm_1b_swiglutorch.json b/open_lm/model_configs/open_lm_1b_swiglutorch.json new file mode 100644 index 00000000..ec98b795 --- /dev/null +++ b/open_lm/model_configs/open_lm_1b_swiglutorch.json @@ -0,0 +1,14 @@ +{ + "hidden_dim": 2048, + "n_layers": 24, + "n_heads": 16, + "seq_len": 2048, + "vocab_size": 50432, + "post_embed_norm": false, + "weight_tying": false, + "model_norm": "gain_only_lp_layer_norm", + "norm_type": "gain_only_lp_layer_norm", + "ffn_type": "swiglu_torch", + "qk_norm": true, + "positional_embedding_type": "rotary" +} \ No newline at end of file From 0180973fa6b66a9cfa87e8d3f54cf8f99824c437 Mon Sep 17 00:00:00 2001 From: Jean Mercat Date: Tue, 30 Jul 2024 19:34:23 -0700 Subject: [PATCH 3/7] hacky removal of appended 0s in OpenLMforCausalLM --- eval/eval_openlm_ckpt.py | 3 ++- open_lm/utils/llm_foundry_wrapper.py | 18 ++---------------- open_lm/utils/transformers/hf_model.py | 22 ++++++++++++++++++++++ 3 files changed, 26 insertions(+), 17 deletions(-) diff --git a/eval/eval_openlm_ckpt.py b/eval/eval_openlm_ckpt.py index 1ce519d3..478cd28b 100644 --- a/eval/eval_openlm_ckpt.py +++ b/eval/eval_openlm_ckpt.py @@ -46,8 +46,9 @@ def evaluate(model, tokenizer, cfg): composer_model = SimpleComposerOpenLMCausalLM(model, tokenizer) + cfg_icl_tasks = [dict(i) for i in cfg.icl_tasks] evaluators, logger_keys = build_icl_evaluators( - cfg.icl_tasks, tokenizer, cfg.max_seq_len, cfg.device_eval_batch_size + cfg_icl_tasks, tokenizer, cfg.max_seq_len, cfg.device_eval_batch_size ) in_memory_logger = InMemoryLogger() # track metrics in the in_memory_logger diff --git a/open_lm/utils/llm_foundry_wrapper.py b/open_lm/utils/llm_foundry_wrapper.py index 62bb6adc..a78cb350 100644 --- a/open_lm/utils/llm_foundry_wrapper.py +++ b/open_lm/utils/llm_foundry_wrapper.py @@ -3,7 +3,7 @@ """Implements a Hugging Causal LM wrapped inside a :class:`.ComposerModel`.""" -from typing import Mapping, Union, List +from typing import Union from llmfoundry.eval.metrics.nlp import ( InContextLearningLMAccuracy, InContextLearningLMExpectedCalibrationError, @@ -16,8 +16,6 @@ LanguagePerplexity, ) from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast -from transformers import StoppingCriteria, StoppingCriteriaList -import torch from composer.models.huggingface import HuggingFaceModel @@ -40,13 +38,6 @@ InContextLearningMCExpectedCalibrationError(), ] -class CustomStopTokensCriteria(StoppingCriteria): - def __init__(self, stop_tokens: List[str]) -> None: - self.stop_tokens = stop_tokens - - def __call__(self, generated_tokens: torch.Tensor, *args, **kwargs) -> bool: - return any(token in self.stop_tokens for token in generated_tokens.flatten()) - class SimpleComposerOpenLMCausalLM(HuggingFaceModel): def __init__(self, model, tokenizer): @@ -59,9 +50,4 @@ def __init__(self, model, tokenizer): ) def generate(self, input_ids=None, inputs_embeds=None, **kwargs): - stop_token = self.tokenizer.eos_token_id - stop_criteria = CustomStopTokensCriteria([stop_token]) - stop_criteria_list = StoppingCriteriaList([stop_criteria]) - if "stopping_criteria" in kwargs: - stop_criteria_list += kwargs.pop("stopping_criteria") - return super().generate(input_ids=input_ids, stopping_criteria=stop_criteria_list, **kwargs) + return super().generate(input_ids=input_ids, **kwargs) diff --git a/open_lm/utils/transformers/hf_model.py b/open_lm/utils/transformers/hf_model.py index 83353a19..75c0a585 100644 --- a/open_lm/utils/transformers/hf_model.py +++ b/open_lm/utils/transformers/hf_model.py @@ -98,6 +98,20 @@ def forward( ```""" assert position_ids is None, "Position IDs are not supported" # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + + # If input_ids are all 0 after a certain point, we can skip that + num_zeros = 0 + if input_ids is not None: + # Find indices of the last non-zero element in each row + last_nonzero_indices = torch.max(torch.nonzero(input_ids, as_tuple=True)[1]) + if last_nonzero_indices < input_ids.shape[1] - 1: + num_zeros = input_ids.shape[1] - last_nonzero_indices - 1 + input_ids = input_ids[:, : last_nonzero_indices + 1] + if attention_mask is not None: + attention_mask = attention_mask[:, : last_nonzero_indices + 1] + if labels is not None: + labels = labels[:, : last_nonzero_indices + 1] + logits, _, past_key_values = self.model( input_ids=input_ids, inputs_embeds=inputs_embeds, @@ -105,6 +119,7 @@ def forward( use_cache=use_cache, attention_mask=attention_mask, ) + loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() @@ -115,6 +130,13 @@ def forward( loss = loss_fct(shift_logits, shift_labels) output = CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values, loss=loss) + + # Add back the 0 that we removed + if num_zeros: + fake_logits = torch.zeros(logits.shape[0], num_zeros, logits.shape[2], device=logits.device) + fake_logits[:, :, 0] = 1 + output.logits = torch.cat([output.logits, fake_logits], dim=1) + return output def prepare_inputs_for_generation( From 8c21c1c6cb0124964faa6f7450ec5fc7f6148c10 Mon Sep 17 00:00:00 2001 From: Jean Mercat Date: Tue, 30 Jul 2024 20:18:58 -0700 Subject: [PATCH 4/7] remove zeros in llm_foundry_wrapper to accelerate eval --- eval/in_memory_hf_eval.yaml | 2 +- open_lm/utils/llm_foundry_wrapper.py | 99 +++++++++++++++++++++++++- open_lm/utils/transformers/hf_model.py | 19 ----- 3 files changed, 99 insertions(+), 21 deletions(-) diff --git a/eval/in_memory_hf_eval.yaml b/eval/in_memory_hf_eval.yaml index 1b6e6f78..11819ad8 100644 --- a/eval/in_memory_hf_eval.yaml +++ b/eval/in_memory_hf_eval.yaml @@ -31,7 +31,7 @@ fsdp_config: icl_tasks: - label: mmlu - dataset_uri: local_data/mmlu.jsonl # ADD YOUR OWN DATASET URI + dataset_uri: eval/local_data/mmlu.jsonl # ADD YOUR OWN DATASET URI num_fewshot: [0] icl_task_type: multiple_choice continuation_delimiter: 'Answer: ' # this separates questions from answers diff --git a/open_lm/utils/llm_foundry_wrapper.py b/open_lm/utils/llm_foundry_wrapper.py index a78cb350..d4094ce3 100644 --- a/open_lm/utils/llm_foundry_wrapper.py +++ b/open_lm/utils/llm_foundry_wrapper.py @@ -3,7 +3,7 @@ """Implements a Hugging Causal LM wrapped inside a :class:`.ComposerModel`.""" -from typing import Union +from typing import Union, Optional, Any from llmfoundry.eval.metrics.nlp import ( InContextLearningLMAccuracy, InContextLearningLMExpectedCalibrationError, @@ -16,6 +16,8 @@ LanguagePerplexity, ) from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +import torch +from torch import dist from composer.models.huggingface import HuggingFaceModel @@ -51,3 +53,98 @@ def __init__(self, model, tokenizer): def generate(self, input_ids=None, inputs_embeds=None, **kwargs): return super().generate(input_ids=input_ids, **kwargs) + + def eval_forward(self, batch, outputs: Optional[Any] = None): + + # If input_ids are all 0 after a certain point, we can skip that + num_zeros = 0 + if batch['input_ids'] is not None: + # Find indices of the last non-zero element in each row + last_nonzero_indices = torch.max(torch.nonzero(batch['input_ids'], as_tuple=True)[1]) + if last_nonzero_indices < batch['input_ids'].shape[1] - 1: + num_zeros = batch['input_ids'].shape[1] - last_nonzero_indices - 1 + batch['input_ids'] = batch['input_ids'][:, : last_nonzero_indices + 1] + if batch['attention_mask'] is not None: + batch['attention_mask'] = batch['attention_mask'][:, : last_nonzero_indices + 1] + if batch['labels'] is not None: + batch['labels'] = batch['labels'][:, : last_nonzero_indices + 1] + + # If the batch mode is generate, we will generate a requested number of tokens using the underlying + # model's generate function. Extra generation kwargs can be passed in via the batch. Strings will + # be returned from eval_forward + if batch.get('mode', None) == 'generate': + if self.tokenizer is None: + raise ValueError( + 'Generation eval cannot be used without providing a tokenizer to the model constructor.', + ) + + self.labels = batch.pop('labels') + generation = self.generate( + batch['input_ids'], + attention_mask=batch['attention_mask'], + synced_gpus=dist.get_world_size() > 1, + **batch.get('generation_kwargs', {}), + ) + + # don't remove prefix space to sentencepiece models + if len( + self.tokenizer(' a', add_special_tokens=False)['input_ids'], # pyright: ignore[reportGeneralTypeIssues] + ) == 1: + return self.tokenizer.batch_decode( + generation[:, batch['input_ids'].shape[1]:], + skip_special_tokens=True, + ) + else: + return [ + ' ' + generation for generation in + self.tokenizer.batch_decode(generation[:, batch['input_ids'].shape[1]:], skip_special_tokens=True) + ] + + if self.use_logits or batch.get('mode', None) == 'icl_task': + # pop labels first to avoid computing loss + self.labels = batch.pop('labels') + + # HF encoder decoder models like T5 expect either decoder_input_ids or labels, + # so we add decoder_input_ids to the batch if it is missing + if self.config.is_encoder_decoder and 'decoder_input_ids' not in batch: + if hasattr(self.model, 'prepare_decoder_input_ids_from_labels'): + batch['decoder_input_ids'] = self.model.prepare_decoder_input_ids_from_labels(labels=self.labels) + else: + raise RuntimeError( + 'Encoder decoder models require that either decoder_input_ids is present in the batch' + ' or that the model has a prepare_decoder_input_ids_from_labels method.', + ) + + if self.shift_labels or batch.get('mode', None) == 'icl_task': + assert self.labels is not None + # HF CausalLM models internally shift labels before computing loss, so we do the same here + self.labels[:, :-1] = self.labels[:, 1:].clone() + self.labels[:, -1] = -100 + + output = outputs if outputs else self.forward(batch) + + if self.config.use_return_dict: + output = output['logits'] + else: + # if loss was computed (cached outputs from forward), loss is at index 0 and logits are at index 1 + # if loss was not computed (no cached outputs during eval), loss is not present and logits are at index 0 + output = output[1] if len(output[0].shape) == 0 else output[0] + + # if we are in the single class case, then remove the classes dimension + if output.ndim == 2 and output.shape[1] == 1: + output = output.squeeze(dim=1) + else: + output = outputs if outputs else self.forward(batch) + + # Add back the 0 that we removed + if num_zeros: + if hasattr(output, 'logits') or isinstance(output, dict): + fake_logits = torch.zeros(output['logits'].shape[0], num_zeros, output['logits'].shape[2], device=output['logits'].device) + fake_logits[:, :, 0] = 1 + output['logits'] = torch.cat([output['logits'], fake_logits], dim=1) + else: + fake_logits = torch.zeros(output.shape[0], num_zeros, output.shape[2], device=output.device) + fake_logits[:, :, 0] = 1 + output = torch.cat([output, fake_logits], dim=1) + + return output diff --git a/open_lm/utils/transformers/hf_model.py b/open_lm/utils/transformers/hf_model.py index 75c0a585..a2beb9ad 100644 --- a/open_lm/utils/transformers/hf_model.py +++ b/open_lm/utils/transformers/hf_model.py @@ -99,19 +99,6 @@ def forward( assert position_ids is None, "Position IDs are not supported" # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - # If input_ids are all 0 after a certain point, we can skip that - num_zeros = 0 - if input_ids is not None: - # Find indices of the last non-zero element in each row - last_nonzero_indices = torch.max(torch.nonzero(input_ids, as_tuple=True)[1]) - if last_nonzero_indices < input_ids.shape[1] - 1: - num_zeros = input_ids.shape[1] - last_nonzero_indices - 1 - input_ids = input_ids[:, : last_nonzero_indices + 1] - if attention_mask is not None: - attention_mask = attention_mask[:, : last_nonzero_indices + 1] - if labels is not None: - labels = labels[:, : last_nonzero_indices + 1] - logits, _, past_key_values = self.model( input_ids=input_ids, inputs_embeds=inputs_embeds, @@ -131,12 +118,6 @@ def forward( output = CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values, loss=loss) - # Add back the 0 that we removed - if num_zeros: - fake_logits = torch.zeros(logits.shape[0], num_zeros, logits.shape[2], device=logits.device) - fake_logits[:, :, 0] = 1 - output.logits = torch.cat([output.logits, fake_logits], dim=1) - return output def prepare_inputs_for_generation( From a3896c5b91b802f7ec0d7d5b109b1715e124e187 Mon Sep 17 00:00:00 2001 From: Jean Mercat Date: Tue, 30 Jul 2024 20:22:34 -0700 Subject: [PATCH 5/7] use super instead of copy paste --- open_lm/utils/llm_foundry_wrapper.py | 68 +--------------------------- 1 file changed, 1 insertion(+), 67 deletions(-) diff --git a/open_lm/utils/llm_foundry_wrapper.py b/open_lm/utils/llm_foundry_wrapper.py index d4094ce3..fff6b200 100644 --- a/open_lm/utils/llm_foundry_wrapper.py +++ b/open_lm/utils/llm_foundry_wrapper.py @@ -55,7 +55,6 @@ def generate(self, input_ids=None, inputs_embeds=None, **kwargs): return super().generate(input_ids=input_ids, **kwargs) def eval_forward(self, batch, outputs: Optional[Any] = None): - # If input_ids are all 0 after a certain point, we can skip that num_zeros = 0 if batch['input_ids'] is not None: @@ -69,72 +68,7 @@ def eval_forward(self, batch, outputs: Optional[Any] = None): if batch['labels'] is not None: batch['labels'] = batch['labels'][:, : last_nonzero_indices + 1] - # If the batch mode is generate, we will generate a requested number of tokens using the underlying - # model's generate function. Extra generation kwargs can be passed in via the batch. Strings will - # be returned from eval_forward - if batch.get('mode', None) == 'generate': - if self.tokenizer is None: - raise ValueError( - 'Generation eval cannot be used without providing a tokenizer to the model constructor.', - ) - - self.labels = batch.pop('labels') - generation = self.generate( - batch['input_ids'], - attention_mask=batch['attention_mask'], - synced_gpus=dist.get_world_size() > 1, - **batch.get('generation_kwargs', {}), - ) - - # don't remove prefix space to sentencepiece models - if len( - self.tokenizer(' a', add_special_tokens=False)['input_ids'], # pyright: ignore[reportGeneralTypeIssues] - ) == 1: - return self.tokenizer.batch_decode( - generation[:, batch['input_ids'].shape[1]:], - skip_special_tokens=True, - ) - else: - return [ - ' ' + generation for generation in - self.tokenizer.batch_decode(generation[:, batch['input_ids'].shape[1]:], skip_special_tokens=True) - ] - - if self.use_logits or batch.get('mode', None) == 'icl_task': - # pop labels first to avoid computing loss - self.labels = batch.pop('labels') - - # HF encoder decoder models like T5 expect either decoder_input_ids or labels, - # so we add decoder_input_ids to the batch if it is missing - if self.config.is_encoder_decoder and 'decoder_input_ids' not in batch: - if hasattr(self.model, 'prepare_decoder_input_ids_from_labels'): - batch['decoder_input_ids'] = self.model.prepare_decoder_input_ids_from_labels(labels=self.labels) - else: - raise RuntimeError( - 'Encoder decoder models require that either decoder_input_ids is present in the batch' - ' or that the model has a prepare_decoder_input_ids_from_labels method.', - ) - - if self.shift_labels or batch.get('mode', None) == 'icl_task': - assert self.labels is not None - # HF CausalLM models internally shift labels before computing loss, so we do the same here - self.labels[:, :-1] = self.labels[:, 1:].clone() - self.labels[:, -1] = -100 - - output = outputs if outputs else self.forward(batch) - - if self.config.use_return_dict: - output = output['logits'] - else: - # if loss was computed (cached outputs from forward), loss is at index 0 and logits are at index 1 - # if loss was not computed (no cached outputs during eval), loss is not present and logits are at index 0 - output = output[1] if len(output[0].shape) == 0 else output[0] - - # if we are in the single class case, then remove the classes dimension - if output.ndim == 2 and output.shape[1] == 1: - output = output.squeeze(dim=1) - else: - output = outputs if outputs else self.forward(batch) + output = super().eval_forward(batch, outputs) # Add back the 0 that we removed if num_zeros: From 0164bacb0ed1a1cab48725be731653b27d329ae7 Mon Sep 17 00:00:00 2001 From: Jean Mercat Date: Wed, 31 Jul 2024 10:56:47 -0700 Subject: [PATCH 6/7] linted --- open_lm/utils/llm_foundry_wrapper.py | 30 +++++++++++++++------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/open_lm/utils/llm_foundry_wrapper.py b/open_lm/utils/llm_foundry_wrapper.py index fff6b200..e21ed94b 100644 --- a/open_lm/utils/llm_foundry_wrapper.py +++ b/open_lm/utils/llm_foundry_wrapper.py @@ -53,29 +53,31 @@ def __init__(self, model, tokenizer): def generate(self, input_ids=None, inputs_embeds=None, **kwargs): return super().generate(input_ids=input_ids, **kwargs) - + def eval_forward(self, batch, outputs: Optional[Any] = None): # If input_ids are all 0 after a certain point, we can skip that num_zeros = 0 - if batch['input_ids'] is not None: + if batch["input_ids"] is not None: # Find indices of the last non-zero element in each row - last_nonzero_indices = torch.max(torch.nonzero(batch['input_ids'], as_tuple=True)[1]) - if last_nonzero_indices < batch['input_ids'].shape[1] - 1: - num_zeros = batch['input_ids'].shape[1] - last_nonzero_indices - 1 - batch['input_ids'] = batch['input_ids'][:, : last_nonzero_indices + 1] - if batch['attention_mask'] is not None: - batch['attention_mask'] = batch['attention_mask'][:, : last_nonzero_indices + 1] - if batch['labels'] is not None: - batch['labels'] = batch['labels'][:, : last_nonzero_indices + 1] + last_nonzero_indices = torch.max(torch.nonzero(batch["input_ids"], as_tuple=True)[1]) + if last_nonzero_indices < batch["input_ids"].shape[1] - 1: + num_zeros = batch["input_ids"].shape[1] - last_nonzero_indices - 1 + batch["input_ids"] = batch["input_ids"][:, : last_nonzero_indices + 1] + if batch["attention_mask"] is not None: + batch["attention_mask"] = batch["attention_mask"][:, : last_nonzero_indices + 1] + if batch["labels"] is not None: + batch["labels"] = batch["labels"][:, : last_nonzero_indices + 1] output = super().eval_forward(batch, outputs) # Add back the 0 that we removed - if num_zeros: - if hasattr(output, 'logits') or isinstance(output, dict): - fake_logits = torch.zeros(output['logits'].shape[0], num_zeros, output['logits'].shape[2], device=output['logits'].device) + if num_zeros: + if hasattr(output, "logits") or isinstance(output, dict): + fake_logits = torch.zeros( + output["logits"].shape[0], num_zeros, output["logits"].shape[2], device=output["logits"].device + ) fake_logits[:, :, 0] = 1 - output['logits'] = torch.cat([output['logits'], fake_logits], dim=1) + output["logits"] = torch.cat([output["logits"], fake_logits], dim=1) else: fake_logits = torch.zeros(output.shape[0], num_zeros, output.shape[2], device=output.device) fake_logits[:, :, 0] = 1 From e094417e020bf1ca83a574881ecbf0a54ec689cc Mon Sep 17 00:00:00 2001 From: jmercat Date: Sat, 3 Aug 2024 14:53:15 -0700 Subject: [PATCH 7/7] Update in_memory_hf_eval.yaml --- eval/in_memory_hf_eval.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eval/in_memory_hf_eval.yaml b/eval/in_memory_hf_eval.yaml index 11819ad8..1b6e6f78 100644 --- a/eval/in_memory_hf_eval.yaml +++ b/eval/in_memory_hf_eval.yaml @@ -31,7 +31,7 @@ fsdp_config: icl_tasks: - label: mmlu - dataset_uri: eval/local_data/mmlu.jsonl # ADD YOUR OWN DATASET URI + dataset_uri: local_data/mmlu.jsonl # ADD YOUR OWN DATASET URI num_fewshot: [0] icl_task_type: multiple_choice continuation_delimiter: 'Answer: ' # this separates questions from answers