diff --git a/src/parallel_corpus/graph.py b/src/parallel_corpus/graph.py index fdbeb1e..b29ebeb 100644 --- a/src/parallel_corpus/graph.py +++ b/src/parallel_corpus/graph.py @@ -4,7 +4,7 @@ import logging import re from dataclasses import dataclass -from typing import Dict, Iterable, List, Optional, TypeVar +from typing import Dict, Iterable, List, Optional, TypedDict, TypeVar import parallel_corpus.shared.ranges import parallel_corpus.shared.str_map @@ -87,6 +87,12 @@ def init(s: str, *, manual: bool = False) -> Graph: return init_from(token.tokenize(s), manual=manual) +def init_with_source_and_target(source: str, target: str, *, manual: bool = False) -> Graph: + return init_from_source_and_target( + source=token.tokenize(source), target=token.tokenize(target), manual=manual + ) + + def init_from(tokens: List[str], *, manual: bool = False) -> Graph: return align( Graph( @@ -99,10 +105,59 @@ def init_from(tokens: List[str], *, manual: bool = False) -> Graph: ) +def init_from_source_and_target( + source: List[str], target: List[str], *, manual: bool = False +) -> Graph: + source_tokens = token.identify(source, "s") + target_tokens = token.identify(target, "t") + return align( + Graph( + source=source_tokens, + target=target_tokens, + edges=edge_record( + itertools.chain( + (edge([s.id], [], manual=manual) for s in source_tokens), + (edge([t.id], [], manual=manual) for t in target_tokens), + ) + ), + ) + ) + + +class TextLabels(TypedDict): + text: str + labels: List[str] + + +def from_unaligned(st: SourceTarget[List[TextLabels]]) -> Graph: + """Initialize a graph from unaligned tokens""" + edges: Dict[str, Edge] = {} + + def proto_token_to_token(tok: TextLabels, i: int, prefix: str) -> Token: + id_ = f"{prefix}{i}" + e = edge([id_], tok["labels"], manual=False) + edges[id_] = e + return Token(tok["text"], id_) + + def proto_tokens_to_tokens(toks: List[TextLabels], side: Side) -> List[Token]: + return [ + proto_token_to_token(tok, i, "s" if side == Side.source else "t") + for i, tok in enumerate(toks) + ] + + g = map_sides(st, proto_tokens_to_tokens) + + return align(Graph(source=g.source, target=g.target, edges=edges)) + + def modify(g: Graph, from_: int, to: int, text: str, side: Side = Side.target) -> Graph: return align(unaligned_modify(g, from_, to, text, side)) +def set_source(g: Graph, text: str) -> Graph: + return align(unaligned_set_side(g, Side.source, text)) + + def set_target(g: Graph, text: str) -> Graph: return align(unaligned_set_side(g, Side.target, text)) diff --git a/src/parallel_corpus/py.typed b/src/parallel_corpus/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_graph.py b/tests/test_graph.py index 8726e4b..5103be0 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -2,7 +2,7 @@ import pytest from parallel_corpus import graph, token -from parallel_corpus.source_target import Side +from parallel_corpus.source_target import Side, SourceTarget def test_graph_init() -> None: @@ -16,6 +16,44 @@ def test_graph_init() -> None: assert g.edges == edges +def test_init_from_source_and_target_1() -> None: + g = graph.init_with_source_and_target(source="apa", target="apa") + assert g == graph.init("apa") + + +def test_init_from_source_and_target_2() -> None: + g = graph.init_with_source_and_target(source="apa bepa", target="apa") + expected_source = token.identify(token.tokenize("apa bepa"), "s") + expected_target = token.identify(token.tokenize("apa"), "t") + g_expected = graph.Graph( + source=expected_source, + target=expected_target, + edges=graph.edge_record([graph.edge(["s0", "t0"], []), graph.edge(["s1"], [])]), + ) + assert g == g_expected + + +def test_init_from_source_and_target_3() -> None: + g = graph.init_with_source_and_target(source="apa", target="bepa apa") + expected_source = token.identify(token.tokenize("apa"), "s") + expected_target = token.identify(token.tokenize("bepa apa"), "t") + g_expected = graph.Graph( + source=expected_source, + target=expected_target, + edges=graph.edge_record([graph.edge(["s0", "t1"], []), graph.edge(["t0"], [])]), + ) + assert g == g_expected + + +def test_from_unaligned() -> None: + g = graph.from_unaligned( + SourceTarget( + source=[{"text": "apa ", "labels": []}], target=[{"text": "apa ", "labels": []}] + ) + ) + assert g == graph.init("apa") + + def test_graph_case1() -> None: first = "Jonathan saknades , emedan han , med sin vapendragare , redan på annat håll sökt och anträffat fienden ." # noqa: E501 second = "Jonat han saknades , emedan han , med sin vapendragare , redan på annat håll sökt och anträffat fienden ." # noqa: E501 @@ -38,6 +76,17 @@ def test_graph_case2() -> None: assert "e-s0-s1-t20" in gm.edges +def test_set_source() -> None: + source = "Jonat han saknades" + target = "Jonathan saknaes" + + g = graph.init(target) + + gm = graph.set_source(g, source) + print(f"{gm=}") + assert "e-s2-s3-t0" in gm.edges + + def test_unaligned_set_side() -> None: g0 = graph.init("a bc d") print(">>> test_unaligned_set_side")