From 229d5a2afcea474d4ee125293631939f49a6d592 Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Tue, 31 Dec 2024 16:30:32 -0800 Subject: [PATCH] Fixed ignore not working --- python/dolma/core/runtime.py | 48 +++++++++++++++++++++++++++++++----- 1 file changed, 42 insertions(+), 6 deletions(-) diff --git a/python/dolma/core/runtime.py b/python/dolma/core/runtime.py index ac5e2a23..d14a4cc0 100644 --- a/python/dolma/core/runtime.py +++ b/python/dolma/core/runtime.py @@ -27,8 +27,17 @@ TaggerOutputDictType, ) from .errors import DolmaFatalError, DolmaRetryableFailure, DolmaShardError +from .loggers import get_logger from .parallel import BaseParallelProcessor, QueueType -from .paths import delete_dir, join_path, make_relative, mkdir_p, split_glob, split_path +from .paths import ( + delete_dir, + exists, + join_path, + make_relative, + mkdir_p, + split_glob, + split_path, +) from .registry import TaggerRegistry from .utils import import_modules, make_variable_name @@ -178,10 +187,10 @@ def _make_output_streams( mkdir_p(parent) # open a new file and create a new encoder - io = stack.enter_context(smart_open.open(loc.path, **open_kwargs)) + io_ = stack.enter_context(smart_open.open(loc.path, **open_kwargs)) encoder = msgspec.json.Encoder() opened[loc.path] = TaggerOutputIO( - exp=loc.exp, taggers=set(), path=loc.path, io=io, encoder=encoder + exp=loc.exp, taggers=set(), path=loc.path, io=io_, encoder=encoder ) # keep track of which taggers are writing to this paths @@ -223,7 +232,7 @@ def _write_sample_to_streams( class TaggerProcessor(BaseParallelProcessor): @classmethod - def increment_progressbar( # type: ignore + def increment_progressbar( # type: ignore # pylint: disable=arguments-differ cls, queue: QueueType, # queue must be the first argument, and it should be a positional-only argument /, @@ -245,6 +254,10 @@ def process_single( **kwargs, ): """Lets count run the taggers! We will use the destination path to save each tagger output.""" + + # get a logger + logger = get_logger(cls.__name__) + # import tagger modules taggers_modules = kwargs.get("taggers_modules", None) if taggers_modules is not None: @@ -264,7 +277,9 @@ def process_single( # this is the dictionary that will hold the output of each tagger taggers_paths = _determine_output_paths_for_taggers( - experiment_name=experiment_name, destination=destination_path, taggers=taggers + experiment_name=experiment_name, + destination=destination_path, + taggers=taggers, ) # skip on failure @@ -283,6 +298,27 @@ def process_single( # total number of documents processed total_docs_cnt = 0 + if not kwargs.get("ignore_existing", False): + # we group taggers by their path (this is for cases when two taggers are going to same file) + # and then remove all taggers if any of the paths exists and ignore_existing is True + _taggers_by_path: Dict[str, list[str]] = {} + for tagger_name, tagger_path in taggers_paths.items(): + _taggers_by_path.setdefault(tagger_path.path, []).append(tagger_name) + + # actually take care of removal here + for tagger_path, tagger_names in _taggers_by_path.items(): + if exists(tagger_path): + for tagger_name in tagger_names: + logger.info("Skipping %s because %s already exists.", tagger_name, tagger_path) + taggers.pop(tagger_name) + taggers_paths.pop(tagger_name) + + if not taggers: + # if all taggers have been removed, we return early + cls.increment_progressbar(queue, files=1) + logger.info("All taggers for %s have been skipped.", source_path) + return + # creating dedicated decoder speeds up the process # if any of the taggers require metadata, we use a decoder that can handle it # otherwise, we use a decoder that does not parse metadata, which is faster @@ -327,7 +363,7 @@ def process_single( # double the update interval if the queue is full update_interval *= 2 - except Exception as exp: + except Exception as exp: # pylint: disable=broad-except # handle any exception that might have occurred msg = f"Failed to process {source_path} due to {exp.__class__.__name__}: {' '.join(exp.args)}" if exp.__class__.__name__ == "IncompleteReadError":