Skip to content

Commit

Permalink
renamed; reformatted
Browse files Browse the repository at this point in the history
  • Loading branch information
soldni committed Nov 24, 2023
1 parent 825d0cb commit f50ce07
Show file tree
Hide file tree
Showing 8 changed files with 220 additions and 102 deletions.
9 changes: 6 additions & 3 deletions python/dolma/core/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,10 +259,12 @@ def process_single(
# too full
update_interval = 1

# running document count; gets reset every time we update the progress
# bar
# running document count; gets reset every time we update the progress bar
docs_cnt = 0

# total number of documents processed
total_docs_cnt = 0

# creating dedicated decoder speeds up the process
# if any of the taggers require metadata, we use a decoder that can handle it
# otherwise, we use a decoder that does not parse metadata, which is faster
Expand Down Expand Up @@ -291,8 +293,9 @@ def process_single(

# increment the number of documents processed so far
docs_cnt += 1
total_docs_cnt += 1

if steps is not None and docs_cnt >= steps:
if steps is not None and total_docs_cnt >= steps:
# if we have reached the maximum number of steps, we break
break

Expand Down
2 changes: 1 addition & 1 deletion python/dolma/taggers/code/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .taggers import (
from .code_taggers import (
CodeCopyrightTagger,
CodeRedPajamaTaggers,
CodeSecretsTagger,
Expand Down
File renamed without changes.
11 changes: 6 additions & 5 deletions python/dolma/taggers/repetitions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .taggers import ParagraphRepetitionsTagger, RepetitionsTagger
from .repetitions_taggers import (
ParagraphRepetitionsTagger,
RepetitionsTagger,
TokenizerRepetitionsTagger,
)

__all__ = [
"RepetitionsTagger",
"ParagraphRepetitionsTagger",
]
__all__ = ["RepetitionsTagger", "ParagraphRepetitionsTagger", "TokenizerRepetitionsTagger"]
136 changes: 136 additions & 0 deletions python/dolma/taggers/repetitions/repetitions_taggers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""
Taggers to detect repetitions in the text.
@soldni
"""

import re
from abc import abstractmethod
from typing import Generator, List

import numpy as np
from tokenizers import Tokenizer

from ...core.data_types import DocResult, Document, Span
from ...core.registry import TaggerRegistry
from ...core.taggers import BaseTagger
from ...core.utils import split_paragraphs
from .utils import find_periodic_sequences


class BaseRepetitionsTagger(BaseTagger):
@abstractmethod
def _extract_from_text(self, text: str) -> Generator[Span, None, None]:
raise NotImplementedError()

def _extract_from_doc(self, doc: Document) -> Generator[Span, None, None]:
yield from self._extract_from_text(doc.text)

def _compute_document_stats(self, spans: List[Span], doc: Document) -> List[Span]:
doc_max_span = Span(
start=0,
end=len(doc.text),
type="doc_max_repetition",
score=max(spans, key=lambda s: s.score).score if spans else 0.0,
)
doc_mean_reps_span = Span(
start=0,
end=len(doc.text),
type="doc_mean_repetition",
score=float(np.mean([s.score for s in spans]) if spans else 0),
)
doc_frac_reps_span = Span(
start=0,
end=len(doc.text),
type="doc_frac_repetition",
score=float(sum([s.score for s in spans]) / len(doc.text) if spans else 0),
)
return [doc_max_span, doc_mean_reps_span, doc_frac_reps_span]

def predict(self, doc: Document) -> DocResult:
"""Predict method for the tagger."""
reps_spans = list(self._extract_from_doc(doc))
document_stats_spans = self._compute_document_stats(spans=reps_spans, doc=doc)
return DocResult(doc=doc, spans=reps_spans + document_stats_spans)


@TaggerRegistry.add("repetitions_v1")
class RepetitionsTagger(BaseRepetitionsTagger):
"""Tagger to detect repetitions of of groups of characters.
Only repetitions that occur at least 4 times are detected."""

def __init__(self) -> None:
self.re_char_repetitions = re.compile(r"(.+?)(\s?\1){3,}")
super().__init__()

def _extract_from_text(self, text: str) -> Generator[Span, None, None]:
"""Extract repetitions of characters in the text."""
for match in self.re_char_repetitions.finditer(text):
yield Span(
start=(start := match.start()),
end=(end := match.end()),
type="repetition",
score=float(end - start),
)


@TaggerRegistry.add("paragraph_repetitions_v1")
class ParagraphRepetitionsTagger(RepetitionsTagger):
"""Tagger to detect repetitions of paragraphs.
It's faster than the char repetition tagger, but it does not account for
repetitions of characters that span multiple paragraphs."""

def _extract_from_doc(self, doc: Document) -> Generator[Span, None, None]:
offset = 0
for paragraph in split_paragraphs(doc.text, remove_empty=False):
for span in self._extract_from_text(paragraph.text):
span.start += offset
span.end += offset
yield span
offset += len(paragraph.text)


@TaggerRegistry.add("tokenizer_repetitions_v1")
class TokenizerRepetitionsTagger(BaseRepetitionsTagger):
"""Tagger to detect repetitions of tokens.
It uses a tokenizer to split the text into tokens, and then identifies
sequences of tokens that repeat at least 3 times."""

TOKENIZER_IDENTIFIER = "allenai/eleuther-ai-gpt-neox-20b-pii-special"
MIN_PERIOD = 1
MAX_PERIOD = 13

def __init__(self) -> None:
self.tokenizer = Tokenizer.from_pretrained(self.TOKENIZER_IDENTIFIER)

def _extract_from_text(self, text: str) -> Generator[Span, None, None]:
tokens = self.tokenizer.encode(text, add_special_tokens=False)
sequences_iter = find_periodic_sequences(
arr=np.array(tokens.ids), min_period=self.MIN_PERIOD, max_period=self.MAX_PERIOD
)
for seq in sequences_iter:
yield Span(
start=(s := tokens.offsets[seq.start][0]),
end=(e := tokens.offsets[seq.end - 1][1]),
type="repetition",
score=float(e - s),
)


@TaggerRegistry.add("paragraph_tokenizer_repetitions_v1")
class ParagraphTokenizerRepetitionsTagger(TokenizerRepetitionsTagger):
"""Tagger to detect repetitions of tokens in paragraphs.
It's faster than the tokenizer repetition tagger, but it does not account for
repetitions of tokens that span multiple paragraphs."""

def _extract_from_doc(self, doc: Document) -> Generator[Span, None, None]:
offset = 0
for paragraph in split_paragraphs(doc.text, remove_empty=False):
# space is required to avoid first symbol in the paragraph to be
# tokenized as a different token.
for span in self._extract_from_text(" " + paragraph.text):
span.start += offset - 1
span.end += offset - 1
yield span
offset += len(paragraph.text)
77 changes: 0 additions & 77 deletions python/dolma/taggers/repetitions/taggers.py

This file was deleted.

17 changes: 14 additions & 3 deletions python/dolma/taggers/repetitions/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Generator, List, NamedTuple, Tuple
from typing import Generator, List, NamedTuple

import numpy as np

Expand Down Expand Up @@ -29,14 +29,16 @@ def group_consecutive_values(arr: np.ndarray, stepsize: int = 1) -> List[np.ndar


class RepetitionTuple(NamedTuple):
"""Tuple to store information about a periodic sequence."""

start: int
end: int
period: int
times: int


def find_periodic_sequences(
arr: np.ndarray, max_period: int, min_period: int = 1
arr: np.ndarray, max_period: int, min_period: int = 1, mask_value: int = -1
) -> Generator[RepetitionTuple, None, None]:
"""Function to find periodic sequences in an array.
Expand All @@ -52,15 +54,24 @@ def find_periodic_sequences(
end at the end of each row), we check the end of the previous row and the
start of the next row to determine the actual start and end positions of the
sequence.
Args:
arr (np.ndarray): The array to search for periodic sequences.
max_period (int): The maximum period to check for.
min_period (int, optional): The minimum period to check for. Defaults to 1.
mask_value (int, optional): The value to use to pad the array. Defaults to -1.
"""
# make sure the mask_value is not in the array
if (arr == mask_value).sum() > 0:
raise ValueError("`mask_value` is in the array")

# no since we can only detect sequences that repeat at least 3 times,
# there is no point in checking for periods greater than 1/3 of the length
max_period = min(max_period, len(arr) // 3)

for period in range(min_period, max_period + 1):
# pad the array so that it can be reshaped into a matrix matching the period
padded_arr = np.pad(arr, (0, period - (len(arr) % period)), constant_values=-1)
padded_arr = np.pad(arr, (0, period - (len(arr) % period)), constant_values=mask_value)
shaped_arr = padded_arr.reshape(-1, period)

# find rows that are equal to the previous row; these are the possibly-periodic sequences
Expand Down
Loading

0 comments on commit f50ce07

Please sign in to comment.