Skip to content

Commit

Permalink
Adds template for translation tasks (#391)
Browse files Browse the repository at this point in the history
* implement tranlsation prompt

* add small coment about tranlsation prompt

* change formatting to reformat language dependant  parts

---------

Co-authored-by: Clémentine Fourrier <[email protected]>
  • Loading branch information
2 people authored and Hynek Kydlicek committed Nov 26, 2024
1 parent 6af5280 commit f235abc
Show file tree
Hide file tree
Showing 3 changed files with 286 additions and 3 deletions.
13 changes: 10 additions & 3 deletions src/lighteval/tasks/templates/continuation.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def get_continuation_prompt_function(
language: Language,
adapter: Callable[[dict], ContinuationInput | None] | ContinuationDictAdapter,
formulation: Formulation = MCFFormulation(),
fix_formatting: bool = True,
):
"""
Create a templated prompt function for a Continuation task.
Expand Down Expand Up @@ -120,6 +121,7 @@ def get_continuation_prompt_function(
adapter (Callable[[dict], ContinuationInput] | ContinuationDictAdapter): Either a function that takes a dataset row and returns a ContinuationInput, or a dictionary with keys corresponding to the field names in the dataset row.
Note: Both ContinuationDictAdapter and ContinuationInput are TypeDicts, this means that the caller provides dictionary and doesn't initialize any class!
formulation (Formulation, optional): The formulation (MCF/Hybrid/CF) to use for the task. Defaults to MCFFormulation().
fix_formatting (bool, optional): Whether to fix the formatting of the text by capitalizing and fixing punctuation based on language. If False, the text will be used as-is. Defaults to True.
Returns:
Callable: A function that generates Continuation prompt based on the given parameters.
"""
Expand All @@ -134,12 +136,17 @@ def prepare_prompt(line: dict):
instruction_val = cont_input.get("instruction")
instruction = f"{instruction_val}\n" if instruction_val else ""

context = f"{capitalize(fix_ending_punct(cont_input['context'], translation_literals))}"
continuations = cont_input["continuations"]
context = (
f"{capitalize(fix_ending_punct(cont_input['context'], translation_literals))}"
if fix_formatting
else cont_input["context"]
)

continuations = [
fix_capitalization(context, fix_ending_punct(continuation, translation_literals), translation_literals)
for continuation in continuations
if fix_formatting
else continuation
for continuation in cont_input["continuations"]
]

return cont_input, instruction, context, continuations
Expand Down
156 changes: 156 additions & 0 deletions src/lighteval/tasks/templates/translation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# MIT License

# Copyright (c) 2024 The HuggingFace Team

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from typing import Callable

from langcodes import standardize_tag
from typing_extensions import NotRequired, TypedDict

from lighteval.tasks.templates.continuation import get_continuation_prompt_function
from lighteval.tasks.templates.multichoice import create_adapter_from_dict
from lighteval.tasks.templates.utils.formatting_utils import capitalize, fix_ending_punct
from lighteval.tasks.templates.utils.formulation import Formulation, MCFFormulation
from lighteval.tasks.templates.utils.translation_literals import TRANSLATION_LITERALS
from lighteval.utils.language import Language
from lighteval.utils.utils import as_list


# Template chosen so that it's not very language-dependent, as it's not clear whether one should use the target or source language.
# It's also the best template based on https://arxiv.org/pdf/2301.07069.


TRANSLATION_CONTEXT = "{source_label}{colon}{sentence_space}{source_text}{sentence_space}{target_label}{colon}"


# Defined for type hinting only
class TranslationInput(TypedDict):
"""
Input for the Translation task.
Args:
source_text: The source text to be translated
target_text: The target text to be translated
instruction (optional): The instruction of the Translation task (e.g. Translate the following text to Turkish)
"""

source_text: str
target_text: str | list[str]
gold_idx: NotRequired[int | list[int]]
instruction: NotRequired[str]


class TranslationAdapter(TypedDict):
"""
Adapter for mapping from the dataset row into the TranslationInput format.
Args:
source_text: Column name in the row that contains the source text to be translated
target_text: Column name in the row that contains the target text to be translated
instruction (optional): Column name in the row that contains the instruction of the task (e.g. Translate the following text to Turkish)
"""

source_text: str
target_text: str
gold_idx: NotRequired[int | list[int]]
instruction: NotRequired[str]


def get_translation_prompt_function(
source_language: Language,
target_language: Language,
adapter: Callable[[dict], TranslationInput | None] | TranslationAdapter,
formulation: Formulation = MCFFormulation(),
):
"""
Create a templated prompt function for a Translation task.
Example tasks:
- WMT2016
- WMT2017
Format:
*CF*
EN: How are you? TR: | Nasılsın?
*Hybrid*
EN: How are you? TR:
A. Nasılsın?
B. Jak se máš?
Answer: | Nasılsın?/Jak se máš?
*MCF*
EN: How are you? TR:
A. Nasılsın?
B. Jak se máš?
Answer: | A/B
Args:
adapter (Callable[[dict], TranslationInput] | TranslationAdapter): Either a function that takes a dataset row and returns a TranslationInput, or a dictionary with keys corresponding to the field names in the dataset row.
Note: Both TranslationAdapter and TranslationInput are TypeDicts, this means that the caller provides dictionary and doesn't initialize any class!
formulation (Formulation, optional): The formulation to use for the task. Defaults to MCFFormulation().
Returns:
Callable: A function that generates Translation prompts based on the given parameters.
"""
adapter_fn = create_adapter_from_dict(adapter)
continuation_prompt_fn = get_continuation_prompt_function(
Language.ENGLISH,
{"context": "context", "continuations": "continuations", "gold_idx": "gold_idx"},
formulation,
fix_formatting=False,
)
source_translation_literals = TRANSLATION_LITERALS[source_language]
target_translation_literals = TRANSLATION_LITERALS[target_language]

source_label_string = standardize_tag(source_language.value).upper()
target_label_string = standardize_tag(target_language.value).upper()

def translation_prompt(
line: dict,
task_name: str,
):
input_data = adapter_fn(line)
if input_data is None:
return None

source_text = capitalize(fix_ending_punct(input_data["source_text"], source_translation_literals))

context = TRANSLATION_CONTEXT.format(
source_label=source_label_string,
source_text=source_text,
target_label=target_label_string,
colon=":",
sentence_space=" ",
)

continuations = [
capitalize(fix_ending_punct(text, target_translation_literals))
for text in as_list(input_data["target_text"])
]

return continuation_prompt_fn(
{
"instruction": input_data.get("instruction", ""),
"context": context,
"continuations": continuations,
"gold_idx": input_data.get("gold_idx", list(range(len(continuations)))),
},
task_name,
)

return translation_prompt
120 changes: 120 additions & 0 deletions tests/tasks/templates/test_translation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# MIT License

# Copyright (c) 2024 The HuggingFace Team

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.


from lighteval.tasks.templates.translation import get_translation_prompt_function
from lighteval.tasks.templates.utils.formulation import CFFormulation, MCFFormulation
from lighteval.utils.language import Language


def test_translation_prompt_cf():
"""
Tests that translation prompt function works correctly for CF formulation.
"""
test_input = {
"source_text": "Ahoj, jak se máš?",
"target_text": "Bonjour, comment allez-vous?",
}

prompt_fn = get_translation_prompt_function(
source_language=Language.CZECH,
target_language=Language.FRENCH,
adapter=lambda x: {
"source_text": x["source_text"],
"target_text": x["target_text"],
},
formulation=CFFormulation(),
)

doc = prompt_fn(test_input, "test_task")
assert doc is not None

assert doc.query == "CS: Ahoj, jak se máš? FR:"
assert doc.unconditioned_query == ""
assert doc.choices == [" Bonjour, comment allez-vous?"]
assert doc.gold_index == [0]


def test_translation_prompt_mcf():
"""
Tests that translation prompt function works correctly for MCF formulation.
"""
test_input = {
"source_text": "Ahoj, jak se máš?",
"target_text": ["Bonjour, comment allez-vous?", "Ciao, come stai?"],
}

prompt_fn = get_translation_prompt_function(
source_language=Language.CZECH,
target_language=Language.FRENCH,
adapter=lambda x: {
"source_text": x["source_text"],
"target_text": x["target_text"],
"gold_idx": 0,
},
formulation=MCFFormulation(),
)

doc = prompt_fn(test_input, "test_task")
assert doc is not None

assert (
doc.query
== """\
CS: Ahoj, jak se máš? FR:
A. Bonjour, comment allez-vous?
B. Ciao, come stai?
Answer:\
"""
)
assert doc.unconditioned_query == "Answer:"
assert doc.choices == [" A", " B"]
assert doc.gold_index == [0]


def test_translation_prompt_cf_formatting():
"""
Tests that translation prompt function works correctly for CF formulation with formatting.
"""
test_input = {
"source_text": "How are you?",
"target_text": ["你好吗?"],
}

prompt_fn = get_translation_prompt_function(
source_language=Language.ENGLISH,
target_language=Language.CHINESE,
adapter=lambda x: {
"source_text": x["source_text"],
"target_text": x["target_text"],
"gold_idx": 0,
},
formulation=CFFormulation(),
)

doc = prompt_fn(test_input, "test_task")
assert doc is not None

assert doc.query == "EN: How are you? ZH:"
assert doc.unconditioned_query == ""
assert doc.choices == [" 你好吗?"]
assert doc.gold_index == [0]

0 comments on commit f235abc

Please sign in to comment.