diff --git a/src/smashed/mappers/promptsource.py b/src/smashed/mappers/promptsource.py index 5240f5e..65a57fe 100644 --- a/src/smashed/mappers/promptsource.py +++ b/src/smashed/mappers/promptsource.py @@ -141,21 +141,33 @@ def approx_input_fields(self) -> Tuple[Set[str], ...]: are used, nor cases where members of a variable are accessed. """ - # we compute variables first because some might - all_variables = set( - field - for field in self.get_vars_from_txt(self.template) - if field not in self.extra_vars - ) - out = tuple( - { + output = tuple( + set( field - for field in all_variables - if (field in fragment and field not in self.extra_vars) - } - for fragment in self.template.split("|||") + for field in self.get_vars_from_txt(t) + if field not in self.extra_vars + ) + for t in self.template.split("|||") ) - return out + print(output) + + return output + + # # we compute variables first because some might + # all_variables = set( + # field + # for field in self.get_vars_from_txt(self.template) + # if field not in self.extra_vars + # ) + # out = tuple( + # { + # field + # for field in all_variables + # if (field in fragment and field not in self.extra_vars) + # } + # for fragment in self.template.split("|||") + # ) + # return out @property def template_text(self) -> Tuple[str, ...]: diff --git a/tests/test_promptsource_recipe.py b/tests/test_promptsource_recipe.py index 6a07527..6f0e62c 100644 --- a/tests/test_promptsource_recipe.py +++ b/tests/test_promptsource_recipe.py @@ -106,18 +106,6 @@ def test_few_shot_truncation(self): # The fact that the prompt is a bit different from the template # is totally fine: T5 removes multiple spaces, turns newlines into # spaces, and decoding strips the trailing spaces. - - print('\n' + self.tokenizer.decode(mapped_dataset[0]["input_ids"])) - print('-----------') - print( - f"Q: {FEW_SHOT_DATASET[0]['question'][:14].rstrip()} " - f"A: {FEW_SHOT_DATASET[0]['answer'][:14].rstrip()} " - f"Q: {FEW_SHOT_DATASET[1]['question'][:14].rstrip()} " - f"A: {FEW_SHOT_DATASET[1]['answer'][:14].rstrip()} " - f"Q: {FEW_SHOT_DATASET[2]['question'][:14].rstrip()} " - "A:" - ) - return self.assertEqual( self.tokenizer.decode(mapped_dataset[0]["input_ids"]), (