Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
soldni committed May 24, 2024
1 parent 43afaa5 commit 0cc1e47
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 25 deletions.
38 changes: 25 additions & 13 deletions src/smashed/mappers/promptsource.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,21 +141,33 @@ def approx_input_fields(self) -> Tuple[Set[str], ...]:
are used, nor cases where members of a variable are accessed.
"""

# we compute variables first because some might
all_variables = set(
field
for field in self.get_vars_from_txt(self.template)
if field not in self.extra_vars
)
out = tuple(
{
output = tuple(
set(
field
for field in all_variables
if (field in fragment and field not in self.extra_vars)
}
for fragment in self.template.split("|||")
for field in self.get_vars_from_txt(t)
if field not in self.extra_vars
)
for t in self.template.split("|||")
)
return out
print(output)

return output

# # we compute variables first because some might
# all_variables = set(
# field
# for field in self.get_vars_from_txt(self.template)
# if field not in self.extra_vars
# )
# out = tuple(
# {
# field
# for field in all_variables
# if (field in fragment and field not in self.extra_vars)
# }
# for fragment in self.template.split("|||")
# )
# return out

@property
def template_text(self) -> Tuple[str, ...]:
Expand Down
12 changes: 0 additions & 12 deletions tests/test_promptsource_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,18 +106,6 @@ def test_few_shot_truncation(self):
# The fact that the prompt is a bit different from the template
# is totally fine: T5 removes multiple spaces, turns newlines into
# spaces, and decoding strips the trailing spaces.

print('\n' + self.tokenizer.decode(mapped_dataset[0]["input_ids"]))
print('-----------')
print(
f"Q: {FEW_SHOT_DATASET[0]['question'][:14].rstrip()} "
f"A: {FEW_SHOT_DATASET[0]['answer'][:14].rstrip()} "
f"Q: {FEW_SHOT_DATASET[1]['question'][:14].rstrip()} "
f"A: {FEW_SHOT_DATASET[1]['answer'][:14].rstrip()} "
f"Q: {FEW_SHOT_DATASET[2]['question'][:14].rstrip()} "
"A:</s>"
)
return
self.assertEqual(
self.tokenizer.decode(mapped_dataset[0]["input_ids"]),
(
Expand Down

0 comments on commit 0cc1e47

Please sign in to comment.