From 58dc815999cf2c6227b77babebbec08dd9f57390 Mon Sep 17 00:00:00 2001 From: JulesBelveze Date: Sun, 3 Nov 2024 18:05:05 +0100 Subject: [PATCH 1/6] [.github] - devops: update dependency installation command in PR workflow - Change the `uv sync` command to install all extras during PR checks --- .github/workflows/pr.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index 4180da1..5b96f2c 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -34,7 +34,7 @@ jobs: cache: true # enable caching - name: Install dependencies - run: uv sync --extra dev + run: uv sync --all-extras - name: Run tests run: uv run task test From 123dca33685dd7ce806b7ef5f3dc915a1c1a0ab4 Mon Sep 17 00:00:00 2001 From: JulesBelveze Date: Sun, 3 Nov 2024 18:05:29 +0100 Subject: [PATCH 2/6] [bert_squeeze/data/modules] - feature: enable trust for remote dataset code during loading - Allow datasets library to execute remote code by setting `trust_remote_code=True`, improving compatibility with datasets hosted externally --- bert_squeeze/data/modules/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bert_squeeze/data/modules/base.py b/bert_squeeze/data/modules/base.py index a833cd3..f074143 100644 --- a/bert_squeeze/data/modules/base.py +++ b/bert_squeeze/data/modules/base.py @@ -27,7 +27,7 @@ def prepare_data(self) -> None: Returns: None """ - self.dataset = datasets.load_dataset(self.dataset_config.path) + self.dataset = datasets.load_dataset(self.dataset_config.path, trust_remote_code=True) if "percent" not in self.dataset_config: return From a49991b224d95b50d1cb3391d3b2f90c970cb161 Mon Sep 17 00:00:00 2001 From: JulesBelveze Date: Sat, 9 Nov 2024 08:42:06 +0100 Subject: [PATCH 3/6] [bert_squeeze/data] - refactor: improve readability of dataset loading - Refactor the line that loads the dataset to span multiple lines for better code readability - Maintain functionality of trusting remote dataset code by setting `trust_remote_code=True` in a more readable format --- bert_squeeze/data/modules/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/bert_squeeze/data/modules/base.py b/bert_squeeze/data/modules/base.py index f074143..b8b81c4 100644 --- a/bert_squeeze/data/modules/base.py +++ b/bert_squeeze/data/modules/base.py @@ -27,7 +27,9 @@ def prepare_data(self) -> None: Returns: None """ - self.dataset = datasets.load_dataset(self.dataset_config.path, trust_remote_code=True) + self.dataset = datasets.load_dataset( + self.dataset_config.path, trust_remote_code=True + ) if "percent" not in self.dataset_config: return From 84f7648863d070044a372be5d2c335de1b0f23e7 Mon Sep 17 00:00:00 2001 From: JulesBelveze Date: Sat, 9 Nov 2024 09:09:29 +0100 Subject: [PATCH 4/6] [bert_squeeze/data] - refactor: remove ConferenceDataset module and references - Deleted the ConferenceDataset class to streamline local_datasets - Removed ConferenceDataset import from the __init__.py to clean up package initialization [docs] - docs: update data documentation to reflect removed ConferenceDataset - Removed reference to ConferenceDataset in the data.rst docs to keep documentation accurate --- bert_squeeze/data/__init__.py | 1 - bert_squeeze/data/local_datasets/__init__.py | 1 - .../data/local_datasets/conference_dataset.py | 80 ------------------- docs/data.rst | 4 - 4 files changed, 86 deletions(-) delete mode 100644 bert_squeeze/data/local_datasets/conference_dataset.py diff --git a/bert_squeeze/data/__init__.py b/bert_squeeze/data/__init__.py index de50778..2e8995b 100644 --- a/bert_squeeze/data/__init__.py +++ b/bert_squeeze/data/__init__.py @@ -1,2 +1 @@ -from .local_datasets import ConferenceDataset, DatasetUnlabeled from .modules import DistillationDataModule, LrDataModule, TransformerDataModule diff --git a/bert_squeeze/data/local_datasets/__init__.py b/bert_squeeze/data/local_datasets/__init__.py index 3d56cc5..c224109 100644 --- a/bert_squeeze/data/local_datasets/__init__.py +++ b/bert_squeeze/data/local_datasets/__init__.py @@ -1,2 +1 @@ -from .conference_dataset import ConferenceDataset from .unlabeled_dataset import DatasetUnlabeled diff --git a/bert_squeeze/data/local_datasets/conference_dataset.py b/bert_squeeze/data/local_datasets/conference_dataset.py deleted file mode 100644 index eaf22a6..0000000 --- a/bert_squeeze/data/local_datasets/conference_dataset.py +++ /dev/null @@ -1,80 +0,0 @@ -from typing import List - -import datasets -import pandas as pd - -_DESCRIPTION = ( - "Dataset used for testing purposes. Taken from here: " - "https://raw.githubusercontent.com/susanli2016/NLP-with-Python/master/data/title_conference.csv" -) - - -class ConferenceConfig(datasets.BuilderConfig): - def __init__(self, **kwargs): - """BuilderConfig for the conference dataset. - Args: - **kwargs: keyword arguments forwarded to super. - """ - super(ConferenceConfig, self).__init__( - version=datasets.Version("1.0.0", ""), **kwargs - ) - - -class ConferenceDataset(datasets.GeneratorBasedBuilder): - """ - Conference dataset - """ - - BUILDER_CONFIG_CLASS = ConferenceConfig - BUILDER_CONFIGS = [ - ConferenceConfig( - name="default", description=_DESCRIPTION, data_dir="classification/" - ), - ConferenceConfig( - name="debug", - description="small chunk of the 'default' configuration.", - data_dir="debug/", - ), - ] - DEFAULT_CONFIG_NAME = "default" - - def _info(self): - return datasets.DatasetInfo( - description=_DESCRIPTION, - features=datasets.Features( - { - "title": datasets.Value("string"), - "label": datasets.ClassLabel( - names=['ISCAS', 'INFOCOM', 'WWW', 'SIGGRAPH', 'VLDB'] - ), - } - ), - supervised_keys=None, - ) - - def _split_generators( - self, dl_manager: datasets.DownloadManager - ) -> List[datasets.SplitGenerator]: - """Returns SplitGenerators.""" - urls_to_download = { - "train": self.config.data_dir + "train.csv", - "test": self.config.data_dir + "test.csv", - } - downloaded_files = dl_manager.download_and_extract(urls_to_download) - return [ - datasets.SplitGenerator( - name=datasets.Split.TRAIN, - gen_kwargs={"filepath": downloaded_files["train"]}, - ), - datasets.SplitGenerator( - name=datasets.Split.TEST, - gen_kwargs={"filepath": downloaded_files["test"]}, - ), - ] - - def _generate_examples(self, filepath): - """Yields examples.""" - df = pd.read_csv(filepath) - - for id, row in df.iterrows(): - yield id, {"title": row["Title"], "label": row["Conference"]} diff --git a/docs/data.rst b/docs/data.rst index d632362..91d07ba 100644 --- a/docs/data.rst +++ b/docs/data.rst @@ -4,10 +4,6 @@ Data bert_squeeze.data.local_datasets ---------------------------------- -.. automodule:: bert_squeeze.data.local_datasets.conference_dataset - :members: - :exclude-members: ConferenceConfig - .. automodule:: bert_squeeze.data.local_datasets.parallel_dataset :members: :exclude-members: ParallelConfig From 5182b3365b704f1872b15555d87ce05d67d6417e Mon Sep 17 00:00:00 2001 From: JulesBelveze Date: Sat, 9 Nov 2024 09:14:08 +0100 Subject: [PATCH 5/6] [misc] - refactor: remove data and local_datasets module - Deleted the `local_datasets` module which managed unlabeled and parallel datasets - Data related to BERT squeeze training removed from .gitignore indicating possible deprecation or refactoring [docs] - docs: update documentation to reflect codebase changes - Removed documentation entries for the now-deleted `local_datasets` module in `bert_squeeze` --- .gitignore | 4 - bert_squeeze/data/local_datasets/__init__.py | 1 - .../data/local_datasets/parallel_dataset.py | 84 ------------------- .../data/local_datasets/unlabeled_dataset.py | 63 -------------- docs/data.rst | 11 --- 5 files changed, 163 deletions(-) delete mode 100644 bert_squeeze/data/local_datasets/__init__.py delete mode 100644 bert_squeeze/data/local_datasets/parallel_dataset.py delete mode 100644 bert_squeeze/data/local_datasets/unlabeled_dataset.py diff --git a/.gitignore b/.gitignore index 412446d..18b495a 100644 --- a/.gitignore +++ b/.gitignore @@ -10,10 +10,6 @@ # docs docs/_build -# data -bert-squeeze/data/classification -bert-squeeze/data/unlabeled - # various .neptune/ outputs/ diff --git a/bert_squeeze/data/local_datasets/__init__.py b/bert_squeeze/data/local_datasets/__init__.py deleted file mode 100644 index c224109..0000000 --- a/bert_squeeze/data/local_datasets/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .unlabeled_dataset import DatasetUnlabeled diff --git a/bert_squeeze/data/local_datasets/parallel_dataset.py b/bert_squeeze/data/local_datasets/parallel_dataset.py deleted file mode 100644 index 7a86c2f..0000000 --- a/bert_squeeze/data/local_datasets/parallel_dataset.py +++ /dev/null @@ -1,84 +0,0 @@ -import json -from typing import List - -import datasets - -_DESCRIPTION = "Dataset containing parallel data." - - -class ParallelConfig(datasets.BuilderConfig): - def __init__(self, **kwargs): - """BuilderConfig for the below dataset. - Args: - **kwargs: keyword arguments forwarded to super. - """ - super(ParallelConfig, self).__init__( - version=datasets.Version("1.0.0", ""), **kwargs - ) - - -class ParallelDataset(datasets.GeneratorBasedBuilder): - """ - Dataset for parallel distillation - """ - - BUILDER_CONFIG_CLASS = ParallelConfig - BUILDER_CONFIGS = [ - ParallelConfig(name="default", description=_DESCRIPTION, data_dir="parallel/"), - ParallelConfig( - name="debug", - description="small chunk of the 'default' configuration.", - data_dir="debug", - ), - ] - DEFAULT_CONFIG_NAME = "default" - - def _info(self): - return datasets.DatasetInfo( - description=_DESCRIPTION, - features=datasets.Features( - { - "text": datasets.Value("string"), - "translation": datasets.Value("string"), - "lang": datasets.Value("string"), - } - ), - supervised_keys=None, - ) - - def _split_generators( - self, dl_manager: datasets.DownloadManager - ) -> List[datasets.SplitGenerator]: - """Returns SplitGenerators.""" - urls_to_download = { - "train": self.config.data_dir + "train.json", - "test": self.config.data_dir + "test.json", - "validation": self.config.data_dir + "validation.json", - } - downloaded_files = dl_manager.download_and_extract(urls_to_download) - return [ - datasets.SplitGenerator( - name=datasets.Split.TRAIN, - gen_kwargs={"filepath": downloaded_files["train"]}, - ), - datasets.SplitGenerator( - name=datasets.Split.TEST, - gen_kwargs={"filepath": downloaded_files["test"]}, - ), - datasets.SplitGenerator( - name=datasets.Split.VALIDATION, - gen_kwargs={"filepath": downloaded_files["validation"]}, - ), - ] - - def _generate_examples(self, filepath): - """Yields examples.""" - with open(filepath, "r") as reader: - data = json.load(reader) - - for id, row in enumerate(data): - yield id, { - "text": row["text"], - "translation": row["translation"], - "lang": row["lang"], - } diff --git a/bert_squeeze/data/local_datasets/unlabeled_dataset.py b/bert_squeeze/data/local_datasets/unlabeled_dataset.py deleted file mode 100644 index fce895f..0000000 --- a/bert_squeeze/data/local_datasets/unlabeled_dataset.py +++ /dev/null @@ -1,63 +0,0 @@ -from typing import List - -import datasets -import pandas as pd - -_DESCRIPTION = "Helper dataset to perform soft distillation." - - -class UnlabeledConfig(datasets.BuilderConfig): - def __init__(self, **kwargs): - """BuilderConfig for the below dataset. - Args: - **kwargs: keyword arguments forwarded to super. - """ - super(UnlabeledConfig, self).__init__( - version=datasets.Version("1.0.0", ""), **kwargs - ) - - -class DatasetUnlabeled(datasets.GeneratorBasedBuilder): - """ - Dataset to use for soft distillation. - """ - - BUILDER_CONFIG_CLASS = UnlabeledConfig - BUILDER_CONFIGS = [ - UnlabeledConfig(name="default", description=_DESCRIPTION, data_dir="unlabeled/"), - UnlabeledConfig( - name="debug", - description="small chunk of the 'default' configuration.", - data_dir="debug/", - ), - ] - DEFAULT_CONFIG_NAME = "default" - - def _info(self): - return datasets.DatasetInfo( - description=_DESCRIPTION, - features=datasets.Features( - {"text": datasets.Value("string"), "id": datasets.Value("int16")} - ), - supervised_keys=None, - ) - - def _split_generators( - self, dl_manager: datasets.DownloadManager - ) -> List[datasets.SplitGenerator]: - """Returns SplitGenerators.""" - urls_to_download = {"train": self.config.data_dir + "train.csv"} - downloaded_files = dl_manager.download_and_extract(urls_to_download) - return [ - datasets.SplitGenerator( - name=datasets.Split.TRAIN, - gen_kwargs={"filepath": downloaded_files["train"]}, - ) - ] - - def _generate_examples(self, filepath): - """Yields examples.""" - df = pd.read_csv(filepath) - - for id, row in df.iterrows(): - yield id, {"text": row["text"], "id": id} diff --git a/docs/data.rst b/docs/data.rst index 91d07ba..082dff4 100644 --- a/docs/data.rst +++ b/docs/data.rst @@ -1,17 +1,6 @@ Data ====================== -bert_squeeze.data.local_datasets ----------------------------------- - -.. automodule:: bert_squeeze.data.local_datasets.parallel_dataset - :members: - :exclude-members: ParallelConfig - -.. automodule:: bert_squeeze.data.local_datasets.unlabeled_dataset - :members: - :exclude-members: UnlabeledConfig - bert_squeeze.data.modules ---------------------------- From 22eb8a1b236a83a5d25fdf1dcbca63b467c8e6fb Mon Sep 17 00:00:00 2001 From: JulesBelveze Date: Sat, 9 Nov 2024 16:10:31 +0100 Subject: [PATCH 6/6] [bert_squeeze] - refactor: update TransformerParallelDataModule for translation tasks - Removed hardcoded text column name in favor of dynamic translation column configuration - Added a filter to exclude entries without translations before tokenization - Fixed mismatched attention mask column name in tokenized_dataset [tests] - test: change DistilAssistant test to use `kmfoda/booksum` dataset parameters - Updated test cases to use `booksum` dataset path and specific configuration parameters like `percent`, `target_col`, and `source_col` - Modified asserts to expect different lengths for train and validation data loaders based on `booksum` dataset --- .../data/modules/transformer_module.py | 9 ++-- tests/assistants/test_distil_assistant.py | 48 ++++++++++--------- 2 files changed, 31 insertions(+), 26 deletions(-) diff --git a/bert_squeeze/data/modules/transformer_module.py b/bert_squeeze/data/modules/transformer_module.py index 4535188..28181c9 100644 --- a/bert_squeeze/data/modules/transformer_module.py +++ b/bert_squeeze/data/modules/transformer_module.py @@ -134,8 +134,8 @@ class TransformerParallelDataModule(TransformerDataModule): def __init__( self, dataset_config: DictConfig, tokenizer_name: str, max_length: int, **kwargs ): - dataset_config.text_col = "text" dataset_config.label_col = None + self.translation_col = dataset_config.get("translation_col", "translation") super().__init__(dataset_config, tokenizer_name, max_length, **kwargs) @overrides @@ -144,6 +144,7 @@ def featurize(self) -> datasets.DatasetDict: Returns: DatasetDict: featurized dataset """ + self.dataset = self.dataset.filter(lambda x: x[self.translation_col] is not None) tokenized_dataset = self.dataset.map( lambda x: self.tokenizer( x[self.text_col], @@ -156,7 +157,7 @@ def featurize(self) -> datasets.DatasetDict: lambda x: { "translation_" + name: value for name, value in self.tokenizer( - x["translation"], + x[self.translation_col], padding="max_length", max_length=self.max_length, truncation=True, @@ -164,14 +165,14 @@ def featurize(self) -> datasets.DatasetDict: } ) tokenized_dataset = tokenized_dataset.remove_columns( - [self.text_col, "translation"] + [self.text_col, self.translation_col] ) columns = [ "input_ids", "attention_mask", "translation_input_ids", - "translation_input_ids", + "translation_attention_mask", ] if "distilbert" not in self.tokenizer.name_or_path: columns += ["token_type_ids", "translation_token_type_ids"] diff --git a/tests/assistants/test_distil_assistant.py b/tests/assistants/test_distil_assistant.py index adabcb5..51847a6 100644 --- a/tests/assistants/test_distil_assistant.py +++ b/tests/assistants/test_distil_assistant.py @@ -410,10 +410,12 @@ def test_two_hf_models(self, caplog): distil_assistant = DistilAssistant( "distil-parallel", data_kwargs={ - "path": resource_filename( - "bert_squeeze", "data/local_datasets/parallel_dataset.py" - ), - "is_local": True, + "path": "kmfoda/booksum", + "percent": 5, + "target_col": "summary_text", + "source_col": "chapter", + "train_batch_size": 16, + "eval_batch_size": 4, }, ) assert distil_assistant.teacher is None @@ -430,10 +432,10 @@ def test_student_torch_model(self, caplog): distil_assistant = DistilAssistant( "distil-parallel", data_kwargs={ - "path": resource_filename( - "bert_squeeze", "data/local_datasets/parallel_dataset.py" - ), - "is_local": True, + "path": "kmfoda/booksum", + "percent": 5, + "target_col": "summary_text", + "source_col": "chapter", }, student_kwargs={ "_target_": "tests.fixtures.dummy_models.Lr", @@ -453,10 +455,10 @@ def test_teacher_torch_model(self, caplog): distil_assistant = DistilAssistant( "distil-parallel", data_kwargs={ - "path": resource_filename( - "bert_squeeze", "data/local_datasets/parallel_dataset.py" - ), - "is_local": True, + "path": "kmfoda/booksum", + "percent": 5, + "target_col": "summary_text", + "source_col": "chapter", }, teacher_kwargs={ "_target_": "tests.fixtures.dummy_models.Lr", @@ -476,10 +478,12 @@ def test_torch_models(self, caplog): distil_assistant = DistilAssistant( "distil-parallel", data_kwargs={ - "path": resource_filename( - "bert_squeeze", "data/local_datasets/parallel_dataset.py" - ), - "is_local": True, + "path": "kmfoda/booksum", + "percent": 5, + "target_col": "summary_text", + "source_col": "chapter", + "train_batch_size": 16, + "eval_batch_size": 4, }, teacher_kwargs={ "_target_": "tests.fixtures.dummy_models.Lr", @@ -510,17 +514,17 @@ def test_data(self): "_target_": "tests.fixtures.dummy_models.Lr", }, data_kwargs={ - "path": resource_filename( - "bert_squeeze", "data/local_datasets/parallel_dataset.py" - ), - "is_local": True, + "path": "kmfoda/booksum", + "percent": 5, + "text_col": "summary_text", + "translation_col": "summary_analysis", "train_batch_size": 16, "eval_batch_size": 4, }, ) assert isinstance(distil_assistant.data.train_dataloader(), DataLoader) - assert len(distil_assistant.data.train_dataloader()) == 187 - assert len(distil_assistant.data.val_dataloader()) == 125 + assert len(distil_assistant.data.train_dataloader()) == 22 + assert len(distil_assistant.data.val_dataloader()) == 5 class TestDistilSeq2SeqAssistant: