From 4ec14ff11c9e09afbe60166d95944a0dae179813 Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Mon, 25 Nov 2024 12:46:11 -0500 Subject: [PATCH 01/16] Moves to mapper interface Borrowing a design element I used in UDTube, I decompose the dataset object into two pieces: * a `Mapper` interface which knows how to map between lists of strings and tensors (to decode and encode) * `DataSet`, as before There was no particular reason for the mapper functions to live inside the dataset, and this commit simply makes this separation. A subsequent commit will use this mapper object during prediction. --- yoyodyne/data/__init__.py | 3 +- yoyodyne/data/datamodules.py | 116 ++++++++++++++--------- yoyodyne/data/datasets.py | 173 +++++------------------------------ yoyodyne/data/indexes.py | 104 +++++++++------------ yoyodyne/data/mappers.py | 123 +++++++++++++++++++++++++ 5 files changed, 267 insertions(+), 252 deletions(-) create mode 100644 yoyodyne/data/mappers.py diff --git a/yoyodyne/data/__init__.py b/yoyodyne/data/__init__.py index e089c1d0..3c8304fa 100644 --- a/yoyodyne/data/__init__.py +++ b/yoyodyne/data/__init__.py @@ -5,8 +5,9 @@ from .. import defaults from .batches import PaddedBatch, PaddedTensor # noqa: F401 from .datamodules import DataModule # noqa: F401 -from .datasets import Dataset # noqa: F401 from .indexes import Index # noqa: F401 +from .mappers import Mapper # noqa: F401 +from .parser import TsvParser # noqa: F401 def add_argparse_args(parser: argparse.ArgumentParser) -> None: diff --git a/yoyodyne/data/datamodules.py b/yoyodyne/data/datamodules.py index 9a716480..755f0c60 100644 --- a/yoyodyne/data/datamodules.py +++ b/yoyodyne/data/datamodules.py @@ -1,6 +1,7 @@ """Data modules.""" -from typing import Iterable, Optional, Set +import os +from typing import Iterable, Optional import lightning from torch.utils import data @@ -10,43 +11,56 @@ class DataModule(lightning.LightningDataModule): - """Parses, indexes, collates and loads data. + """Data module. + + This class is initialized by the LightningCLI interface. It manages all + data loading steps. + + Args: + model_dir: Path for checkpoints, indexes, and logs. + predict - The batch size tuner is permitted to mutate the `batch_size` argument. """ + predict: Optional[str] + test: Optional[str] + train: Optional[str] + val: Optional[str] parser: tsv.TsvParser - index: indexes.Index + separate_features: bool batch_size: int + index: indexes.Index collator: collators.Collator def __init__( self, # Paths. *, - train: Optional[str] = None, - val: Optional[str] = None, - predict: Optional[str] = None, - test: Optional[str] = None, - index_path: Optional[str] = None, - # TSV parsing arguments. + model_dir: str, + predict=None, + train=None, + test=None, + val=None, + # TSV parsing options. source_col: int = defaults.SOURCE_COL, features_col: int = defaults.FEATURES_COL, target_col: int = defaults.TARGET_COL, - # String parsing arguments. source_sep: str = defaults.SOURCE_SEP, features_sep: str = defaults.FEATURES_SEP, target_sep: str = defaults.TARGET_SEP, + # Modeling options. tie_embeddings: bool = defaults.TIE_EMBEDDINGS, - # Collator options. - batch_size: int = defaults.BATCH_SIZE, separate_features: bool = False, + # Other. + batch_size: int = defaults.BATCH_SIZE, max_source_length: int = defaults.MAX_SOURCE_LENGTH, max_target_length: int = defaults.MAX_TARGET_LENGTH, - # Indexing. - index: Optional[indexes.Index] = None, ): super().__init__() + self.train = train + self.val = val + self.predict = predict + self.test = test self.parser = tsv.TsvParser( source_col=source_col, features_col=features_col, @@ -56,14 +70,15 @@ def __init__( target_sep=target_sep, tie_embeddings=tie_embeddings, ) - self.tie_embeddings = tie_embeddings - self.train = train - self.val = val - self.predict = predict - self.test = test - self.batch_size = batch_size self.separate_features = separate_features - self.index = index if index is not None else self._make_index() + self.batch_size = batch_size + # If the training data is specified, it is used to create (or recreate) + # the index; if not specified it is read from the model directory. + self.index = ( + self._make_index(model_dir, tie_embeddings) + if self.train + else indexes.Index.read(model_dir) + ) self.collator = collators.Collator( has_features=self.has_features, has_target=self.has_target, @@ -72,11 +87,12 @@ def __init__( max_target_length=max_target_length, ) - def _make_index(self) -> indexes.Index: - # Computes index. - source_vocabulary: Set[str] = set() - features_vocabulary: Set[str] = set() - target_vocabulary: Set[str] = set() + def _make_index( + self, model_dir: str, tie_embeddings: bool + ) -> indexes.Index: + source_vocabulary = set() + features_vocabulary = set() if self.has_features else None + target_vocabulary = set() if self.has_target else None if self.has_features: if self.has_target: for source, features, target in self.parser.samples( @@ -96,16 +112,23 @@ def _make_index(self) -> indexes.Index: else: for source in self.parser.samples(self.train): source_vocabulary.update(source) - return indexes.Index( - source_vocabulary=sorted(source_vocabulary), + index = indexes.Index( + source_vocabulary=source_vocabulary, features_vocabulary=( - sorted(features_vocabulary) if features_vocabulary else None + features_vocabulary if features_vocabulary else None ), - target_vocabulary=( - sorted(target_vocabulary) if target_vocabulary else None - ), - tie_embeddings=self.tie_embeddings, + target_vocabulary=target_vocabulary if target_vocabulary else None, + tie_embeddings=tie_embeddings, ) + # Writes it to the model directory. + try: + os.mkdir(model_dir) + except FileExistsError: + pass + index.write(model_dir) + return index + + # Logging. @staticmethod def pprint(vocabulary: Iterable) -> str: @@ -128,9 +151,7 @@ def log_vocabularies(self) -> None: f"{self.pprint(self.index.target_vocabulary)}" ) - def write_index(self, model_dir: str, experiment: str) -> None: - """Writes the index.""" - self.index.write(model_dir, experiment) + # Properties. @property def has_features(self) -> bool: @@ -149,13 +170,6 @@ def source_vocab_size(self) -> int: self.index.source_vocab_size + self.index.features_vocab_size ) - def _dataset(self, path: str) -> datasets.Dataset: - return datasets.Dataset( - list(self.parser.samples(path)), - self.index, - self.parser, - ) - # Required API. def train_dataloader(self) -> data.DataLoader: @@ -166,6 +180,7 @@ def train_dataloader(self) -> data.DataLoader: batch_size=self.batch_size, shuffle=True, num_workers=1, + persistent_workers=True, ) def val_dataloader(self) -> data.DataLoader: @@ -174,7 +189,9 @@ def val_dataloader(self) -> data.DataLoader: self._dataset(self.val), collate_fn=self.collator, batch_size=self.batch_size, + shuffle=False, num_workers=1, + persistent_workers=True, ) def predict_dataloader(self) -> data.DataLoader: @@ -183,7 +200,9 @@ def predict_dataloader(self) -> data.DataLoader: self._dataset(self.predict), collate_fn=self.collator, batch_size=self.batch_size, + shuffle=False, num_workers=1, + persistent_workers=True, ) def test_dataloader(self) -> data.DataLoader: @@ -192,5 +211,14 @@ def test_dataloader(self) -> data.DataLoader: self._dataset(self.test), collate_fn=self.collator, batch_size=self.batch_size, + shuffle=False, num_workers=1, + persistent_workers=True, + ) + + def _dataset(self, path: str) -> datasets.Dataset: + return datasets.Dataset( + list(self.parser.samples(path)), + self.index, + self.parser, ) diff --git a/yoyodyne/data/datasets.py b/yoyodyne/data/datasets.py index 0f7c7521..3a6b7049 100644 --- a/yoyodyne/data/datasets.py +++ b/yoyodyne/data/datasets.py @@ -1,18 +1,13 @@ -"""Datasets and related utilities. - -Anything which has a tensor member should inherit from nn.Module, run the -superclass constructor, and register the tensor as a buffer. This enables the -Trainer to move them to the appropriate device.""" +"""Datasets and related utilities.""" import dataclasses -from typing import Iterator, List, Optional +from typing import List, Optional import torch from torch import nn from torch.utils import data -from .. import special -from . import indexes, tsv +from . import mappers, tsv class Item(nn.Module): @@ -45,140 +40,21 @@ def has_target(self): return self.target is not None +# TODO: Add an iterable data set object for out-of-core inference. + + @dataclasses.dataclass class Dataset(data.Dataset): - """Datatset class.""" + """Mappable data set. + + This class loads the entire file into memory and is therefore only suitable + for in-core data sets. + """ samples: List[tsv.SampleType] - index: indexes.Index # Usually copied from the DataModule. + mapper: mappers.Mapper parser: tsv.TsvParser # Ditto. - @property - def has_features(self) -> bool: - return self.parser.has_features - - @property - def has_target(self) -> bool: - return self.parser.has_target - - def _encode( - self, - symbols: List[str], - ) -> torch.Tensor: - """Encodes a sequence as a tensor of indices with string boundary IDs. - - Args: - symbols (List[str]): symbols to be encoded. - - Returns: - torch.Tensor: the encoded tensor. - """ - return torch.tensor([self.index(symbol) for symbol in symbols]) - - def encode_source(self, symbols: List[str]) -> torch.Tensor: - """Encodes a source string, padding with start and end tags. - - Args: - symbols (List[str]). - - Returns: - torch.Tensor. - """ - wrapped = [special.START] - wrapped.extend(symbols) - wrapped.append(special.END) - return self._encode(wrapped) - - def encode_features(self, symbols: List[str]) -> torch.Tensor: - """Encodes a features string. - - Args: - symbols (List[str]). - - Returns: - torch.Tensor. - """ - return self._encode(symbols) - - def encode_target(self, symbols: List[str]) -> torch.Tensor: - """Encodes a features string, padding with end tags. - - Args: - symbols (List[str]). - - Returns: - torch.Tensor. - """ - wrapped = symbols.copy() - wrapped.append(special.END) - return self._encode(wrapped) - - # Decoding. - - def _decode( - self, - indices: torch.Tensor, - ) -> Iterator[List[str]]: - """Decodes the tensor of indices into lists of symbols. - - Args: - indices (torch.Tensor): 2d tensor of indices. - - Yields: - List[str]: Decoded symbols. - """ - for idx in indices.cpu(): - yield [ - self.index.get_symbol(c) - for c in idx - if not special.isspecial(c) - ] - - def decode_source( - self, - indices: torch.Tensor, - ) -> Iterator[str]: - """Decodes a source tensor. - - Args: - indices (torch.Tensor): 2d tensor of indices. - - Yields: - str: Decoded source strings. - """ - for symbols in self._decode(indices): - yield self.parser.source_string(symbols) - - def decode_features( - self, - indices: torch.Tensor, - ) -> Iterator[str]: - """Decodes a features tensor. - - Args: - indices (torch.Tensor): 2d tensor of indices. - - Yields: - str: Decoded features strings. - """ - for symbols in self._decode(indices): - yield self.parser.feature_string(symbols) - - def decode_target( - self, - indices: torch.Tensor, - ) -> Iterator[str]: - """Decodes a target tensor. - - Args: - indices (torch.Tensor): 2d tensor of indices. - - Yields: - str: Decoded target strings. - """ - for symbols in self._decode(indices): - yield self.parser.target_string(symbols) - # Required API. def __len__(self) -> int: @@ -193,26 +69,27 @@ def __getitem__(self, idx: int) -> Item: Returns: Item. """ + row = self.samples[idx] if self.has_features: if self.has_target: - source, features, target = self.samples[idx] + source, features, target = row return Item( - source=self.encode_source(source), - features=self.encode_features(features), - target=self.encode_target(target), + source=self.mapper.encode_source(source), + features=self.mapper.encode_features(features), + target=self.mapper.encode_target(target), ) else: - source, features = self.samples[idx] + source, features = row return Item( - source=self.encode_source(source), - features=self.encode_features(features), + source=self.mapper.encode_source(source), + features=self.mapper.encode_features(features), ) elif self.has_target: - source, target = self.samples[idx] + source, target = row return Item( - source=self.encode_source(source), - target=self.encode_target(target), + source=self.mapper.encode_source(source), + target=self.mapper.encode_target(target), ) else: - source = self.samples[idx] - return Item(source=self.encode_source(source)) + source = row + return Item(source=self.mapper.encode_source(source)) diff --git a/yoyodyne/data/indexes.py b/yoyodyne/data/indexes.py index c9b70f14..94431fe4 100644 --- a/yoyodyne/data/indexes.py +++ b/yoyodyne/data/indexes.py @@ -1,6 +1,7 @@ """Symbol index.""" -import os +from __future__ import annotations + import pickle from typing import Dict, Iterable, List, Optional @@ -39,7 +40,12 @@ def __init__( self.tie_embeddings = tie_embeddings # We store vocabularies separately for logging purposes. self.source_vocabulary = sorted(source_vocabulary) - self.target_vocabulary = sorted(target_vocabulary) + self.features_vocabulary = ( + sorted(features_vocabulary) if features_vocabulary else None + ) + self.target_vocabulary = ( + sorted(target_vocabulary) if target_vocabulary else None + ) if self.tie_embeddings: # Vocabulary is the union of source and target. vocabulary = sorted( @@ -47,14 +53,11 @@ def __init__( ) else: # Vocabulary consists of target symbols followed by source symbols. - vocabulary = sorted(target_vocabulary) + sorted(source_vocabulary) + vocabulary = self.target_vocabulary + self.source_vocabulary # FeatureInvariantTransformer assumes that features_vocabulary is at # the end of the vocabulary. if features_vocabulary is not None: - self.features_vocabulary = sorted(features_vocabulary) vocabulary.extend(self.features_vocabulary) - else: - self.features_vocabulary = None # Keeps special.SPECIAL first to maintain overlap with features. self._index2symbol = special.SPECIAL + vocabulary self._symbol2index = {c: i for i, c in enumerate(self._index2symbol)} @@ -62,6 +65,8 @@ def __init__( def __len__(self) -> int: return len(self._index2symbol) + # Lookup. + def __call__(self, lookup: str) -> int: """Looks up index by symbol. @@ -84,64 +89,12 @@ def get_symbol(self, index: int) -> str: """ return self._index2symbol[index] - # Serialization support. - - @classmethod - def read(cls, model_dir: str, experiment: str) -> "Index": - """Loads index. - - Args: - model_dir (str). - experiment (str). - - Returns: - Index. - """ - index = cls.__new__(cls) - path = index.index_path(model_dir, experiment) - with open(path, "rb") as source: - dictionary = pickle.load(source) - for key, value in dictionary.items(): - setattr(index, key, value) - return index - - @staticmethod - def index_path(model_dir: str, experiment: str) -> str: - """Computes path for the index file. - - Args: - model_dir (str). - experiment (str). - - Returns: - str. - """ - return f"{model_dir}/{experiment}/index.pkl" - - def write(self, model_dir: str, experiment: str) -> None: - """Writes index. - - Args: - model_dir (str). - experiment (str). - """ - path = self.index_path(model_dir, experiment) - os.makedirs(os.path.dirname(path), exist_ok=True) - with open(path, "wb") as sink: - pickle.dump(vars(self), sink) + # Properties. @property def symbols(self) -> List[str]: return list(self._symbol2index.keys()) - @property - def has_features(self) -> bool: - return self.features_vocab_size > 0 - - @property - def has_target(self) -> bool: - return self.target_vocab_size > 0 - @property def vocab_size(self) -> int: return len(self._symbol2index) @@ -167,6 +120,7 @@ def features_vocab_size(self) -> int: return len(self.features_vocabulary) if self.features_vocabulary else 0 # These are also recorded in the `special` module. + # TODO: are these still needed? @property def pad_idx(self) -> int: @@ -183,3 +137,35 @@ def end_idx(self) -> int: @property def unk_idx(self) -> int: return self._symbol2index[special.UNK] + + # Serialization. + + @classmethod + def read(cls, model_dir: str, experiment: str) -> Index: + """Loads index. + + Args: + model_dir (str). + experiment (str). + + Returns: + Index. + """ + index = cls.__new__(cls) + with open(cls.path(model_dir), "rb") as source: + for key, value in pickle.load(source).items(): + setattr(index, key, value) + return index + + def write(self, model_dir: str) -> None: + """Writes index. + + Args: + model_dir (str). + """ + with open(self.index_path(model_dir), "wb") as sink: + pickle.dump(vars(self), sink) + + @staticmethod + def path(model_dir: str) -> str: + return f"{model_dir}/index.pkl" diff --git a/yoyodyne/data/mappers.py b/yoyodyne/data/mappers.py new file mode 100644 index 00000000..ed3e2bb0 --- /dev/null +++ b/yoyodyne/data/mappers.py @@ -0,0 +1,123 @@ +"""Encodes and decodes tensors.""" + +from __future__ import annotations + +import dataclasses + +from typing import Iterable, List + +import torch + +from . import indexes +from .. import special + + +@dataclasses.dataclass +class Mapper: + """Handles mapping between strings and tensors.""" + + index: indexes.Index # Usually copied from the DataModule. + + @classmethod + def read(cls, model_dir: str) -> Mapper: + """Loads mapper from an index. + + Args: + model_dir (str). + + Returns: + Mapper. + """ + return cls(indexes.Index.read(model_dir)) + + # Encoding. + + def _encode(self, symbols: Iterable[str]): + """Encodes a tensor. + + Args: + ymbols (Iterable[str]). + + Returns: + torch.Tensor: the encoded tensor. + """ + return torch.tensor([self.index(symbol) for symbol in symbols]) + + def encode_source(self, symbols: Iterable[str]) -> torch.Tensor: + """Encodes a source string, padding with start and end tags. + + Args: + symbols (Iterable[str]). + + Returns: + torch.Tensor. + """ + wrapped = [special.START] + wrapped.extend(symbols) + wrapped.append(special.END) + return self._encode(wrapped) + + def encode_features(self, symbols: Iterable[str]) -> torch.Tensor: + """Encodes a features string. + + Args: + symbols (Iterable[str]). + + Returns: + torch.Tensor. + """ + return self._encode(symbols) + + def encode_target(self, symbols: Iterable[str]) -> torch.Tensor: + """Encodes a features string, padding with end tags. + + Args: + symbols (Iterable[str]). + + Returns: + torch.Tensor. + """ + wrapped = list(symbols) + wrapped.append(special.END) + return self._encode(wrapped) + + # Decoding. + + def _decode( + self, + indices: torch.Tensor, + ) -> List[str]: + """Decodes a tensor. + + Args: + indices (torch.Tensor): 1d tensor of indices. + + Yields: + List[str]: Decoded symbols. + """ + return [ + self.index.get_symbol(c) + for c in indices + if c not in self.index.special_idx + ] + + # These are just here for compatibility but they all have + # the same implementation. + + def decode_source( + self, + indices: torch.Tensor, + ) -> List[str]: + return self._decode(indices) + + def decode_features( + self, + indices: torch.Tensor, + ) -> List[str]: + return self._decode(indices) + + def decode_target( + self, + indices: torch.Tensor, + ) -> List[str]: + return self._decode(indices) From 83b98f93ec51ae357a3b42f38018b6d5c985fe5f Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Mon, 25 Nov 2024 13:16:11 -0500 Subject: [PATCH 02/16] Many bugfixes. --- yoyodyne/data/__init__.py | 3 ++- yoyodyne/data/datamodules.py | 6 +++--- yoyodyne/data/datasets.py | 20 ++++++++++++++----- yoyodyne/data/indexes.py | 7 +++++-- yoyodyne/models/expert.py | 37 +++++++++++++++++++++--------------- yoyodyne/train.py | 7 ++++--- 6 files changed, 51 insertions(+), 29 deletions(-) diff --git a/yoyodyne/data/__init__.py b/yoyodyne/data/__init__.py index 3c8304fa..8c57e10f 100644 --- a/yoyodyne/data/__init__.py +++ b/yoyodyne/data/__init__.py @@ -5,9 +5,10 @@ from .. import defaults from .batches import PaddedBatch, PaddedTensor # noqa: F401 from .datamodules import DataModule # noqa: F401 +from .datasets import Dataset # noqa: F401 from .indexes import Index # noqa: F401 from .mappers import Mapper # noqa: F401 -from .parser import TsvParser # noqa: F401 +from .tsv import TsvParser # noqa: F401 def add_argparse_args(parser: argparse.ArgumentParser) -> None: diff --git a/yoyodyne/data/datamodules.py b/yoyodyne/data/datamodules.py index 755f0c60..f4e5e5fa 100644 --- a/yoyodyne/data/datamodules.py +++ b/yoyodyne/data/datamodules.py @@ -7,7 +7,7 @@ from torch.utils import data from .. import defaults, util -from . import collators, datasets, indexes, tsv +from . import collators, datasets, indexes, mappers, tsv class DataModule(lightning.LightningDataModule): @@ -49,8 +49,8 @@ def __init__( features_sep: str = defaults.FEATURES_SEP, target_sep: str = defaults.TARGET_SEP, # Modeling options. - tie_embeddings: bool = defaults.TIE_EMBEDDINGS, separate_features: bool = False, + tie_embeddings: bool = defaults.TIE_EMBEDDINGS, # Other. batch_size: int = defaults.BATCH_SIZE, max_source_length: int = defaults.MAX_SOURCE_LENGTH, @@ -219,6 +219,6 @@ def test_dataloader(self) -> data.DataLoader: def _dataset(self, path: str) -> datasets.Dataset: return datasets.Dataset( list(self.parser.samples(path)), - self.index, + mappers.Mapper(self.index), self.parser, ) diff --git a/yoyodyne/data/datasets.py b/yoyodyne/data/datasets.py index 3a6b7049..f13184e7 100644 --- a/yoyodyne/data/datasets.py +++ b/yoyodyne/data/datasets.py @@ -55,6 +55,16 @@ class Dataset(data.Dataset): mapper: mappers.Mapper parser: tsv.TsvParser # Ditto. + # Properties. + + @property + def has_features(self) -> bool: + return self.parser.has_features + + @property + def has_target(self) -> bool: + return self.parser.has_target + # Required API. def __len__(self) -> int: @@ -69,27 +79,27 @@ def __getitem__(self, idx: int) -> Item: Returns: Item. """ - row = self.samples[idx] + sample = self.samples[idx] if self.has_features: if self.has_target: - source, features, target = row + source, features, target = sample return Item( source=self.mapper.encode_source(source), features=self.mapper.encode_features(features), target=self.mapper.encode_target(target), ) else: - source, features = row + source, features = sample return Item( source=self.mapper.encode_source(source), features=self.mapper.encode_features(features), ) elif self.has_target: - source, target = row + source, target = sample return Item( source=self.mapper.encode_source(source), target=self.mapper.encode_target(target), ) else: - source = row + source = sample return Item(source=self.mapper.encode_source(source)) diff --git a/yoyodyne/data/indexes.py b/yoyodyne/data/indexes.py index 94431fe4..0da5d610 100644 --- a/yoyodyne/data/indexes.py +++ b/yoyodyne/data/indexes.py @@ -2,6 +2,7 @@ from __future__ import annotations +import itertools import pickle from typing import Dict, Iterable, List, Optional @@ -49,7 +50,9 @@ def __init__( if self.tie_embeddings: # Vocabulary is the union of source and target. vocabulary = sorted( - frozenset(source_vocabulary + target_vocabulary) + frozenset( + itertools.chain(source_vocabulary, target_vocabulary) + ) ) else: # Vocabulary consists of target symbols followed by source symbols. @@ -163,7 +166,7 @@ def write(self, model_dir: str) -> None: Args: model_dir (str). """ - with open(self.index_path(model_dir), "wb") as sink: + with open(self.path(model_dir), "wb") as sink: pickle.dump(vars(self), sink) @staticmethod diff --git a/yoyodyne/models/expert.py b/yoyodyne/models/expert.py index 00662fd6..007a665e 100644 --- a/yoyodyne/models/expert.py +++ b/yoyodyne/models/expert.py @@ -27,11 +27,15 @@ from .. import data, defaults -class ActionError(Exception): +class Error(Exception): pass -class AlignerError(Exception): +class ActionError(Error): + pass + + +class AlignerError(Error): pass @@ -430,6 +434,7 @@ def find_prefixes( def get_expert( train_data: data.Dataset, + index: data.Index, epochs: int = defaults.ORACLE_EM_EPOCHS, oracle_factor: int = defaults.ORACLE_FACTOR, sed_params_path: str = None, @@ -439,6 +444,7 @@ def get_expert( Args: data (data.Dataset): dataset for generating expert vocabulary. + index (data.Index): index for mapping symbols to indices. epochs (int): number of EM epochs. oracle_factor (float): scaling factor to determine rate of expert rollout sampling. @@ -453,27 +459,28 @@ def get_expert( def _generate_data( data: data.Dataset, + index: data.Index, ) -> Iterator[Tuple[List[int], List[int]]]: """Helper function to manage data encoding for SED." - We want encodings without BOS or EOS tokens. This - encodes only raw source-target text for the Maxwell library. + We want encodings without BOS or EOS tokens. This encodes only raw + source-target text for the Maxwell library. Args: - data (data.Dataset): Dataset to iterate over. + data (data.Dataset): dataset for generating expert vocabulary. + index (data.Index): index for mapping symbols to indices. - Returns: - Iterator[Tuple[List[int], List[int]]]: Iterator that - yields list version of source and target entries - in dataset. + Yields: + Tuple[List[int, List[int]]]: lists of source and target entries. """ - assert data.has_target, """Passed dataset with no target to expert - module, cannot perform SED""" + if not data.has_target: + raise Error("Dataset has no target") for sample in data.samples: - source, target = sample[0], sample[-1] - yield [data.index(s) for s in source], [ - data.index(t) for t in target - ] + source, *_, target = sample + yield ( + [index(symbol) for symbol in source], + [index(symbol) for symbol in target], + ) actions = ActionVocabulary(train_data.index) if read_from_file: diff --git a/yoyodyne/train.py b/yoyodyne/train.py index 08fd5f1d..eb21935a 100644 --- a/yoyodyne/train.py +++ b/yoyodyne/train.py @@ -150,9 +150,9 @@ def get_datamodule_from_argparse_args( "transducer_lstm", ] datamodule = data.DataModule( + model_dir=args.model_dir, train=args.train, val=args.val, - batch_size=args.batch_size, source_col=args.source_col, features_col=args.features_col, target_col=args.target_col, @@ -160,13 +160,13 @@ def get_datamodule_from_argparse_args( features_sep=args.features_sep, target_sep=args.target_sep, separate_features=separate_features, + tie_embeddings=args.tie_embeddings, + batch_size=args.batch_size, max_source_length=args.max_source_length, max_target_length=args.max_target_length, - tie_embeddings=args.tie_embeddings, ) if not datamodule.has_target: raise Error("No target column specified") - datamodule.index.write(args.model_dir, args.experiment) datamodule.log_vocabularies() return datamodule @@ -213,6 +213,7 @@ def get_model_from_argparse_args( ) expert = models.expert.get_expert( datamodule.train_dataloader().dataset, + datamodule.index, epochs=args.oracle_em_epochs, oracle_factor=args.oracle_factor, sed_params_path=sed_params_paths, From 18ced9c1817359ec929891a205b8fd119c1358dc Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Mon, 25 Nov 2024 13:46:54 -0500 Subject: [PATCH 03/16] Removes `experiment` layer You can just simulate this by appending an additional string onto the name of the model_dir if needed. --- yoyodyne/data/indexes.py | 3 +-- yoyodyne/predict.py | 5 +---- yoyodyne/train.py | 16 +++++----------- 3 files changed, 7 insertions(+), 17 deletions(-) diff --git a/yoyodyne/data/indexes.py b/yoyodyne/data/indexes.py index 0da5d610..019edf66 100644 --- a/yoyodyne/data/indexes.py +++ b/yoyodyne/data/indexes.py @@ -144,12 +144,11 @@ def unk_idx(self) -> int: # Serialization. @classmethod - def read(cls, model_dir: str, experiment: str) -> Index: + def read(cls, model_dir: str) -> Index: """Loads index. Args: model_dir (str). - experiment (str). Returns: Index. diff --git a/yoyodyne/predict.py b/yoyodyne/predict.py index 4b057b68..1e25041b 100644 --- a/yoyodyne/predict.py +++ b/yoyodyne/predict.py @@ -40,7 +40,7 @@ def get_datamodule_from_argparse_args( "pointer_generator_transformer", "transducer", ] - index = data.Index.read(args.model_dir, args.experiment) + index = data.Index.read(args.model_dir) return data.DataModule( predict=args.predict, batch_size=args.batch_size, @@ -140,9 +140,6 @@ def add_argparse_args(parser: argparse.ArgumentParser) -> None: required=True, help="Path to output model directory.", ) - parser.add_argument( - "--experiment", required=True, help="Name of experiment." - ) parser.add_argument( "--predict", required=True, diff --git a/yoyodyne/train.py b/yoyodyne/train.py index eb21935a..86f12be1 100644 --- a/yoyodyne/train.py +++ b/yoyodyne/train.py @@ -24,20 +24,19 @@ class Error(Exception): pass -def _get_loggers(experiment: str, model_dir: str, log_wandb: bool) -> List: +def _get_loggers(model_dir: str, log_wandb: bool) -> List: """Creates the logger(s). Args: - experiment (str). model_dir (str). log_wandb (bool). Returns: List: logger. """ - trainer_loggers = [loggers.CSVLogger(model_dir, name=experiment)] + trainer_loggers = [loggers.CSVLogger(model_dir)] if log_wandb: - trainer_loggers.append(loggers.WandbLogger(project=experiment)) + trainer_loggers.append(loggers.WandbLogger()) # Logs the path to local artifacts made by PTL. wandb.config["local_run_dir"] = trainer_loggers[0].log_dir return trainer_loggers @@ -125,7 +124,7 @@ def get_trainer_from_argparse_args( ), default_root_dir=args.model_dir, enable_checkpointing=True, - logger=_get_loggers(args.experiment, args.model_dir, args.log_wandb), + logger=_get_loggers(args.model_dir, args.log_wandb), ) @@ -207,9 +206,7 @@ def get_model_from_argparse_args( expert = None if args.arch in ["transducer_gru", "transducer_lstm"]: sed_params_paths = ( - args.sed_params - if args.sed_params - else f"{args.model_dir}/{args.experiment}/sed.pkl" + args.sed_params if args.sed_params else f"{args.model_dir}/sed.pkl" ) expert = models.expert.get_expert( datamodule.train_dataloader().dataset, @@ -343,9 +340,6 @@ def add_argparse_args(parser: argparse.ArgumentParser) -> None: required=True, help="Path to output model directory.", ) - parser.add_argument( - "--experiment", required=True, help="Name of experiment." - ) parser.add_argument( "--train", required=True, From 37cc30e68b6b9723dd079a6b8eeab202065bd1a4 Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Mon, 25 Nov 2024 13:51:24 -0500 Subject: [PATCH 04/16] Updates README to reflect last commit --- README.md | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 63ed3747..7bd5ca0a 100644 --- a/README.md +++ b/README.md @@ -79,7 +79,6 @@ Training is performed by the [`yoyodyne-train`](yoyodyne/train.py) script. One must specify the following required arguments: - `--model_dir`: path for model metadata and checkpoints -- `--experiment`: name of experiment (pick something unique) - `--train`: path to TSV file containing training data - `--val`: path to TSV file containing validation data @@ -108,7 +107,6 @@ One must specify the following required arguments: - `--arch`: architecture, matching the one used for training - `--model_dir`: path for model metadata -- `--experiment`: name of experiment - `--checkpoint`: path to checkpoint - `--predict`: path to file containing data to be predicted - `--output`: path for predictions @@ -162,12 +160,15 @@ provide any symbols of the form `<...>`, `[...]`, or `{...}`. Checkpointing is handled by [Lightning](https://pytorch-lightning.readthedocs.io/en/stable/common/checkpointing_basic.html). The path for model information, including checkpoints, is specified by a -combination of `--model_dir` and `--experiment`, such that we build the path -`model_dir/experiment/version_n`, where each run of an experiment with the same -`model_dir` and `experiment` is namespaced with a new version number. A version -stores everything needed to reload the model, including the hyperparameters -(`model_dir/experiment_name/version_n/hparams.yaml`) and the checkpoints -directory (`model_dir/experiment_name/version_n/checkpoints`). +combination of `--model_dir` such that we build the path `model_dir/version_n`, +where each run of an experiment with the same `model_dir` is namespaced with a +new version number. A version stores everything needed to reload the model, +including: + +- the index (`model_dir/index.pkl`), +- the hyperparameters (`model_dir/lightning_logs/version_n/hparams.yaml`), +- the metrics (`model_dir/lightning_logs/version_n/metrics.csv`), and +- the checkpoints (`model_dir/lightning_logs/version_n/checkpoints`). By default, each run initializes a new model from scratch, unless the `--train_from` argument is specified. To continue training from a specific From b38e96852fcd8893927af361306d184aca53d804 Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Mon, 25 Nov 2024 16:54:12 -0500 Subject: [PATCH 05/16] Updates and adds prediction support. --- yoyodyne/data/datamodules.py | 5 +--- yoyodyne/data/indexes.py | 19 ------------- yoyodyne/data/mappers.py | 2 +- yoyodyne/models/__init__.py | 1 + yoyodyne/models/expert.py | 16 +++++------ yoyodyne/predict.py | 54 ++++++++++++++++++++++-------------- 6 files changed, 44 insertions(+), 53 deletions(-) diff --git a/yoyodyne/data/datamodules.py b/yoyodyne/data/datamodules.py index f4e5e5fa..5c404eb5 100644 --- a/yoyodyne/data/datamodules.py +++ b/yoyodyne/data/datamodules.py @@ -121,10 +121,7 @@ def _make_index( tie_embeddings=tie_embeddings, ) # Writes it to the model directory. - try: - os.mkdir(model_dir) - except FileExistsError: - pass + os.makedirs(model_dir, exist_ok=True) index.write(model_dir) return index diff --git a/yoyodyne/data/indexes.py b/yoyodyne/data/indexes.py index 019edf66..ad3cc387 100644 --- a/yoyodyne/data/indexes.py +++ b/yoyodyne/data/indexes.py @@ -122,25 +122,6 @@ def target_vocab_size(self) -> int: def features_vocab_size(self) -> int: return len(self.features_vocabulary) if self.features_vocabulary else 0 - # These are also recorded in the `special` module. - # TODO: are these still needed? - - @property - def pad_idx(self) -> int: - return self._symbol2index[special.PAD] - - @property - def start_idx(self) -> int: - return self._symbol2index[special.START] - - @property - def end_idx(self) -> int: - return self._symbol2index[special.END] - - @property - def unk_idx(self) -> int: - return self._symbol2index[special.UNK] - # Serialization. @classmethod diff --git a/yoyodyne/data/mappers.py b/yoyodyne/data/mappers.py index ed3e2bb0..49d4c468 100644 --- a/yoyodyne/data/mappers.py +++ b/yoyodyne/data/mappers.py @@ -98,7 +98,7 @@ def _decode( return [ self.index.get_symbol(c) for c in indices - if c not in self.index.special_idx + if not special.isspecial(c) ] # These are just here for compatibility but they all have diff --git a/yoyodyne/models/__init__.py b/yoyodyne/models/__init__.py index 09b64b8f..01c1de35 100644 --- a/yoyodyne/models/__init__.py +++ b/yoyodyne/models/__init__.py @@ -20,6 +20,7 @@ from .transducer import TransducerGRUModel, TransducerLSTMModel # noqa: F401 from .transformer import TransformerModel + _model_fac = { "attentive_gru": AttentiveGRUModel, "attentive_lstm": AttentiveLSTMModel, diff --git a/yoyodyne/models/expert.py b/yoyodyne/models/expert.py index 007a665e..c4ba637f 100644 --- a/yoyodyne/models/expert.py +++ b/yoyodyne/models/expert.py @@ -433,7 +433,7 @@ def find_prefixes( def get_expert( - train_data: data.Dataset, + dataset: data.Dataset, index: data.Index, epochs: int = defaults.ORACLE_EM_EPOCHS, oracle_factor: int = defaults.ORACLE_FACTOR, @@ -443,7 +443,7 @@ def get_expert( """Generates expert object for training transducer. Args: - data (data.Dataset): dataset for generating expert vocabulary. + dataset (data.Dataset): dataset for generating expert vocabulary. index (data.Index): index for mapping symbols to indices. epochs (int): number of EM epochs. oracle_factor (float): scaling factor to determine rate of @@ -458,7 +458,7 @@ def get_expert( """ def _generate_data( - data: data.Dataset, + dataset: data.Dataset, index: data.Index, ) -> Iterator[Tuple[List[int], List[int]]]: """Helper function to manage data encoding for SED." @@ -467,28 +467,28 @@ def _generate_data( source-target text for the Maxwell library. Args: - data (data.Dataset): dataset for generating expert vocabulary. + dataset (data.Dataset): dataset for generating expert vocabulary. index (data.Index): index for mapping symbols to indices. Yields: Tuple[List[int, List[int]]]: lists of source and target entries. """ - if not data.has_target: + if not dataset.has_target: raise Error("Dataset has no target") - for sample in data.samples: + for sample in dataset.samples: source, *_, target = sample yield ( [index(symbol) for symbol in source], [index(symbol) for symbol in target], ) - actions = ActionVocabulary(train_data.index) + actions = ActionVocabulary(index) if read_from_file: sed_params = sed.ParamDict.read_params(sed_params_path) sed_aligner = sed.StochasticEditDistance(sed_params) else: sed_aligner = sed.StochasticEditDistance.fit_from_data( - _generate_data(train_data), + _generate_data(dataset, index), epochs=epochs, ) sed_aligner.params.write_params(sed_params_path) diff --git a/yoyodyne/predict.py b/yoyodyne/predict.py index 1e25041b..b1502d1c 100644 --- a/yoyodyne/predict.py +++ b/yoyodyne/predict.py @@ -36,12 +36,16 @@ def get_datamodule_from_argparse_args( data.DataModule. """ separate_features = args.features_col != 0 and args.arch in [ - "pointer_generator_rnn", + "hard_attention_gru", + "hard_attention_lstm", + "pointer_generator_gru", + "pointer_generator_lstm", "pointer_generator_transformer", - "transducer", + "transducer_grm", + "transducer_lstm", ] - index = data.Index.read(args.model_dir) return data.DataModule( + model_dir=args.model_dir, predict=args.predict, batch_size=args.batch_size, source_col=args.source_col, @@ -53,7 +57,6 @@ def get_datamodule_from_argparse_args( separate_features=separate_features, max_source_length=args.max_source_length, max_target_length=args.max_target_length, - index=index, ) @@ -76,17 +79,6 @@ def get_model_from_argparse_args( return model_cls.load_from_checkpoint(args.checkpoint, **kwargs) -def _mkdir(output: str) -> None: - """Creates directory for output file if necessary. - - Args: - output (str): output to output file. - """ - dirname = os.path.dirname(output) - if dirname: - os.makedirs(dirname, exist_ok=True) - - def predict( trainer: lightning.Trainer, model: models.BaseModel, @@ -104,23 +96,43 @@ def predict( util.log_info(f"Writing to {output}") _mkdir(output) loader = datamodule.predict_dataloader() + parser = datamodule.parser + mapper = data.Mapper(datamodule.index) with open(output, "w", encoding=defaults.ENCODING) as sink: if model.beam_width > 1: # Beam search. tsv_writer = csv.writer(sink, delimiter="\t") for predictions, scores in trainer.predict(model, loader): predictions = util.pad_tensor_after_eos(predictions) - decoded_predictions = loader.dataset.decode_target(predictions) - row = itertools.chain( - *zip(decoded_predictions, scores.tolist()) + # TODO: beam search requires singleton batches and this + # assumes that. Revise if that restriction is ever lifted. + targets = [ + parser.target_string(mapper.decode_target(target)) + for target in predictions + ] + row = itertools.chain.from_iterable( + zip(targets, scores.tolist()) ) tsv_writer.writerow(row) else: # Greedy search. for predictions, _ in trainer.predict(model, loader): predictions = util.pad_tensor_after_eos(predictions) - for prediction in loader.dataset.decode_target(predictions): - print(prediction, file=sink) + # Unpacks each element in the batch. + for target in predictions: + symbols = mapper.decode_target(target) + print(parser.target_string(symbols), file=sink) + + +def _mkdir(output: str) -> None: + """Creates directory for output file if necessary. + + Args: + output (str): output to output file. + """ + dirname = os.path.dirname(output) + if dirname: + os.makedirs(dirname, exist_ok=True) def add_argparse_args(parser: argparse.ArgumentParser) -> None: @@ -154,7 +166,7 @@ def add_argparse_args(parser: argparse.ArgumentParser) -> None: data.add_argparse_args(parser) # Architecture arguments; the architecture-specific ones are not needed. models.add_argparse_args(parser) - models.BaseEncoderDecoder.add_predict_argparse_args(parser) + models.BaseModel.add_predict_argparse_args(parser) # Among the things this adds, the following are likely to be useful: # --accelerator ("gpu" for GPU) # --devices (for multiple device support) From 3dcfe12c6732ffebb3a7a5a7d579123eb0cb1fc2 Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Mon, 25 Nov 2024 19:15:14 -0500 Subject: [PATCH 06/16] remove unused --- yoyodyne/data/indexes.py | 1 - 1 file changed, 1 deletion(-) diff --git a/yoyodyne/data/indexes.py b/yoyodyne/data/indexes.py index f2d575c4..f115e6bf 100644 --- a/yoyodyne/data/indexes.py +++ b/yoyodyne/data/indexes.py @@ -4,7 +4,6 @@ import itertools import pickle -import os from typing import Dict, Iterable, List, Optional From 1aedf8649450a3b2153428bd6d2ec4e8850670a2 Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Sat, 30 Nov 2024 13:00:23 -0500 Subject: [PATCH 07/16] Cleanup. --- yoyodyne/data/collators.py | 1 + yoyodyne/data/datamodules.py | 1 + yoyodyne/data/datasets.py | 1 + yoyodyne/data/tsv.py | 1 + 4 files changed, 4 insertions(+) diff --git a/yoyodyne/data/collators.py b/yoyodyne/data/collators.py index 43babc33..3cb23a7c 100644 --- a/yoyodyne/data/collators.py +++ b/yoyodyne/data/collators.py @@ -2,6 +2,7 @@ import argparse import dataclasses + from typing import List import torch diff --git a/yoyodyne/data/datamodules.py b/yoyodyne/data/datamodules.py index 5c404eb5..94bc441b 100644 --- a/yoyodyne/data/datamodules.py +++ b/yoyodyne/data/datamodules.py @@ -1,6 +1,7 @@ """Data modules.""" import os + from typing import Iterable, Optional import lightning diff --git a/yoyodyne/data/datasets.py b/yoyodyne/data/datasets.py index f13184e7..28384e97 100644 --- a/yoyodyne/data/datasets.py +++ b/yoyodyne/data/datasets.py @@ -1,6 +1,7 @@ """Datasets and related utilities.""" import dataclasses + from typing import List, Optional import torch diff --git a/yoyodyne/data/tsv.py b/yoyodyne/data/tsv.py index 2b5dc19a..d0eedf7a 100644 --- a/yoyodyne/data/tsv.py +++ b/yoyodyne/data/tsv.py @@ -6,6 +6,7 @@ import csv import dataclasses + from typing import Iterator, List, Tuple, Union from .. import defaults From 40d76b5151c3fd9eca2197caf82efbe4487a12a0 Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Sat, 30 Nov 2024 13:03:56 -0500 Subject: [PATCH 08/16] Indentation --- yoyodyne/data/indexes.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/yoyodyne/data/indexes.py b/yoyodyne/data/indexes.py index 58507325..d14e5d94 100644 --- a/yoyodyne/data/indexes.py +++ b/yoyodyne/data/indexes.py @@ -114,7 +114,7 @@ def read(cls, model_dir: str) -> Index: setattr(index, key, value) return index - def write(self, model_dir: str) -> None: + def write(self, model_dir: str) -> None: """Writes index. Args: @@ -124,7 +124,7 @@ def write(self, model_dir: str) -> None: os.makedirs(os.path.dirname(path), exist_ok=True) with open(path, "wb") as sink: pickle.dump(vars(self), sink) - + @staticmethod def path(model_dir: str) -> str: """Computes path for the index file. @@ -138,7 +138,7 @@ def path(model_dir: str) -> str: return f"{model_dir}/index.pkl" # Properties. - + @property def symbols(self) -> List[str]: return list(self._symbol2index.keys()) From 56a106c78300ffd1ac98dedd18795e3291b6c13d Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Mon, 2 Dec 2024 10:35:14 -0500 Subject: [PATCH 09/16] adds comment --- yoyodyne/data/datamodules.py | 3 --- yoyodyne/predict.py | 1 + 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/yoyodyne/data/datamodules.py b/yoyodyne/data/datamodules.py index 94bc441b..de2a19fc 100644 --- a/yoyodyne/data/datamodules.py +++ b/yoyodyne/data/datamodules.py @@ -1,7 +1,5 @@ """Data modules.""" -import os - from typing import Iterable, Optional import lightning @@ -122,7 +120,6 @@ def _make_index( tie_embeddings=tie_embeddings, ) # Writes it to the model directory. - os.makedirs(model_dir, exist_ok=True) index.write(model_dir) return index diff --git a/yoyodyne/predict.py b/yoyodyne/predict.py index b1502d1c..1c577674 100644 --- a/yoyodyne/predict.py +++ b/yoyodyne/predict.py @@ -110,6 +110,7 @@ def predict( parser.target_string(mapper.decode_target(target)) for target in predictions ] + # Collates target strings and their scores. row = itertools.chain.from_iterable( zip(targets, scores.tolist()) ) From a1272af02c60a7dd01819148e26807be9b75c267 Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Mon, 2 Dec 2024 10:56:51 -0500 Subject: [PATCH 10/16] Last-minute --- yoyodyne/data/indexes.py | 2 +- yoyodyne/predict.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/yoyodyne/data/indexes.py b/yoyodyne/data/indexes.py index d14e5d94..815adbad 100644 --- a/yoyodyne/data/indexes.py +++ b/yoyodyne/data/indexes.py @@ -61,7 +61,7 @@ def __init__( vocabulary = self.target_vocabulary + self.source_vocabulary # FeatureInvariantTransformer assumes that features_vocabulary is at # the end of the vocabulary. - if features_vocabulary is not None: + if self.features_vocabulary is not None: vocabulary.extend(self.features_vocabulary) # Keeps special.SPECIAL first to maintain overlap with features. self._index2symbol = special.SPECIAL + vocabulary diff --git a/yoyodyne/predict.py b/yoyodyne/predict.py index 1c577674..deaf717f 100644 --- a/yoyodyne/predict.py +++ b/yoyodyne/predict.py @@ -27,7 +27,7 @@ def get_trainer_from_argparse_args( def get_datamodule_from_argparse_args( args: argparse.Namespace, ) -> data.DataModule: - """Creates the dataset from CLI arguments. + """Creates the datamoodule from CLI arguments. Args: args (argparse.Namespace). From b879b3280057485ad8f5a1ac97cc3f407bcc72f2 Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Mon, 2 Dec 2024 11:05:14 -0500 Subject: [PATCH 11/16] black/flake8 updates --- yoyodyne/data/datamodules.py | 13 ++----------- yoyodyne/data/indexes.py | 1 - yoyodyne/predict.py | 11 ----------- 3 files changed, 2 insertions(+), 23 deletions(-) diff --git a/yoyodyne/data/datamodules.py b/yoyodyne/data/datamodules.py index d845005b..c911f22d 100644 --- a/yoyodyne/data/datamodules.py +++ b/yoyodyne/data/datamodules.py @@ -1,6 +1,6 @@ """Data modules.""" -from typing import Iterable, Optional +from typing import Iterable, Optional, Set import lightning from torch.utils import data @@ -116,7 +116,7 @@ def _make_index( features_vocabulary if features_vocabulary else None ), target_vocabulary=target_vocabulary if target_vocabulary else None, - tie_embeddings=tie_embeddings, + tie_embeddings=tie_embeddings, ) # Writes it to the model directory. index.write(model_dir) @@ -155,15 +155,6 @@ def has_features(self) -> bool: def has_target(self) -> bool: return self.parser.has_target - @property - def source_vocab_size(self) -> int: - if self.separate_features: - return self.index.source_vocab_size - else: - return ( - self.index.source_vocab_size + self.index.features_vocab_size - ) - # Required API. def train_dataloader(self) -> data.DataLoader: diff --git a/yoyodyne/data/indexes.py b/yoyodyne/data/indexes.py index 298660d9..3ea5d128 100644 --- a/yoyodyne/data/indexes.py +++ b/yoyodyne/data/indexes.py @@ -4,7 +4,6 @@ import itertools import pickle -import os from typing import Dict, Iterable, List, Optional diff --git a/yoyodyne/predict.py b/yoyodyne/predict.py index c5d8dbb7..1ab8ddeb 100644 --- a/yoyodyne/predict.py +++ b/yoyodyne/predict.py @@ -125,17 +125,6 @@ def predict( print(parser.target_string(symbols), file=sink) -def _mkdir(output: str) -> None: - """Creates directory for output file if necessary. - - Args: - output (str): output to output file. - """ - dirname = os.path.dirname(output) - if dirname: - os.makedirs(dirname, exist_ok=True) - - def add_argparse_args(parser: argparse.ArgumentParser) -> None: """Adds prediction arguments to parser. From ddb6989bce6102d6ee622e7d6ba19eacd17f9ec4 Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Mon, 2 Dec 2024 11:15:00 -0500 Subject: [PATCH 12/16] Updates version number further --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 997df497..a3339f7f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ exclude = ["examples*"] [project] name = "yoyodyne" -version = "0.2.16" +version = "0.2.17" description = "Small-vocabulary neural sequence-to-sequence models" readme = "README.md" requires-python = ">= 3.9" From f00d3f9aa603d61c9b1db506748141a101395155 Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Mon, 2 Dec 2024 11:23:44 -0500 Subject: [PATCH 13/16] Remove redundant instance variable. --- yoyodyne/data/datamodules.py | 1 - 1 file changed, 1 deletion(-) diff --git a/yoyodyne/data/datamodules.py b/yoyodyne/data/datamodules.py index c911f22d..efe0c91c 100644 --- a/yoyodyne/data/datamodules.py +++ b/yoyodyne/data/datamodules.py @@ -68,7 +68,6 @@ def __init__( target_sep=target_sep, tie_embeddings=tie_embeddings, ) - self.separate_features = separate_features self.batch_size = batch_size # If the training data is specified, it is used to create (or recreate) # the index; if not specified it is read from the model directory. From 7b7b32af3c83d1583c85f5df791f9d77525261bc Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Mon, 2 Dec 2024 11:32:44 -0500 Subject: [PATCH 14/16] Docs cleanup --- yoyodyne/data/datamodules.py | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/yoyodyne/data/datamodules.py b/yoyodyne/data/datamodules.py index efe0c91c..340275cb 100644 --- a/yoyodyne/data/datamodules.py +++ b/yoyodyne/data/datamodules.py @@ -12,13 +12,32 @@ class DataModule(lightning.LightningDataModule): """Data module. - This class is initialized by the LightningCLI interface. It manages all - data loading steps. - Args: model_dir: Path for checkpoints, indexes, and logs. - predict - + train: Path for training data TSV. + val: Path for validation data TSV. + predict: Path for prediction data TSV. + test: Path for test data TSV. + source_col: 1-indexed column in TSV containing source strings. + features_col: 1-indexed column in TSV containing features strings. + target_col: 1-indexed column in TSV containing target strings. + source_sep: String used to split source string into symbols; an empty + string indicates that each Unicode codepoint is its own symbol. + features_sep: String used to split features string into symbols; an + empty string indicates that each Unicode codepoint is its own symbol. + target_sep: String used to split target string into symbols; an empty + string indicates that each Unicode codepoint is its own symbol. + separate_features: Whether or not a separate encoder should be used + for features. + tie_embeddings: Whether or not source and target embeddings are tied. If + not, then source symbols are wrapped in {...}. + batch_size: Desired batch size. + max_source_length: The maximum length of a source string; this includes + concatenated feature strings if not using separate features. An + error will be raised if any source exceeds this limit. + max_target_length: The maximum length of a target string. A warning + will be raised and the target strings will be truncated if any + target exceeds this limit. """ train: Optional[str] @@ -87,6 +106,7 @@ def __init__( def _make_index( self, model_dir: str, tie_embeddings: bool ) -> indexes.Index: + """Creates the index from a training set.""" source_vocabulary: Set[str] = set() features_vocabulary: Set[str] = set() target_vocabulary: Set[str] = set() @@ -125,7 +145,7 @@ def _make_index( @staticmethod def pprint(vocabulary: Iterable) -> str: - """Prints the vocabulary for debugging adn logging purposes.""" + """Prints the vocabulary for debugging dnd logging purposes.""" return ", ".join(f"{symbol!r}" for symbol in vocabulary) def log_vocabularies(self) -> None: From 3543096ce4d2a22a4723f6ef9c4383ec551f1758 Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Mon, 2 Dec 2024 11:37:28 -0500 Subject: [PATCH 15/16] wrap --- yoyodyne/data/datamodules.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/yoyodyne/data/datamodules.py b/yoyodyne/data/datamodules.py index 340275cb..55475d7c 100644 --- a/yoyodyne/data/datamodules.py +++ b/yoyodyne/data/datamodules.py @@ -12,6 +12,9 @@ class DataModule(lightning.LightningDataModule): """Data module. + This is responsible for indexing the data, collating/padding, and + generating datasets. + Args: model_dir: Path for checkpoints, indexes, and logs. train: Path for training data TSV. @@ -24,13 +27,14 @@ class DataModule(lightning.LightningDataModule): source_sep: String used to split source string into symbols; an empty string indicates that each Unicode codepoint is its own symbol. features_sep: String used to split features string into symbols; an - empty string indicates that each Unicode codepoint is its own symbol. + empty string indicates that each Unicode codepoint is its own + symbol. target_sep: String used to split target string into symbols; an empty string indicates that each Unicode codepoint is its own symbol. separate_features: Whether or not a separate encoder should be used for features. - tie_embeddings: Whether or not source and target embeddings are tied. If - not, then source symbols are wrapped in {...}. + tie_embeddings: Whether or not source and target embeddings are tied. + If not, then source symbols are wrapped in {...}. batch_size: Desired batch size. max_source_length: The maximum length of a source string; this includes concatenated feature strings if not using separate features. An From 8bc012714949a338267fcb780f41473fe2f1b0dd Mon Sep 17 00:00:00 2001 From: Kyle Gorman Date: Mon, 2 Dec 2024 11:43:44 -0500 Subject: [PATCH 16/16] typos --- yoyodyne/models/expert.py | 6 +++--- yoyodyne/predict.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/yoyodyne/models/expert.py b/yoyodyne/models/expert.py index d7a7bab8..5dc1e2cf 100644 --- a/yoyodyne/models/expert.py +++ b/yoyodyne/models/expert.py @@ -461,10 +461,10 @@ def _generate_data( dataset: data.Dataset, index: data.Index, ) -> Iterator[Tuple[List[int], List[int]]]: - """Helper function to manage data encoding for SED." + """Helper function to manage data encoding for SED. - We want encodings without BOS or EOS tokens. This encodes only raw - source-target text for the Maxwell library. + We want encodings without padding. This encodes only raw source-target + text for the Maxwell library. Args: dataset (data.Dataset): dataset for generating expert vocabulary. diff --git a/yoyodyne/predict.py b/yoyodyne/predict.py index 41712b09..b8c38734 100644 --- a/yoyodyne/predict.py +++ b/yoyodyne/predict.py @@ -103,7 +103,7 @@ def predict( # Beam search. tsv_writer = csv.writer(sink, delimiter="\t") for predictions, scores in trainer.predict(model, loader): - predictions = util.pad_tensor_after_eos(predictions) + predictions = util.pad_tensor_after_end(predictions) # TODO: beam search requires singleton batches and this # assumes that. Revise if that restriction is ever lifted. targets = [