Skip to content

Commit

Permalink
Fixed ignore not working
Browse files Browse the repository at this point in the history
  • Loading branch information
soldni authored Jan 1, 2025
1 parent a824220 commit 229d5a2
Showing 1 changed file with 42 additions and 6 deletions.
48 changes: 42 additions & 6 deletions python/dolma/core/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,17 @@
TaggerOutputDictType,
)
from .errors import DolmaFatalError, DolmaRetryableFailure, DolmaShardError
from .loggers import get_logger
from .parallel import BaseParallelProcessor, QueueType
from .paths import delete_dir, join_path, make_relative, mkdir_p, split_glob, split_path
from .paths import (
delete_dir,
exists,
join_path,
make_relative,
mkdir_p,
split_glob,
split_path,
)
from .registry import TaggerRegistry
from .utils import import_modules, make_variable_name

Expand Down Expand Up @@ -178,10 +187,10 @@ def _make_output_streams(
mkdir_p(parent)

# open a new file and create a new encoder
io = stack.enter_context(smart_open.open(loc.path, **open_kwargs))
io_ = stack.enter_context(smart_open.open(loc.path, **open_kwargs))
encoder = msgspec.json.Encoder()
opened[loc.path] = TaggerOutputIO(
exp=loc.exp, taggers=set(), path=loc.path, io=io, encoder=encoder
exp=loc.exp, taggers=set(), path=loc.path, io=io_, encoder=encoder
)

# keep track of which taggers are writing to this paths
Expand Down Expand Up @@ -223,7 +232,7 @@ def _write_sample_to_streams(

class TaggerProcessor(BaseParallelProcessor):
@classmethod
def increment_progressbar( # type: ignore
def increment_progressbar( # type: ignore # pylint: disable=arguments-differ
cls,
queue: QueueType, # queue must be the first argument, and it should be a positional-only argument
/,
Expand All @@ -245,6 +254,10 @@ def process_single(
**kwargs,
):
"""Lets count run the taggers! We will use the destination path to save each tagger output."""

# get a logger
logger = get_logger(cls.__name__)

# import tagger modules
taggers_modules = kwargs.get("taggers_modules", None)
if taggers_modules is not None:
Expand All @@ -264,7 +277,9 @@ def process_single(

# this is the dictionary that will hold the output of each tagger
taggers_paths = _determine_output_paths_for_taggers(
experiment_name=experiment_name, destination=destination_path, taggers=taggers
experiment_name=experiment_name,
destination=destination_path,
taggers=taggers,
)

# skip on failure
Expand All @@ -283,6 +298,27 @@ def process_single(
# total number of documents processed
total_docs_cnt = 0

if not kwargs.get("ignore_existing", False):
# we group taggers by their path (this is for cases when two taggers are going to same file)
# and then remove all taggers if any of the paths exists and ignore_existing is True
_taggers_by_path: Dict[str, list[str]] = {}
for tagger_name, tagger_path in taggers_paths.items():
_taggers_by_path.setdefault(tagger_path.path, []).append(tagger_name)

# actually take care of removal here
for tagger_path, tagger_names in _taggers_by_path.items():
if exists(tagger_path):
for tagger_name in tagger_names:
logger.info("Skipping %s because %s already exists.", tagger_name, tagger_path)
taggers.pop(tagger_name)
taggers_paths.pop(tagger_name)

if not taggers:
# if all taggers have been removed, we return early
cls.increment_progressbar(queue, files=1)
logger.info("All taggers for %s have been skipped.", source_path)
return

# creating dedicated decoder speeds up the process
# if any of the taggers require metadata, we use a decoder that can handle it
# otherwise, we use a decoder that does not parse metadata, which is faster
Expand Down Expand Up @@ -327,7 +363,7 @@ def process_single(
# double the update interval if the queue is full
update_interval *= 2

except Exception as exp:
except Exception as exp: # pylint: disable=broad-except
# handle any exception that might have occurred
msg = f"Failed to process {source_path} due to {exp.__class__.__name__}: {' '.join(exp.args)}"
if exp.__class__.__name__ == "IncompleteReadError":
Expand Down

0 comments on commit 229d5a2

Please sign in to comment.