From 46e65a6d0daee95814e4b75deec7d19f6887aeda Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Tue, 29 Oct 2024 11:45:50 -0700 Subject: [PATCH] minor fixes --- .../src/dolma_classifiers/inference.py | 16 ++------ classifiers/src/dolma_classifiers/models.py | 12 +++--- classifiers/src/dolma_classifiers/train.py | 38 +++++++++++-------- classifiers/src/dolma_classifiers/utils.py | 3 ++ 4 files changed, 36 insertions(+), 33 deletions(-) diff --git a/classifiers/src/dolma_classifiers/inference.py b/classifiers/src/dolma_classifiers/inference.py index 1ca59a26..79ce7b93 100644 --- a/classifiers/src/dolma_classifiers/inference.py +++ b/classifiers/src/dolma_classifiers/inference.py @@ -1,5 +1,4 @@ import argparse -import multiprocessing as mp import time from collections import defaultdict from functools import partial @@ -7,6 +6,7 @@ from multiprocessing import Event, Process from queue import Empty from queue import Queue as QueueType +from threading import Event as EventType from typing import Any, Generator, NamedTuple from urllib.parse import urlparse @@ -25,7 +25,7 @@ from transformers import BatchEncoding, PreTrainedTokenizer from .loggers import ProgressLogger, WandbLogger, get_logger -from .models import Prediction, Registry +from .models import Registry from .utils import cleanup, get_local_gpu_rank, sanitize_model_name, setup @@ -99,13 +99,11 @@ def __iter__(self) -> Generator[Batch, None, None]: self.output_paths_queue.put(OutputPath(source=path, count=count)) - def collate_batch(batch: list[Batch], pad_token_id: int) -> Batch: - max_lengths = [len(b.encoding['input_ids'][0]) for b in batch] # pyright: ignore padded_encodings = { key: pad_sequence( # assuming first dimension is batch size - [b.encoding[key][-1,:] for b in batch], # pyright: ignore + [b.encoding[key][-1, :] for b in batch], # pyright: ignore batch_first=True, padding_value=pad_token_id, ) @@ -119,14 +117,13 @@ def collate_batch(batch: list[Batch], pad_token_id: int) -> Batch: ) - class AttributeRow(NamedTuple): sources: list[str] attributes: list[dict[str, Any]] def writer_worker( - error_event: Event, + error_event: EventType, scores_queue: QueueType[AttributeRow | None], output_paths_queue: QueueType[OutputPath], source_destination_mapping: dict[str, str], @@ -218,8 +215,6 @@ def process_documents( suffix: str | None = None ): """Processes a batch of files using distributed processing.""" - console_logger = get_logger("process_documents") - classifier = Registry.get( model_name=model_name, @@ -232,9 +227,6 @@ def process_documents( # to check if destination path exists (file already processed) fs = fsspec.get_filesystem_class(urlparse(source_paths[0]).scheme)() - # this encoder will be used to write the attributes to the destination file - encoder = msgspec.json.Encoder() - source_destination_mapping = { source_path: destination_path for source_path, destination_path in zip(source_paths, destination_paths) diff --git a/classifiers/src/dolma_classifiers/models.py b/classifiers/src/dolma_classifiers/models.py index 4d8c5ac1..ef1a0dd1 100644 --- a/classifiers/src/dolma_classifiers/models.py +++ b/classifiers/src/dolma_classifiers/models.py @@ -1,4 +1,3 @@ -from functools import partial from typing import NamedTuple, Type import torch @@ -16,7 +15,7 @@ from transformers.modeling_outputs import SequenceClassifierOutput from .loggers import get_logger -from .utils import get_local_gpu_rank, sanitize_model_name +from .utils import sanitize_model_name class Prediction(NamedTuple): @@ -43,12 +42,14 @@ def __init__( compile=compile, trust_remote_code=trust_remote_code, ) - self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.tokenizer = AutoTokenizer.from_pretrained(model_name) # pyright: ignore if len(self.model.config.id2label) > 1: - label_name_fn = lambda label: f"{sanitize_model_name(model_name)}_{sanitize_model_name(label)}" + def label_name_fn(label: str): + return f"{sanitize_model_name(model_name)}_{sanitize_model_name(label)}" else: - label_name_fn = lambda label: sanitize_model_name(model_name) + def label_name_fn(label: str): + return sanitize_model_name(model_name) self.labels_map = { id_: label_name_fn(label) @@ -137,7 +138,6 @@ def forward(self, input_ids, attention_mask, **kwargs): return SequenceClassifierOutput(logits=outputs[:, 0, :]) - @Registry.add("nvidia/quality-classifier-deberta") class DebertaQualityClassifier(BaseQualityClassifier): def _make_model( diff --git a/classifiers/src/dolma_classifiers/train.py b/classifiers/src/dolma_classifiers/train.py index 64e45483..23747b34 100644 --- a/classifiers/src/dolma_classifiers/train.py +++ b/classifiers/src/dolma_classifiers/train.py @@ -1,12 +1,12 @@ import multiprocessing from dataclasses import dataclass -from typing import Callable, Generator, NamedTuple +from functools import partial +from typing import Callable from urllib.parse import urlparse import fsspec import jq import smart_open -import torch from msgspec.json import Decoder from torch.utils.data import Dataset from tqdm import tqdm @@ -18,15 +18,18 @@ class Document: label: str -def read_file(path: str, label: str | None = None, selector: str | None = None) -> list[Document]: +def _label_selector_fn(row: dict, selector: Callable | None, label: str | None) -> str: if selector is not None: - compiled_selector = jq.compile(selector) - label_fn = lambda row: str(compiled_selector.input(row).first()) + return str(selector(row).first()) elif label is not None: - label_fn = lambda row: str(label) + return str(label) else: raise ValueError("Either `label` or `selector` must be provided") + +def read_file(path: str, label: str | None = None, selector: str | None = None) -> list[Document]: + label_fn = partial(_label_selector_fn, label=label, selector=(jq.compile(selector) if selector else None)) + decoder = Decoder() documents = [] @@ -45,10 +48,12 @@ class DataConfig: label: str | None = None selector: str | None = None - def expand(self, fs: fsspec.AbstractFileSystem | None = None) -> list["DataConfig"]: - fs = fs or fsspec.get_filesystem_class(urlparse(self.path).scheme)() - paths = [str(p) for p in fs.glob(self.path)] if "*" in self.path else [self.path] - return [DataConfig(path=path, label=self.label, selector=self.selector) for path in paths] + @staticmethod + def expand(data_config: "DataConfig", fs: fsspec.AbstractFileSystem | None = None) -> list["DataConfig"]: + fs = fs or fsspec.get_filesystem_class(urlparse(data_config.path).scheme)() + assert fs is not None, f"Could not determine filesystem for {data_config.path}" + paths = [str(p) for p in fs.glob(data_config.path)] if "*" in data_config.path else [data_config.path] + return [DataConfig(path=path, label=data_config.label, selector=data_config.selector) for path in paths] class ClassifierDataset(Dataset): @@ -58,19 +63,22 @@ def __init__( workers: int = 1, ): with multiprocessing.Pool(workers) as pool: - expanded_configs = list( - tqdm( - pool.imap_unordered(lambda c: c.expand(), configs), + expanded_configs: list[DataConfig] = [ + data_config + for data_configs in tqdm( + pool.imap_unordered(DataConfig.expand, configs), total=len(configs), desc="Expanding configs", ) - ) + for data_config in data_configs + ] with multiprocessing.Pool(workers) as pool: self.documents = list( tqdm( pool.imap_unordered( - lambda c: read_file(path=c.path, label=c.label, selector=c.selector), expanded_configs + lambda c: read_file(path=c.path, label=c.label, selector=c.selector), + expanded_configs ), total=len(expanded_configs), desc="Reading files", diff --git a/classifiers/src/dolma_classifiers/utils.py b/classifiers/src/dolma_classifiers/utils.py index 56cbc303..b95f1c0b 100644 --- a/classifiers/src/dolma_classifiers/utils.py +++ b/classifiers/src/dolma_classifiers/utils.py @@ -31,6 +31,9 @@ def get_local_gpu_rank() -> int: def setup() -> tuple[int, int]: if (rank := os.environ.get("RANK")) and (world_size := os.environ.get("WORLD_SIZE")): dist.init_process_group("nccl", rank=int(rank), world_size=int(world_size)) + + os.environ["CUDA_VISIBLE_DEVICES"] = str(get_local_gpu_rank()) + return get_rank_and_world_size()