Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed ignore_existing flag not working as expected. #224

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ dev = [
"isort>=5.10.1",
"mypy>=0.971",
"pytest>=5.2",
"types-PyYAML",
"types-dateparser"
]
# extension to process code
code = ["detect-secrets==1.4.0", "beautifulsoup4>=4", "pygments", "regex"]
Expand Down Expand Up @@ -227,7 +229,6 @@ aggressive = 3
[tool.mypy]
python_version = "3.9"
ignore_missing_imports = true
no_site_packages = true
allow_redefinition = false
warn_unused_configs = true
warn_unused_ignores = true
Expand All @@ -238,5 +239,6 @@ show_error_codes = true
pretty = true
plugins = ["numpy.typing.mypy_plugin"]


[tool.mypy-tests]
strict_optional = false
4 changes: 2 additions & 2 deletions python/dolma/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def namespace_to_nested_omegaconf(args: Namespace, structured: Type[T], config:

untyped_config: DictConfig = om.merge(
om.create(config or {}), om.create(nested_config_dict)
) # pyright: ignore (pylance is confused because om.create might return a DictConfig or a ListConfig)
) # type: ignore # (pylance is confused because om.create might return a DictConfig or a ListConfig)

base_structured_config: DictConfig = om.structured(structured)
merged_config = om.merge(base_structured_config, untyped_config)
Expand All @@ -159,7 +159,7 @@ def namespace_to_nested_omegaconf(args: Namespace, structured: Type[T], config:
except OmegaConfBaseException as ex:
raise DolmaConfigError(f"Invalid error while parsing key `{ex.full_key}`: {type(ex).__name__}") from ex

return merged_config # pyright: ignore
return merged_config # type: ignore # (pylance because same error as above)


def print_config(config: Any, console: Optional[Console] = None) -> None:
Expand Down
6 changes: 4 additions & 2 deletions python/dolma/core/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(self, *args, metadata: Optional[Dict[str, Any]] = None, **kwargs) -
self.metadata = metadata or {}

@classmethod
def from_spec(cls, spec: InputSpecWithMetadata) -> "DocumentWithMetadata":
def from_spec(cls, spec: InputSpecWithMetadata) -> "DocumentWithMetadata": # type: ignore[override]
return DocumentWithMetadata(
source=spec.source,
version=spec.version,
Expand Down Expand Up @@ -125,7 +125,9 @@ def __init__(
self.attributes = attributes or {}

@classmethod
def from_spec(cls, spec: InputSpecWithMetadataAndAttributes) -> "DocumentWithMetadataAndAttributes":
def from_spec( # type: ignore[override]
cls, spec: InputSpecWithMetadataAndAttributes
) -> "DocumentWithMetadataAndAttributes":
return DocumentWithMetadataAndAttributes(
source=spec.source,
version=spec.version,
Expand Down
48 changes: 42 additions & 6 deletions python/dolma/core/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,17 @@
TaggerOutputDictType,
)
from .errors import DolmaFatalError, DolmaRetryableFailure, DolmaShardError
from .loggers import get_logger
from .parallel import BaseParallelProcessor, QueueType
from .paths import delete_dir, join_path, make_relative, mkdir_p, split_glob, split_path
from .paths import (
delete_dir,
exists,
join_path,
make_relative,
mkdir_p,
split_glob,
split_path,
)
from .registry import TaggerRegistry
from .utils import import_modules, make_variable_name

Expand Down Expand Up @@ -178,10 +187,10 @@ def _make_output_streams(
mkdir_p(parent)

# open a new file and create a new encoder
io = stack.enter_context(smart_open.open(loc.path, **open_kwargs))
io_ = stack.enter_context(smart_open.open(loc.path, **open_kwargs))
encoder = msgspec.json.Encoder()
opened[loc.path] = TaggerOutputIO(
exp=loc.exp, taggers=set(), path=loc.path, io=io, encoder=encoder
exp=loc.exp, taggers=set(), path=loc.path, io=io_, encoder=encoder
)

# keep track of which taggers are writing to this paths
Expand Down Expand Up @@ -223,7 +232,7 @@ def _write_sample_to_streams(

class TaggerProcessor(BaseParallelProcessor):
@classmethod
def increment_progressbar( # type: ignore
def increment_progressbar( # type: ignore # pylint: disable=arguments-differ
cls,
queue: QueueType, # queue must be the first argument, and it should be a positional-only argument
/,
Expand All @@ -245,6 +254,10 @@ def process_single(
**kwargs,
):
"""Lets count run the taggers! We will use the destination path to save each tagger output."""

# get a logger
logger = get_logger(cls.__name__)

# import tagger modules
taggers_modules = kwargs.get("taggers_modules", None)
if taggers_modules is not None:
Expand All @@ -264,7 +277,9 @@ def process_single(

# this is the dictionary that will hold the output of each tagger
taggers_paths = _determine_output_paths_for_taggers(
experiment_name=experiment_name, destination=destination_path, taggers=taggers
experiment_name=experiment_name,
destination=destination_path,
taggers=taggers,
)

# skip on failure
Expand All @@ -283,6 +298,27 @@ def process_single(
# total number of documents processed
total_docs_cnt = 0

if not kwargs.get("ignore_existing", False):
# we group taggers by their path (this is for cases when two taggers are going to same file)
# and then remove all taggers if any of the paths exists and ignore_existing is True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As written, isn't this block only entered if ignore_existing isn't True? Also I'd consider using 'skip' existing instead of 'ignore' because to me 'ignore' feels a bit ambiguous - are you ignoring that taggers exist for the document, then overwriting them, or are you ignoring the documents that have existing taggers?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! fixed both.

_taggers_by_path: Dict[str, list[str]] = {}
for tagger_name, tagger_location in taggers_paths.items():
_taggers_by_path.setdefault(tagger_location.path, []).append(tagger_name)

# actually take care of removal here
for tagger_path, tagger_names in _taggers_by_path.items():
if exists(tagger_path):
for tagger_name in tagger_names:
logger.info("Skipping %s because %s already exists.", tagger_name, tagger_path)
taggers.pop(tagger_name)
taggers_paths.pop(tagger_name)

if not taggers:
# if all taggers have been removed, we return early
cls.increment_progressbar(queue, files=1)
logger.info("All taggers for %s have been skipped.", source_path)
return

# creating dedicated decoder speeds up the process
# if any of the taggers require metadata, we use a decoder that can handle it
# otherwise, we use a decoder that does not parse metadata, which is faster
Expand Down Expand Up @@ -327,7 +363,7 @@ def process_single(
# double the update interval if the queue is full
update_interval *= 2

except Exception as exp:
except Exception as exp: # pylint: disable=broad-except
# handle any exception that might have occurred
msg = f"Failed to process {source_path} due to {exp.__class__.__name__}: {' '.join(exp.args)}"
if exp.__class__.__name__ == "IncompleteReadError":
Expand Down
2 changes: 1 addition & 1 deletion python/dolma/core/taggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class BaseTaggerWithMetadata(BaseTagger):
def predict(self, doc: DocumentWithMetadata) -> DocResult: # type: ignore
raise NotImplementedError

def tag(self, row: InputSpecWithMetadata) -> TaggerOutputDictType:
def tag(self, row: InputSpecWithMetadata) -> TaggerOutputDictType: # type: ignore
"""Internal function that is used by the tagger to get data"""
doc = DocumentWithMetadata.from_spec(row)
doc_result = self.predict(doc)
Expand Down
5 changes: 3 additions & 2 deletions python/dolma/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,5 +184,6 @@ def _handle_zstd(file_obj, mode):

register_compressor(".zstd", _handle_zstd)
else:
# add zstd compression
add_compression()
# add zstd compression; in case smart_open has zstd support already, this will error out
# with mypy, so we need the type: ignore[unreachable] comment
add_compression() # type: ignore[unreachable]
8 changes: 4 additions & 4 deletions python/dolma/taggers/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,21 @@
from ..core.utils import split_paragraphs

with necessary.necessary("cld3", soft=True) as CLD3_AVAILABLE:
if CLD3_AVAILABLE or TYPE_CHECKING:
if CLD3_AVAILABLE or TYPE_CHECKING: # type: ignore[unreachable]
import cld3 # pyright:ignore pylint:disable=import-error

with necessary.necessary("pycld2", soft=True) as CLD2_AVAILABLE:
if CLD2_AVAILABLE or TYPE_CHECKING:
if CLD2_AVAILABLE or TYPE_CHECKING: # type: ignore[unreachable]
import pycld2 as cld2 # pyright:ignore pylint:disable=import-error


with necessary.necessary("langdetect", soft=True) as LANGDETECT_AVAILABLE:
if LANGDETECT_AVAILABLE or TYPE_CHECKING:
if LANGDETECT_AVAILABLE or TYPE_CHECKING: # type: ignore[unreachable]
from langdetect import PROFILES_DIRECTORY, DetectorFactory, LangDetectException


with necessary.necessary("lingua", soft=True) as LINGUA_AVAILABLE:
if LINGUA_AVAILABLE or TYPE_CHECKING:
if LINGUA_AVAILABLE or TYPE_CHECKING: # type: ignore[unreachable]
from lingua import Language, LanguageDetectorBuilder


Expand Down
20 changes: 14 additions & 6 deletions python/dolma/tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,14 @@
from os import PathLike
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Generator, List, Optional, Tuple, Union
from typing import ( # type: ignore[unreachable,unused-ignore]
TYPE_CHECKING,
Generator,
List,
Optional,
Tuple,
Union,
)

import msgspec
import numpy as np
Expand All @@ -25,8 +32,10 @@
from .data_types import InputSpec, TokenizerOutput

with necessary("transformers", soft=True) as TRANSFORMERS_AVAILABLE:
if TYPE_CHECKING or TRANSFORMERS_AVAILABLE:
from transformers import AutoTokenizer # pylint: disable=import-error
if TYPE_CHECKING or TRANSFORMERS_AVAILABLE: # type: ignore[unreachable,unused-ignore]
from transformers import ( # pyright: ignore # pylint: disable=import-error
AutoTokenizer,
)

PathOrStr = Union[str, PathLike]

Expand Down Expand Up @@ -365,7 +374,6 @@ def tokenize_file(
file, each containing a field named `text`.
"""
tokenizer = make_tokenizer(tokenizer_name_or_path, **tokenizer_kwargs)
dtype = deepcopy(tokenizer.dtype)
decoder = msgspec.json.Decoder(InputSpec)
with smart_open.open(path, mode="rt") as input_stream:
for i, line in enumerate(input_stream, start=1):
Expand All @@ -376,8 +384,8 @@ def tokenize_file(
tokens = tokenizer.encode(text, add_special_tokens=True)
if refresh_tokenizer_every:
# extra copy to prevent memory leaks
tokens = np.array(tokens, dtype=dtype)
yield TokenizerOutput.from_tokens(id=row.id, src=path, loc=i, tokens=tokens) # pyright: ignore
tokens = deepcopy(tokens)
yield TokenizerOutput.from_tokens(id=row.id, src=path, loc=i, tokens=tokens)

if refresh_tokenizer_every > 0 and i % refresh_tokenizer_every == 0:
# to prevent memory leaks, we refresh the tokenizer every so often
Expand Down
4 changes: 2 additions & 2 deletions python/dolma/warc/linearizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from .utils import raise_warc_dependency_error

with necessary("trafilatura", soft=True) as TRAFILATURA_AVAILABLE:
if TRAFILATURA_AVAILABLE or TYPE_CHECKING:
if TRAFILATURA_AVAILABLE or TYPE_CHECKING: # type: ignore[unreachable]
import trafilatura # noqa: F401
import trafilatura.meta # noqa: F401

with necessary("resiliparse", soft=True) as RESILIPARSE_AVAILABLE:
if RESILIPARSE_AVAILABLE or TYPE_CHECKING:
if RESILIPARSE_AVAILABLE or TYPE_CHECKING: # type: ignore[unreachable]
from resiliparse.extract.html2text import extract_plain_text # noqa: F401
from resiliparse.parse.encoding import detect_encoding # noqa: F401
from resiliparse.parse.html import HTMLTree # noqa: F401
Expand Down
18 changes: 10 additions & 8 deletions python/dolma/warc/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
from .utils import UrlNormalizer, raise_warc_dependency_error

with necessary("fastwarc", soft=True) as FASTWARC_AVAILABLE:
if FASTWARC_AVAILABLE or TYPE_CHECKING:
from fastwarc.warc import ArchiveIterator, WarcRecordType
if FASTWARC_AVAILABLE or TYPE_CHECKING: # type: ignore[unreachable]
from fastwarc.warc import ArchiveIterator, WarcHeaderMap, WarcRecordType

with necessary("dateparser", soft=True) as DATEPARSER_AVAILABLE:
if DATEPARSER_AVAILABLE or TYPE_CHECKING:
if DATEPARSER_AVAILABLE or TYPE_CHECKING: # type: ignore[unreachable]
import dateparser


Expand Down Expand Up @@ -164,11 +164,13 @@ def process_single(
if not decoded_content:
continue

# metadata
ctype, *_ = (record.http_headers.get("Content-Type") or "").split(";")
date = cls._parse_warc_timestamp(record.http_headers.get("Date"))
target_uri = record.headers.get("WARC-Target-URI")
payload_id = record.headers.get("WARC-Payload-Digest").split(":")[1].lower()
# collect metadata
# in newer versions of fastwarc, the http_headers could be None if not found
http_headers = record.http_headers or WarcHeaderMap()
ctype, *_ = (http_headers.get("Content-Type") or "").split(";")
date = cls._parse_warc_timestamp(http_headers.get("Date") or "")
target_uri = record.headers.get("WARC-Target-URI") or ""
payload_id = (record.headers.get("WARC-Payload-Digest") or "").split(":")[1].lower()
metadata = dict(
warc_url=target_uri,
url=url_normalizer(target_uri),
Expand Down
4 changes: 2 additions & 2 deletions python/dolma/warc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from ..core.errors import DolmaFatalError

with necessary("w3lib", soft=True) as W3LIB_AVAILABLE:
if W3LIB_AVAILABLE or TYPE_CHECKING:
if W3LIB_AVAILABLE or TYPE_CHECKING: # type: ignore[unreachable]
from w3lib.url import canonicalize_url # noqa: F401

with necessary("url_normalize", soft=True) as URL_NORMALIZE_AVAILABLE:
if URL_NORMALIZE_AVAILABLE or TYPE_CHECKING:
if URL_NORMALIZE_AVAILABLE or TYPE_CHECKING: # type: ignore[unreachable]
from url_normalize import url_normalize # noqa: F401


Expand Down
2 changes: 1 addition & 1 deletion scripts/sample_prefix.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import necessary

with necessary.necessary("click") as CLICK_AVAILABLE:
if CLICK_AVAILABLE or TYPE_CHECKING:
if CLICK_AVAILABLE or TYPE_CHECKING: # type: ignore[unreachable]
import click


Expand Down
Loading