Skip to content

Commit

Permalink
fixes to make examples/qasper work
Browse files Browse the repository at this point in the history
  • Loading branch information
soldni committed Sep 22, 2023
1 parent 5bebfe0 commit e263874
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 22 deletions.
24 changes: 13 additions & 11 deletions examples/qasper.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,23 @@ def transform(self, data: TransformElementType) -> TransformElementType:
def main():
dataset = load_dataset("qasper", split="validation")

template = """
{{title}}{{abs_sep}}
{{abstract}}{{abs_sep}}
{% for i in range(full_text['section_name'] | length) %}
{{full_text['section_name'][i]}}{{title_sep}}
{% for paragraph in full_text['paragraphs'][i] %}
{{paragraph}}{{para_sep}}
{% endfor %}
{{sec_sep}}
{% endfor %}
"""

pipeline = (
# concatenate the full text into a single string; use
# title_sep, para_sep, sec_sep, and abs_sep to manage separators
sm.JinjaMapper(
jinja=(
"{{title}}{{abs_sep}}"
"{{abstract}}{{abs_sep}}"
"{% for i in range(full_text['section_name'] | length) %}"
"{{full_text['section_name'][i]}}{{title_sep}}"
"{% for paragraph in full_text['paragraphs'][i] %}"
"{{paragraph}}{{para_sep}}"
"{% endfor %}"
"{{sec_sep}}"
"{% endfor %}"
),
jinja=template,
source_field_name="context",
extra_variables={
"title_sep": "\n",
Expand Down
22 changes: 11 additions & 11 deletions src/smashed/mappers/promptsource.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
Sequence,
Set,
Tuple,
Type,
Union,
cast,
)
Expand Down Expand Up @@ -49,13 +48,11 @@ class JinjaEnvironment:
_env: Optional["Environment"] = None

@classmethod
def env(cls, loader: Optional[Type["BaseLoader"]] = None) -> "Environment":
def env(cls, loader: Optional["BaseLoader"] = None) -> "Environment":
if cls._env is not None:
return cls._env

cls._env = Environment(
loader=(loader or BaseLoader) # pyright: ignore
)
cls._env = Environment(loader=loader)
return cls._env

@classmethod
Expand Down Expand Up @@ -143,13 +140,16 @@ def approx_input_fields(self) -> Tuple[Set[str], ...]:
since we can't parse out cases where for loops or if statements
are used, nor cases where members of a variable are accessed.
"""

# we compute variables first because some might
all_variables = set(
field
for field in self.get_vars_from_txt(self.template)
if field not in self.extra_vars
)
return tuple(
set(
field
for field in self.get_vars_from_txt(t)
if field not in self.extra_vars
)
for t in self.template.split("|||")
{v for v in all_variables if v in fragment}
for fragment in self.template.split("|||")
)

@property
Expand Down

0 comments on commit e263874

Please sign in to comment.