diff --git a/examples/qasper.py b/examples/qasper.py index 3bbe844..1cc34ee 100644 --- a/examples/qasper.py +++ b/examples/qasper.py @@ -45,21 +45,23 @@ def transform(self, data: TransformElementType) -> TransformElementType: def main(): dataset = load_dataset("qasper", split="validation") + template = """ + {{title}}{{abs_sep}} + {{abstract}}{{abs_sep}} + {% for i in range(full_text['section_name'] | length) %} + {{full_text['section_name'][i]}}{{title_sep}} + {% for paragraph in full_text['paragraphs'][i] %} + {{paragraph}}{{para_sep}} + {% endfor %} + {{sec_sep}} + {% endfor %} + """ + pipeline = ( # concatenate the full text into a single string; use # title_sep, para_sep, sec_sep, and abs_sep to manage separators sm.JinjaMapper( - jinja=( - "{{title}}{{abs_sep}}" - "{{abstract}}{{abs_sep}}" - "{% for i in range(full_text['section_name'] | length) %}" - "{{full_text['section_name'][i]}}{{title_sep}}" - "{% for paragraph in full_text['paragraphs'][i] %}" - "{{paragraph}}{{para_sep}}" - "{% endfor %}" - "{{sec_sep}}" - "{% endfor %}" - ), + jinja=template, source_field_name="context", extra_variables={ "title_sep": "\n", diff --git a/src/smashed/mappers/promptsource.py b/src/smashed/mappers/promptsource.py index 38cadc9..0a77f04 100644 --- a/src/smashed/mappers/promptsource.py +++ b/src/smashed/mappers/promptsource.py @@ -11,7 +11,6 @@ Sequence, Set, Tuple, - Type, Union, cast, ) @@ -49,13 +48,11 @@ class JinjaEnvironment: _env: Optional["Environment"] = None @classmethod - def env(cls, loader: Optional[Type["BaseLoader"]] = None) -> "Environment": + def env(cls, loader: Optional["BaseLoader"] = None) -> "Environment": if cls._env is not None: return cls._env - cls._env = Environment( - loader=(loader or BaseLoader) # pyright: ignore - ) + cls._env = Environment(loader=loader) return cls._env @classmethod @@ -143,13 +140,16 @@ def approx_input_fields(self) -> Tuple[Set[str], ...]: since we can't parse out cases where for loops or if statements are used, nor cases where members of a variable are accessed. """ + + # we compute variables first because some might + all_variables = set( + field + for field in self.get_vars_from_txt(self.template) + if field not in self.extra_vars + ) return tuple( - set( - field - for field in self.get_vars_from_txt(t) - if field not in self.extra_vars - ) - for t in self.template.split("|||") + {v for v in all_variables if v in fragment} + for fragment in self.template.split("|||") ) @property