diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000..da99824aa3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,141 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# VSCode +.vscode + +# IntelliJ +.idea + +# Mac .DS_Store +.DS_Store + +# More test things +wandb \ No newline at end of file diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000000..1aba38f67a --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +include LICENSE diff --git a/Makefile b/Makefile new file mode 100644 index 0000000000..e1c15c5959 --- /dev/null +++ b/Makefile @@ -0,0 +1,19 @@ +.PHONY: quality style test docs + +check_dirs := src + +# Check that source code meets quality standards + +# this target runs checks on all files +quality: + black --check $(check_dirs) + isort --check-only $(check_dirs) + flake8 $(check_dirs) + python utils/style_doc.py src --max_len 119 --check_only + +# Format source code automatically and check is there are any problems left that need manual fixing +style: + black $(check_dirs) + isort $(check_dirs) + python utils/style_doc.py src --max_len 119 + \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000..b7465bb131 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,3 @@ +[tool.black] +line-length = 119 +target-version = ['py36'] diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000000..6b26312126 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,23 @@ +[isort] +default_section = FIRSTPARTY +ensure_newline_before_comments = True +force_grid_wrap = 0 +include_trailing_comma = True +known_first_party = pet +known_third_party = + numpy + torch + accelerate + transformers + +line_length = 119 +lines_after_imports = 2 +multi_line_output = 3 +use_parentheses = True + +[flake8] +ignore = E203, E722, E501, E741, W503, W605 +max-line-length = 119 + +[tool:pytest] +doctest_optionflags=NUMBER NORMALIZE_WHITESPACE ELLIPSIS \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000..9c8edd91d3 --- /dev/null +++ b/setup.py @@ -0,0 +1,78 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from setuptools import setup +from setuptools import find_packages + +extras = {} +extras["quality"] = ["black ~= 22.0", "isort >= 5.5.4", "flake8 >= 3.8.3"] +extras["dev"] = extras["quality"] + +setup( + name="pets", + version="0.1.0.dev0", + description="Parameter-Efficient Tuning at Scale (PETS)", + long_description=open("README.md", "r", encoding="utf-8").read(), + long_description_content_type="text/markdown", + keywords="deep learning", + license="Apache", + author="The HuggingFace team", + author_email="sourab@huggingface.co", + url="https://github.com/huggingface/pets", + package_dir={"": "src"}, + packages=find_packages("src"), + entry_points={}, + python_requires=">=3.7.0", + install_requires=[ + "numpy>=1.17", + "packaging>=20.0", + "psutil", + "pyyaml", + "torch>=1.4.0", + "transformers", + "accelerate", + ], + extras_require=extras, + classifiers=[ + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + ], +) + +# Release checklist +# 1. Change the version in __init__.py and setup.py. +# 2. Commit these changes with the message: "Release: VERSION" +# 3. Add a tag in git to mark the release: "git tag VERSION -m 'Adds tag VERSION for pypi' " +# Push the tag to git: git push --tags origin main +# 4. Run the following commands in the top-level directory: +# python setup.py bdist_wheel +# python setup.py sdist +# 5. Upload the package to the pypi test server first: +# twine upload dist/* -r pypitest +# twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/ +# 6. Check that you can install it in a virtualenv by running: +# pip install -i https://testpypi.python.org/pypi accelerate +# accelerate env +# accelerate test +# 7. Upload the final version to actual pypi: +# twine upload dist/* -r pypi +# 8. Add release notes to the tag in github once everything is looking hunky-dory. +# 9. Update the version in __init__.py, setup.py to the new version "-dev" and push to master diff --git a/src/pet/__init__.py b/src/pet/__init__.py new file mode 100644 index 0000000000..2436302799 --- /dev/null +++ b/src/pet/__init__.py @@ -0,0 +1,18 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +__version__ = "0.1.0.dev0" + +from .pet_model import ( + ParameterEfficientTuningModel, + ParameterEfficientTuningModelForSequenceClassification, + PromptEncoderType, +) +from .tuners import ( + PrefixEncoder, + PromptEmbedding, + PromptEncoder, + PromptEncoderReparameterizationType, + PromptTuningInit, +) diff --git a/src/pet.py b/src/pet/pet_model.py similarity index 97% rename from src/pet.py rename to src/pet/pet_model.py index 7ac208b27d..9345539e51 100644 --- a/src/pet.py +++ b/src/pet/pet_model.py @@ -1,12 +1,14 @@ -from collections import OrderedDict import enum import warnings +from collections import OrderedDict + import torch +from accelerate.state import AcceleratorState from transformers import PreTrainedModel + from tuners.p_tuning import PromptEncoder from tuners.prefix_tuning import PrefixEncoder from tuners.prompt_tuning import PromptEmbedding -from accelerate.state import AcceleratorState class PromptEncoderType(str, enum.Enum): @@ -88,8 +90,8 @@ def state_dict(self, destination=None, prefix=None, keep_vars=False): def load_state_dict(self, state_dict, strict: bool = True): """ - Custom load state dict method that only loads prompt table and prompt encoder - parameters. Matching load method for this class' custom state dict method. + Custom load state dict method that only loads prompt table and prompt encoder parameters. Matching load method + for this class' custom state dict method. """ self.prompt_encoder.embedding.load_state_dict({"weight": state_dict["prompt_embeddings"]}, strict) @@ -187,8 +189,8 @@ def state_dict(self, destination=None, prefix=None, keep_vars=False): def load_state_dict(self, state_dict, strict: bool = True): """ - Custom load state dict method that only loads prompt table and prompt encoder - parameters. Matching load method for this class' custom state dict method. + Custom load state dict method that only loads prompt table and prompt encoder parameters. Matching load method + for this class' custom state dict method. """ super().load_state_dict(state_dict["prompt_encoder"], strict) self.model.classifier.load_state_dict(state_dict["classifier"], strict) diff --git a/src/prompt_learning_legacy.py b/src/pet/prompt_learning_legacy.py similarity index 91% rename from src/prompt_learning_legacy.py rename to src/pet/prompt_learning_legacy.py index ffb6934396..b08d01851d 100644 --- a/src/prompt_learning_legacy.py +++ b/src/pet/prompt_learning_legacy.py @@ -1,34 +1,27 @@ import enum -import torch +import functools import math import os +from collections import OrderedDict -from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss -from transformers import PreTrainedModel -from transformers.modeling_outputs import SequenceClassifierOutput -from transformers import AutoModelForSequenceClassification -from datasets import load_dataset -import evaluate import torch -from transformers import AutoTokenizer, get_linear_schedule_with_warmup, set_seed -from torch.utils.data import DataLoader from accelerate import Accelerator from accelerate.state import AcceleratorState from accelerate.utils.dataclasses import FullyShardedDataParallelPlugin -import functools -from torch.distributed.fsdp import ( - FullyShardedDataParallel, - CPUOffload, -) -from torch.distributed.fsdp.wrap import ( - enable_wrap, - wrap, - ModuleWrapPolicy, - transformer_auto_wrap_policy, - lambda_auto_wrap_policy, - _or_policy, +from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from torch.utils.data import DataLoader +from transformers import ( + AutoModelForSequenceClassification, + AutoTokenizer, + PreTrainedModel, + get_linear_schedule_with_warmup, + set_seed, ) -from collections import OrderedDict +from transformers.modeling_outputs import SequenceClassifierOutput + +import evaluate +from datasets import load_dataset class PromptEncoderReparameterizationType(str, enum.Enum): @@ -49,8 +42,7 @@ class PromptTuningInit(str, enum.Enum): class PromptEncoder(torch.nn.Module): """ - The prompt encoder network that is used to generate the virtual - token embeddings for p-tuning. + The prompt encoder network that is used to generate the virtual token embeddings for p-tuning. """ def __init__(self, config): @@ -92,13 +84,23 @@ def __init__(self, config): ) elif self.encoder_type == PromptEncoderReparameterizationType.MLP: - layers = [torch.nn.Linear(self.input_size, self.hidden_size), torch.nn.ReLU()] - layers.extend([torch.nn.Linear(self.hidden_size, self.hidden_size), torch.nn.ReLU()]) + layers = [ + torch.nn.Linear(self.input_size, self.hidden_size), + torch.nn.ReLU(), + ] + layers.extend( + [ + torch.nn.Linear(self.hidden_size, self.hidden_size), + torch.nn.ReLU(), + ] + ) layers.append(torch.nn.Linear(self.hidden_size, self.output_size)) self.mlp_head = torch.nn.Sequential(*layers) else: - raise ValueError("Prompt encoder type not recognized. Please use one of MLP (recommended) or LSTM.") + raise ValueError( + "Prompt encoder type not recognized. " " Please use one of MLP (recommended) or LSTM." + ) def forward(self, indices): input_embeds = self.embedding(indices) @@ -130,11 +132,15 @@ def __init__(self, config): self.trans = torch.nn.Sequential( torch.nn.Linear(config["token_dim"], config["prompt_hidden_size"]), torch.nn.Tanh(), - torch.nn.Linear(config["prompt_hidden_size"], config["num_layers"] * 2 * config["token_dim"]), + torch.nn.Linear( + config["prompt_hidden_size"], + config["num_layers"] * 2 * config["token_dim"], + ), ) else: self.embedding = torch.nn.Embedding( - config["num_virtual_tokens"], config["num_layers"] * 2 * config["token_dim"] + config["num_virtual_tokens"], + config["num_layers"] * 2 * config["token_dim"], ) def forward(self, prefix: torch.Tensor): @@ -247,8 +253,8 @@ def state_dict(self, destination=None, prefix=None, keep_vars=False): def load_state_dict(self, state_dict, strict: bool = True): """ - Custom load state dict method that only loads prompt table and prompt encoder - parameters. Matching load method for this class' custom state dict method. + Custom load state dict method that only loads prompt table and prompt encoder parameters. Matching load method + for this class' custom state dict method. """ self.prompt_encoder.embedding.load_state_dict({"weight": state_dict["prompt_embeddings"]}, strict) @@ -389,8 +395,8 @@ def state_dict(self, destination=None, prefix=None, keep_vars=False): def load_state_dict(self, state_dict, strict: bool = True): """ - Custom load state dict method that only loads prompt table and prompt encoder - parameters. Matching load method for this class' custom state dict method. + Custom load state dict method that only loads prompt table and prompt encoder parameters. Matching load method + for this class' custom state dict method. """ super().load_state_dict(state_dict["prompt_encoder"], strict) self.classifier.load_state_dict(state_dict["classifier"], strict) @@ -528,7 +534,7 @@ def main(): batch_size = 16 lr = 5e-3 num_epochs = 100 - device = "cuda" + # device = "cuda" seed = 11 set_seed(seed) @@ -544,7 +550,12 @@ def main(): def tokenize_function(examples): # max_length=None => use the model max length (it's actually the default) - outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None) + outputs = tokenizer( + examples["sentence1"], + examples["sentence2"], + truncation=True, + max_length=None, + ) return outputs # Apply the method we just defined to all the examples in all the splits of the dataset @@ -564,10 +575,16 @@ def collate_fn(examples): # Instantiate dataloaders. train_dataloader = DataLoader( - tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size + tokenized_datasets["train"], + shuffle=True, + collate_fn=collate_fn, + batch_size=batch_size, ) eval_dataloader = DataLoader( - tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=batch_size + tokenized_datasets["validation"], + shuffle=False, + collate_fn=collate_fn, + batch_size=batch_size, ) # Instantiate optimizer @@ -582,9 +599,13 @@ def collate_fn(examples): accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(model) - model, train_dataloader, eval_dataloader, optimizer, lr_scheduler = accelerator.prepare( - model, train_dataloader, eval_dataloader, optimizer, lr_scheduler - ) + ( + model, + train_dataloader, + eval_dataloader, + optimizer, + lr_scheduler, + ) = accelerator.prepare(model, train_dataloader, eval_dataloader, optimizer, lr_scheduler) accelerator.print(model) for epoch in range(num_epochs): @@ -616,17 +637,14 @@ def collate_fn(examples): accelerator.print(f"epoch {epoch}:", eval_metric) accelerator.print(f"epoch {epoch} train loss:", total_loss / len(train_dataloader)) + from torch.distributed.fsdp.fully_sharded_data_parallel import FullStateDictConfig from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP - from torch.distributed.fsdp.fully_sharded_data_parallel import ( - BackwardPrefetch, - CPUOffload, - FullStateDictConfig, - ShardingStrategy, - StateDictType, - ) + from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType FSDP.set_state_dict_type( - model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + model, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True), ) state_dict = model.state_dict() state_dict = model.clean_state_dict(state_dict) diff --git a/src/pet/tuners/__init__.py b/src/pet/tuners/__init__.py new file mode 100644 index 0000000000..41ae36c5c8 --- /dev/null +++ b/src/pet/tuners/__init__.py @@ -0,0 +1,7 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all + +from .p_tuning import PromptEncoder, PromptEncoderReparameterizationType +from .prefix_tuning import PrefixEncoder +from .prompt_tuning import PromptEmbedding, PromptTuningInit diff --git a/src/tuners/lora.py b/src/pet/tuners/lora.py similarity index 100% rename from src/tuners/lora.py rename to src/pet/tuners/lora.py diff --git a/src/tuners/p_tuning.py b/src/pet/tuners/p_tuning.py similarity index 87% rename from src/tuners/p_tuning.py rename to src/pet/tuners/p_tuning.py index 2c103c3f5f..5b161c39a2 100644 --- a/src/tuners/p_tuning.py +++ b/src/pet/tuners/p_tuning.py @@ -1,6 +1,7 @@ -import torch import enum +import torch + class PromptEncoderReparameterizationType(str, enum.Enum): MLP = "MLP" @@ -11,8 +12,7 @@ class PromptEncoderReparameterizationType(str, enum.Enum): # with some refactor class PromptEncoder(torch.nn.Module): """ - The prompt encoder network that is used to generate the virtual - token embeddings for p-tuning. + The prompt encoder network that is used to generate the virtual token embeddings for p-tuning. """ def __init__(self, config): @@ -54,8 +54,16 @@ def __init__(self, config): ) elif self.encoder_type == PromptEncoderReparameterizationType.MLP: - layers = [torch.nn.Linear(self.input_size, self.hidden_size), torch.nn.ReLU()] - layers.extend([torch.nn.Linear(self.hidden_size, self.hidden_size), torch.nn.ReLU()]) + layers = [ + torch.nn.Linear(self.input_size, self.hidden_size), + torch.nn.ReLU(), + ] + layers.extend( + [ + torch.nn.Linear(self.hidden_size, self.hidden_size), + torch.nn.ReLU(), + ] + ) layers.append(torch.nn.Linear(self.hidden_size, self.output_size)) self.mlp_head = torch.nn.Sequential(*layers) diff --git a/src/tuners/prefix_tuning.py b/src/pet/tuners/prefix_tuning.py similarity index 81% rename from src/tuners/prefix_tuning.py rename to src/pet/tuners/prefix_tuning.py index c593b1eff5..761545f55b 100644 --- a/src/tuners/prefix_tuning.py +++ b/src/pet/tuners/prefix_tuning.py @@ -1,5 +1,6 @@ import torch + # Based on https://github.com/THUDM/P-tuning-v2/blob/main/model/prefix_encoder.py # with some refactor class PrefixEncoder(torch.nn.Module): @@ -20,11 +21,15 @@ def __init__(self, config): self.trans = torch.nn.Sequential( torch.nn.Linear(config["token_dim"], config["prompt_hidden_size"]), torch.nn.Tanh(), - torch.nn.Linear(config["prompt_hidden_size"], config["num_layers"] * 2 * config["token_dim"]), + torch.nn.Linear( + config["prompt_hidden_size"], + config["num_layers"] * 2 * config["token_dim"], + ), ) else: self.embedding = torch.nn.Embedding( - config["num_virtual_tokens"], config["num_layers"] * 2 * config["token_dim"] + config["num_virtual_tokens"], + config["num_layers"] * 2 * config["token_dim"], ) def forward(self, prefix: torch.Tensor): diff --git a/src/tuners/prompt_tuning.py b/src/pet/tuners/prompt_tuning.py similarity index 99% rename from src/tuners/prompt_tuning.py rename to src/pet/tuners/prompt_tuning.py index cb4fc8df0a..c3bfed7ece 100644 --- a/src/tuners/prompt_tuning.py +++ b/src/pet/tuners/prompt_tuning.py @@ -1,7 +1,8 @@ -import torch import enum import math +import torch + class PromptTuningInit(str, enum.Enum): TEXT = "TEXT" diff --git a/src/utils/constants.py b/src/pet/utils/constants.py similarity index 100% rename from src/utils/constants.py rename to src/pet/utils/constants.py diff --git a/utils/style_doc.py b/utils/style_doc.py new file mode 100644 index 0000000000..0422ebeb4e --- /dev/null +++ b/utils/style_doc.py @@ -0,0 +1,556 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Style utils for the .rst and the docstrings.""" + +import argparse +import os +import re +import warnings + +import black + + +BLACK_AVOID_PATTERNS = {} + + +# Regexes +# Re pattern that catches list introduction (with potential indent) +_re_list = re.compile(r"^(\s*-\s+|\s*\*\s+|\s*\d+\.\s+)") +# Re pattern that catches code block introduction (with potential indent) +_re_code = re.compile(r"^(\s*)```(.*)$") +# Re pattern that catches rst args blocks of the form `Parameters:`. +_re_args = re.compile("^\s*(Args?|Arguments?|Params?|Parameters?):\s*$") +# Re pattern that catches return blocks of the form `Return:`. +_re_returns = re.compile("^\s*Returns?:\s*$") +# Matches the special tag to ignore some paragraphs. +_re_doc_ignore = re.compile(r"(\.\.|#)\s*docstyle-ignore") +# Re pattern that matches , and blocks. +_re_tip = re.compile("^\s*|\s+warning={true}>)\s*$") + +DOCTEST_PROMPTS = [">>>", "..."] + + +def is_empty_line(line): + return len(line) == 0 or line.isspace() + + +def find_indent(line): + """ + Returns the number of spaces that start a line indent. + """ + search = re.search("^(\s*)(?:\S|$)", line) + if search is None: + return 0 + return len(search.groups()[0]) + + +def parse_code_example(code_lines): + """ + Parses a code example + + Args: + code_lines (`List[str]`): The code lines to parse. + max_len (`int`): The maximum length per line. + + Returns: + (List[`str`], List[`str`]): The list of code samples and the list of outputs. + """ + has_doctest = code_lines[0][:3] in DOCTEST_PROMPTS + + code_samples = [] + outputs = [] + in_code = True + current_bit = [] + + for line in code_lines: + if in_code and has_doctest and not is_empty_line(line) and line[:3] not in DOCTEST_PROMPTS: + code_sample = "\n".join(current_bit) + code_samples.append(code_sample.strip()) + in_code = False + current_bit = [] + elif not in_code and line[:3] in DOCTEST_PROMPTS: + output = "\n".join(current_bit) + outputs.append(output.strip()) + in_code = True + current_bit = [] + + # Add the line without doctest prompt + if line[:3] in DOCTEST_PROMPTS: + line = line[4:] + current_bit.append(line) + + # Add last sample + if in_code: + code_sample = "\n".join(current_bit) + code_samples.append(code_sample.strip()) + else: + output = "\n".join(current_bit) + outputs.append(output.strip()) + + return code_samples, outputs + + +def format_code_example(code: str, max_len: int, in_docstring: bool = False): + """ + Format a code example using black. Will take into account the doctest syntax as well as any initial indentation in + the code provided. + + Args: + code (`str`): The code example to format. + max_len (`int`): The maximum length per line. + in_docstring (`bool`, *optional*, defaults to `False`): Whether or not the code example is inside a docstring. + + Returns: + `str`: The formatted code. + """ + code_lines = code.split("\n") + + # Find initial indent + idx = 0 + while idx < len(code_lines) and is_empty_line(code_lines[idx]): + idx += 1 + if idx >= len(code_lines): + return "", "" + indent = find_indent(code_lines[idx]) + + # Remove the initial indent for now, we will had it back after styling. + # Note that l[indent:] works for empty lines + code_lines = [l[indent:] for l in code_lines[idx:]] + has_doctest = code_lines[0][:3] in DOCTEST_PROMPTS + + code_samples, outputs = parse_code_example(code_lines) + + # Let's blackify the code! We put everything in one big text to go faster. + delimiter = "\n\n### New code sample ###\n" + full_code = delimiter.join(code_samples) + line_length = max_len - indent + if has_doctest: + line_length -= 4 + + for k, v in BLACK_AVOID_PATTERNS.items(): + full_code = full_code.replace(k, v) + try: + mode = black.Mode(target_versions={black.TargetVersion.PY37}, line_length=line_length) + formatted_code = black.format_str(full_code, mode=mode) + error = "" + except Exception as e: + formatted_code = full_code + error = f"Code sample:\n{full_code}\n\nError message:\n{e}" + + # Let's get back the formatted code samples + for k, v in BLACK_AVOID_PATTERNS.items(): + formatted_code = formatted_code.replace(v, k) + # Triple quotes will mess docstrings. + if in_docstring: + formatted_code = formatted_code.replace('"""', "'''") + + code_samples = formatted_code.split(delimiter) + # We can have one output less than code samples + if len(outputs) == len(code_samples) - 1: + outputs.append("") + + formatted_lines = [] + for code_sample, output in zip(code_samples, outputs): + # black may have added some new lines, we remove them + code_sample = code_sample.strip() + in_triple_quotes = False + in_decorator = False + for line in code_sample.strip().split("\n"): + if has_doctest and not is_empty_line(line): + prefix = ( + "... " + if line.startswith(" ") or line in [")", "]", "}"] or in_triple_quotes or in_decorator + else ">>> " + ) + else: + prefix = "" + indent_str = "" if is_empty_line(line) else (" " * indent) + formatted_lines.append(indent_str + prefix + line) + + if '"""' in line: + in_triple_quotes = not in_triple_quotes + if line.startswith(" "): + in_decorator = False + if line.startswith("@"): + in_decorator = True + + formatted_lines.extend([" " * indent + line for line in output.split("\n")]) + if not output.endswith("===PT-TF-SPLIT==="): + formatted_lines.append("") + + result = "\n".join(formatted_lines) + return result.rstrip(), error + + +def format_text(text, max_len, prefix="", min_indent=None): + """ + Format a text in the biggest lines possible with the constraint of a maximum length and an indentation. + + Args: + text (`str`): The text to format + max_len (`int`): The maximum length per line to use + prefix (`str`, *optional*, defaults to `""`): A prefix that will be added to the text. + The prefix doesn't count toward the indent (like a - introducing a list). + min_indent (`int`, *optional*): The minimum indent of the text. + If not set, will default to the length of the `prefix`. + + Returns: + `str`: The formatted text. + """ + text = re.sub(r"\s+", " ", text) + if min_indent is not None: + if len(prefix) < min_indent: + prefix = " " * (min_indent - len(prefix)) + prefix + + indent = " " * len(prefix) + new_lines = [] + words = text.split(" ") + current_line = f"{prefix}{words[0]}" + for word in words[1:]: + try_line = f"{current_line} {word}" + if len(try_line) > max_len: + new_lines.append(current_line) + current_line = f"{indent}{word}" + else: + current_line = try_line + new_lines.append(current_line) + return "\n".join(new_lines) + + +def split_line_on_first_colon(line): + splits = line.split(":") + return splits[0], ":".join(splits[1:]) + + +def style_docstring(docstring, max_len): + """ + Style a docstring by making sure there is no useless whitespace and the maximum horizontal space is used. + + Args: + docstring (`str`): The docstring to style. + max_len (`int`): The maximum length of each line. + + Returns: + `str`: The styled docstring + """ + lines = docstring.split("\n") + new_lines = [] + + # Initialization + current_paragraph = None + current_indent = -1 + in_code = False + param_indent = -1 + prefix = "" + black_errors = [] + + # Special case for docstrings that begin with continuation of Args with no Args block. + idx = 0 + while idx < len(lines) and is_empty_line(lines[idx]): + idx += 1 + if ( + len(lines[idx]) > 1 + and lines[idx].rstrip().endswith(":") + and find_indent(lines[idx + 1]) > find_indent(lines[idx]) + ): + param_indent = find_indent(lines[idx]) + + for idx, line in enumerate(lines): + # Doing all re searches once for the one we need to repeat. + list_search = _re_list.search(line) + code_search = _re_code.search(line) + + # Are we starting a new paragraph? + # New indentation or new line: + new_paragraph = find_indent(line) != current_indent or is_empty_line(line) + # List item + new_paragraph = new_paragraph or list_search is not None + # Code block beginning + new_paragraph = new_paragraph or code_search is not None + # Beginning/end of tip + new_paragraph = new_paragraph or _re_tip.search(line) + + # In this case, we treat the current paragraph + if not in_code and new_paragraph and current_paragraph is not None and len(current_paragraph) > 0: + paragraph = " ".join(current_paragraph) + new_lines.append(format_text(paragraph, max_len, prefix=prefix, min_indent=current_indent)) + current_paragraph = None + + if code_search is not None: + if not in_code: + current_paragraph = [] + current_indent = len(code_search.groups()[0]) + current_code = code_search.groups()[1] + prefix = "" + if current_indent < param_indent: + param_indent = -1 + else: + current_indent = -1 + code = "\n".join(current_paragraph) + if current_code in ["py", "python"]: + formatted_code, error = format_code_example(code, max_len, in_docstring=True) + new_lines.append(formatted_code) + if len(error) > 0: + black_errors.append(error) + else: + new_lines.append(code) + current_paragraph = None + new_lines.append(line) + in_code = not in_code + + elif in_code: + current_paragraph.append(line) + elif is_empty_line(line): + current_paragraph = None + current_indent = -1 + prefix = "" + new_lines.append(line) + elif list_search is not None: + prefix = list_search.groups()[0] + current_indent = len(prefix) + current_paragraph = [line[current_indent:]] + elif _re_args.search(line): + new_lines.append(line) + param_indent = find_indent(lines[idx + 1]) + elif _re_tip.search(line): + # Add a new line before if not present + if not is_empty_line(new_lines[-1]): + new_lines.append("") + new_lines.append(line) + # Add a new line after if not present + if idx < len(lines) - 1 and not is_empty_line(lines[idx + 1]): + new_lines.append("") + elif current_paragraph is None or find_indent(line) != current_indent: + indent = find_indent(line) + # Special behavior for parameters intros. + if indent == param_indent: + # Special rules for some docstring where the Returns blocks has the same indent as the parameters. + if _re_returns.search(line) is not None: + param_indent = -1 + new_lines.append(line) + elif len(line) < max_len: + new_lines.append(line) + else: + intro, description = split_line_on_first_colon(line) + new_lines.append(intro + ":") + if len(description) != 0: + if find_indent(lines[idx + 1]) > indent: + current_indent = find_indent(lines[idx + 1]) + else: + current_indent = indent + 4 + current_paragraph = [description.strip()] + prefix = "" + else: + # Check if we have exited the parameter block + if indent < param_indent: + param_indent = -1 + + current_paragraph = [line.strip()] + current_indent = find_indent(line) + prefix = "" + elif current_paragraph is not None: + current_paragraph.append(line.lstrip()) + + if current_paragraph is not None and len(current_paragraph) > 0: + paragraph = " ".join(current_paragraph) + new_lines.append(format_text(paragraph, max_len, prefix=prefix, min_indent=current_indent)) + + return "\n".join(new_lines), "\n\n".join(black_errors) + + +def style_docstrings_in_code(code, max_len=119): + """ + Style all docstrings in some code. + + Args: + code (`str`): The code in which we want to style the docstrings. + max_len (`int`): The maximum number of characters per line. + + Returns: + `Tuple[str, str]`: A tuple with the clean code and the black errors (if any) + """ + # fmt: off + splits = code.split('\"\"\"') + splits = [ + (s if i % 2 == 0 or _re_doc_ignore.search(splits[i - 1]) is not None else style_docstring(s, max_len=max_len)) + for i, s in enumerate(splits) + ] + black_errors = "\n\n".join([s[1] for s in splits if isinstance(s, tuple) and len(s[1]) > 0]) + splits = [s[0] if isinstance(s, tuple) else s for s in splits] + clean_code = '\"\"\"'.join(splits) + # fmt: on + + return clean_code, black_errors + + +def style_file_docstrings(code_file, max_len=119, check_only=False): + """ + Style all docstrings in a given file. + + Args: + code_file (`str` or `os.PathLike`): The file in which we want to style the docstring. + max_len (`int`): The maximum number of characters per line. + check_only (`bool`, *optional*, defaults to `False`): + Whether to restyle file or just check if they should be restyled. + + Returns: + `bool`: Whether or not the file was or should be restyled. + """ + with open(code_file, "r", encoding="utf-8", newline="\n") as f: + code = f.read() + + clean_code, black_errors = style_docstrings_in_code(code, max_len=max_len) + + diff = clean_code != code + if not check_only and diff: + print(f"Overwriting content of {code_file}.") + with open(code_file, "w", encoding="utf-8", newline="\n") as f: + f.write(clean_code) + + return diff, black_errors + + +def style_mdx_file(mdx_file, max_len=119, check_only=False): + """ + Style a MDX file by formatting all Python code samples. + + Args: + mdx_file (`str` or `os.PathLike`): The file in which we want to style the examples. + max_len (`int`): The maximum number of characters per line. + check_only (`bool`, *optional*, defaults to `False`): + Whether to restyle file or just check if they should be restyled. + + Returns: + `bool`: Whether or not the file was or should be restyled. + """ + with open(mdx_file, "r", encoding="utf-8", newline="\n") as f: + content = f.read() + + lines = content.split("\n") + current_code = [] + current_language = "" + in_code = False + new_lines = [] + black_errors = [] + + for line in lines: + if _re_code.search(line) is not None: + in_code = not in_code + if in_code: + current_language = _re_code.search(line).groups()[1] + current_code = [] + else: + code = "\n".join(current_code) + if current_language in ["py", "python"]: + code, error = format_code_example(code, max_len) + if len(error) > 0: + black_errors.append(error) + new_lines.append(code) + + new_lines.append(line) + elif in_code: + current_code.append(line) + else: + new_lines.append(line) + + if in_code: + raise ValueError(f"There was a problem when styling {mdx_file}. A code block is opened without being closed.") + + clean_content = "\n".join(new_lines) + diff = clean_content != content + if not check_only and diff: + print(f"Overwriting content of {mdx_file}.") + with open(mdx_file, "w", encoding="utf-8", newline="\n") as f: + f.write(clean_content) + + return diff, "\n\n".join(black_errors) + + +def style_doc_files(*files, max_len=119, check_only=False): + """ + Applies doc styling or checks everything is correct in a list of files. + + Args: + files (several `str` or `os.PathLike`): The files to treat. + max_len (`int`): The maximum number of characters per line. + check_only (`bool`, *optional*, defaults to `False`): + Whether to restyle file or just check if they should be restyled. + + Returns: + List[`str`]: The list of files changed or that should be restyled. + """ + changed = [] + black_errors = [] + for file in files: + # Treat folders + if os.path.isdir(file): + files = [os.path.join(file, f) for f in os.listdir(file)] + files = [f for f in files if os.path.isdir(f) or f.endswith(".mdx") or f.endswith(".py")] + changed += style_doc_files(*files, max_len=max_len, check_only=check_only) + # Treat mdx + elif file.endswith(".mdx"): + try: + diff, black_error = style_mdx_file(file, max_len=max_len, check_only=check_only) + if diff: + changed.append(file) + if len(black_error) > 0: + black_errors.append( + f"There was a problem while formatting an example in {file} with black:\m{black_error}" + ) + except Exception: + print(f"There is a problem in {file}.") + raise + # Treat python files + elif file.endswith(".py"): + try: + diff, black_error = style_file_docstrings(file, max_len=max_len, check_only=check_only) + if diff: + changed.append(file) + if len(black_error) > 0: + black_errors.append( + f"There was a problem while formatting an example in {file} with black:\m{black_error}" + ) + except Exception: + print(f"There is a problem in {file}.") + raise + else: + warnings.warn(f"Ignoring {file} because it's not a py or an mdx file or a folder.") + if len(black_errors) > 0: + black_message = "\n\n".join(black_errors) + raise ValueError( + "Some code examples can't be interpreted by black, which means they aren't regular python:\n\n" + + black_message + + "\n\nMake sure to fix the corresponding docstring or doc file, or remove the py/python after ``` if it " + + "was not supposed to be a Python code sample." + ) + return changed + + +def main(*files, max_len=119, check_only=False): + changed = style_doc_files(*files, max_len=max_len, check_only=check_only) + if check_only and len(changed) > 0: + raise ValueError(f"{len(changed)} files should be restyled!") + elif len(changed) > 0: + print(f"Cleaned {len(changed)} files!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("files", nargs="+", help="The file(s) or folder(s) to restyle.") + parser.add_argument("--max_len", type=int, help="The maximum length of lines.") + parser.add_argument("--check_only", action="store_true", help="Whether to only check and not fix styling issues.") + args = parser.parse_args() + + main(*args.files, max_len=args.max_len, check_only=args.check_only)