diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000..da99824aa3
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,141 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+pip-wheel-metadata/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# VSCode
+.vscode
+
+# IntelliJ
+.idea
+
+# Mac .DS_Store
+.DS_Store
+
+# More test things
+wandb
\ No newline at end of file
diff --git a/MANIFEST.in b/MANIFEST.in
new file mode 100644
index 0000000000..1aba38f67a
--- /dev/null
+++ b/MANIFEST.in
@@ -0,0 +1 @@
+include LICENSE
diff --git a/Makefile b/Makefile
new file mode 100644
index 0000000000..e1c15c5959
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,19 @@
+.PHONY: quality style test docs
+
+check_dirs := src
+
+# Check that source code meets quality standards
+
+# this target runs checks on all files
+quality:
+ black --check $(check_dirs)
+ isort --check-only $(check_dirs)
+ flake8 $(check_dirs)
+ python utils/style_doc.py src --max_len 119 --check_only
+
+# Format source code automatically and check is there are any problems left that need manual fixing
+style:
+ black $(check_dirs)
+ isort $(check_dirs)
+ python utils/style_doc.py src --max_len 119
+
\ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000000..b7465bb131
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,3 @@
+[tool.black]
+line-length = 119
+target-version = ['py36']
diff --git a/setup.cfg b/setup.cfg
new file mode 100644
index 0000000000..6b26312126
--- /dev/null
+++ b/setup.cfg
@@ -0,0 +1,23 @@
+[isort]
+default_section = FIRSTPARTY
+ensure_newline_before_comments = True
+force_grid_wrap = 0
+include_trailing_comma = True
+known_first_party = pet
+known_third_party =
+ numpy
+ torch
+ accelerate
+ transformers
+
+line_length = 119
+lines_after_imports = 2
+multi_line_output = 3
+use_parentheses = True
+
+[flake8]
+ignore = E203, E722, E501, E741, W503, W605
+max-line-length = 119
+
+[tool:pytest]
+doctest_optionflags=NUMBER NORMALIZE_WHITESPACE ELLIPSIS
\ No newline at end of file
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000000..9c8edd91d3
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,78 @@
+# Copyright 2021 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from setuptools import setup
+from setuptools import find_packages
+
+extras = {}
+extras["quality"] = ["black ~= 22.0", "isort >= 5.5.4", "flake8 >= 3.8.3"]
+extras["dev"] = extras["quality"]
+
+setup(
+ name="pets",
+ version="0.1.0.dev0",
+ description="Parameter-Efficient Tuning at Scale (PETS)",
+ long_description=open("README.md", "r", encoding="utf-8").read(),
+ long_description_content_type="text/markdown",
+ keywords="deep learning",
+ license="Apache",
+ author="The HuggingFace team",
+ author_email="sourab@huggingface.co",
+ url="https://github.com/huggingface/pets",
+ package_dir={"": "src"},
+ packages=find_packages("src"),
+ entry_points={},
+ python_requires=">=3.7.0",
+ install_requires=[
+ "numpy>=1.17",
+ "packaging>=20.0",
+ "psutil",
+ "pyyaml",
+ "torch>=1.4.0",
+ "transformers",
+ "accelerate",
+ ],
+ extras_require=extras,
+ classifiers=[
+ "Development Status :: 5 - Production/Stable",
+ "Intended Audience :: Developers",
+ "Intended Audience :: Education",
+ "Intended Audience :: Science/Research",
+ "License :: OSI Approved :: Apache Software License",
+ "Operating System :: OS Independent",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.7",
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
+ ],
+)
+
+# Release checklist
+# 1. Change the version in __init__.py and setup.py.
+# 2. Commit these changes with the message: "Release: VERSION"
+# 3. Add a tag in git to mark the release: "git tag VERSION -m 'Adds tag VERSION for pypi' "
+# Push the tag to git: git push --tags origin main
+# 4. Run the following commands in the top-level directory:
+# python setup.py bdist_wheel
+# python setup.py sdist
+# 5. Upload the package to the pypi test server first:
+# twine upload dist/* -r pypitest
+# twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/
+# 6. Check that you can install it in a virtualenv by running:
+# pip install -i https://testpypi.python.org/pypi accelerate
+# accelerate env
+# accelerate test
+# 7. Upload the final version to actual pypi:
+# twine upload dist/* -r pypi
+# 8. Add release notes to the tag in github once everything is looking hunky-dory.
+# 9. Update the version in __init__.py, setup.py to the new version "-dev" and push to master
diff --git a/src/pet/__init__.py b/src/pet/__init__.py
new file mode 100644
index 0000000000..2436302799
--- /dev/null
+++ b/src/pet/__init__.py
@@ -0,0 +1,18 @@
+# flake8: noqa
+# There's no way to ignore "F401 '...' imported but unused" warnings in this
+# module, but to preserve other warnings. So, don't check this module at all.
+
+__version__ = "0.1.0.dev0"
+
+from .pet_model import (
+ ParameterEfficientTuningModel,
+ ParameterEfficientTuningModelForSequenceClassification,
+ PromptEncoderType,
+)
+from .tuners import (
+ PrefixEncoder,
+ PromptEmbedding,
+ PromptEncoder,
+ PromptEncoderReparameterizationType,
+ PromptTuningInit,
+)
diff --git a/src/pet.py b/src/pet/pet_model.py
similarity index 97%
rename from src/pet.py
rename to src/pet/pet_model.py
index 7ac208b27d..9345539e51 100644
--- a/src/pet.py
+++ b/src/pet/pet_model.py
@@ -1,12 +1,14 @@
-from collections import OrderedDict
import enum
import warnings
+from collections import OrderedDict
+
import torch
+from accelerate.state import AcceleratorState
from transformers import PreTrainedModel
+
from tuners.p_tuning import PromptEncoder
from tuners.prefix_tuning import PrefixEncoder
from tuners.prompt_tuning import PromptEmbedding
-from accelerate.state import AcceleratorState
class PromptEncoderType(str, enum.Enum):
@@ -88,8 +90,8 @@ def state_dict(self, destination=None, prefix=None, keep_vars=False):
def load_state_dict(self, state_dict, strict: bool = True):
"""
- Custom load state dict method that only loads prompt table and prompt encoder
- parameters. Matching load method for this class' custom state dict method.
+ Custom load state dict method that only loads prompt table and prompt encoder parameters. Matching load method
+ for this class' custom state dict method.
"""
self.prompt_encoder.embedding.load_state_dict({"weight": state_dict["prompt_embeddings"]}, strict)
@@ -187,8 +189,8 @@ def state_dict(self, destination=None, prefix=None, keep_vars=False):
def load_state_dict(self, state_dict, strict: bool = True):
"""
- Custom load state dict method that only loads prompt table and prompt encoder
- parameters. Matching load method for this class' custom state dict method.
+ Custom load state dict method that only loads prompt table and prompt encoder parameters. Matching load method
+ for this class' custom state dict method.
"""
super().load_state_dict(state_dict["prompt_encoder"], strict)
self.model.classifier.load_state_dict(state_dict["classifier"], strict)
diff --git a/src/prompt_learning_legacy.py b/src/pet/prompt_learning_legacy.py
similarity index 91%
rename from src/prompt_learning_legacy.py
rename to src/pet/prompt_learning_legacy.py
index ffb6934396..b08d01851d 100644
--- a/src/prompt_learning_legacy.py
+++ b/src/pet/prompt_learning_legacy.py
@@ -1,34 +1,27 @@
import enum
-import torch
+import functools
import math
import os
+from collections import OrderedDict
-from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
-from transformers import PreTrainedModel
-from transformers.modeling_outputs import SequenceClassifierOutput
-from transformers import AutoModelForSequenceClassification
-from datasets import load_dataset
-import evaluate
import torch
-from transformers import AutoTokenizer, get_linear_schedule_with_warmup, set_seed
-from torch.utils.data import DataLoader
from accelerate import Accelerator
from accelerate.state import AcceleratorState
from accelerate.utils.dataclasses import FullyShardedDataParallelPlugin
-import functools
-from torch.distributed.fsdp import (
- FullyShardedDataParallel,
- CPUOffload,
-)
-from torch.distributed.fsdp.wrap import (
- enable_wrap,
- wrap,
- ModuleWrapPolicy,
- transformer_auto_wrap_policy,
- lambda_auto_wrap_policy,
- _or_policy,
+from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from torch.utils.data import DataLoader
+from transformers import (
+ AutoModelForSequenceClassification,
+ AutoTokenizer,
+ PreTrainedModel,
+ get_linear_schedule_with_warmup,
+ set_seed,
)
-from collections import OrderedDict
+from transformers.modeling_outputs import SequenceClassifierOutput
+
+import evaluate
+from datasets import load_dataset
class PromptEncoderReparameterizationType(str, enum.Enum):
@@ -49,8 +42,7 @@ class PromptTuningInit(str, enum.Enum):
class PromptEncoder(torch.nn.Module):
"""
- The prompt encoder network that is used to generate the virtual
- token embeddings for p-tuning.
+ The prompt encoder network that is used to generate the virtual token embeddings for p-tuning.
"""
def __init__(self, config):
@@ -92,13 +84,23 @@ def __init__(self, config):
)
elif self.encoder_type == PromptEncoderReparameterizationType.MLP:
- layers = [torch.nn.Linear(self.input_size, self.hidden_size), torch.nn.ReLU()]
- layers.extend([torch.nn.Linear(self.hidden_size, self.hidden_size), torch.nn.ReLU()])
+ layers = [
+ torch.nn.Linear(self.input_size, self.hidden_size),
+ torch.nn.ReLU(),
+ ]
+ layers.extend(
+ [
+ torch.nn.Linear(self.hidden_size, self.hidden_size),
+ torch.nn.ReLU(),
+ ]
+ )
layers.append(torch.nn.Linear(self.hidden_size, self.output_size))
self.mlp_head = torch.nn.Sequential(*layers)
else:
- raise ValueError("Prompt encoder type not recognized. Please use one of MLP (recommended) or LSTM.")
+ raise ValueError(
+ "Prompt encoder type not recognized. " " Please use one of MLP (recommended) or LSTM."
+ )
def forward(self, indices):
input_embeds = self.embedding(indices)
@@ -130,11 +132,15 @@ def __init__(self, config):
self.trans = torch.nn.Sequential(
torch.nn.Linear(config["token_dim"], config["prompt_hidden_size"]),
torch.nn.Tanh(),
- torch.nn.Linear(config["prompt_hidden_size"], config["num_layers"] * 2 * config["token_dim"]),
+ torch.nn.Linear(
+ config["prompt_hidden_size"],
+ config["num_layers"] * 2 * config["token_dim"],
+ ),
)
else:
self.embedding = torch.nn.Embedding(
- config["num_virtual_tokens"], config["num_layers"] * 2 * config["token_dim"]
+ config["num_virtual_tokens"],
+ config["num_layers"] * 2 * config["token_dim"],
)
def forward(self, prefix: torch.Tensor):
@@ -247,8 +253,8 @@ def state_dict(self, destination=None, prefix=None, keep_vars=False):
def load_state_dict(self, state_dict, strict: bool = True):
"""
- Custom load state dict method that only loads prompt table and prompt encoder
- parameters. Matching load method for this class' custom state dict method.
+ Custom load state dict method that only loads prompt table and prompt encoder parameters. Matching load method
+ for this class' custom state dict method.
"""
self.prompt_encoder.embedding.load_state_dict({"weight": state_dict["prompt_embeddings"]}, strict)
@@ -389,8 +395,8 @@ def state_dict(self, destination=None, prefix=None, keep_vars=False):
def load_state_dict(self, state_dict, strict: bool = True):
"""
- Custom load state dict method that only loads prompt table and prompt encoder
- parameters. Matching load method for this class' custom state dict method.
+ Custom load state dict method that only loads prompt table and prompt encoder parameters. Matching load method
+ for this class' custom state dict method.
"""
super().load_state_dict(state_dict["prompt_encoder"], strict)
self.classifier.load_state_dict(state_dict["classifier"], strict)
@@ -528,7 +534,7 @@ def main():
batch_size = 16
lr = 5e-3
num_epochs = 100
- device = "cuda"
+ # device = "cuda"
seed = 11
set_seed(seed)
@@ -544,7 +550,12 @@ def main():
def tokenize_function(examples):
# max_length=None => use the model max length (it's actually the default)
- outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None)
+ outputs = tokenizer(
+ examples["sentence1"],
+ examples["sentence2"],
+ truncation=True,
+ max_length=None,
+ )
return outputs
# Apply the method we just defined to all the examples in all the splits of the dataset
@@ -564,10 +575,16 @@ def collate_fn(examples):
# Instantiate dataloaders.
train_dataloader = DataLoader(
- tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size
+ tokenized_datasets["train"],
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=batch_size,
)
eval_dataloader = DataLoader(
- tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=batch_size
+ tokenized_datasets["validation"],
+ shuffle=False,
+ collate_fn=collate_fn,
+ batch_size=batch_size,
)
# Instantiate optimizer
@@ -582,9 +599,13 @@ def collate_fn(examples):
accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(model)
- model, train_dataloader, eval_dataloader, optimizer, lr_scheduler = accelerator.prepare(
- model, train_dataloader, eval_dataloader, optimizer, lr_scheduler
- )
+ (
+ model,
+ train_dataloader,
+ eval_dataloader,
+ optimizer,
+ lr_scheduler,
+ ) = accelerator.prepare(model, train_dataloader, eval_dataloader, optimizer, lr_scheduler)
accelerator.print(model)
for epoch in range(num_epochs):
@@ -616,17 +637,14 @@ def collate_fn(examples):
accelerator.print(f"epoch {epoch}:", eval_metric)
accelerator.print(f"epoch {epoch} train loss:", total_loss / len(train_dataloader))
+ from torch.distributed.fsdp.fully_sharded_data_parallel import FullStateDictConfig
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
- from torch.distributed.fsdp.fully_sharded_data_parallel import (
- BackwardPrefetch,
- CPUOffload,
- FullStateDictConfig,
- ShardingStrategy,
- StateDictType,
- )
+ from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
FSDP.set_state_dict_type(
- model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
+ model,
+ StateDictType.FULL_STATE_DICT,
+ FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
)
state_dict = model.state_dict()
state_dict = model.clean_state_dict(state_dict)
diff --git a/src/pet/tuners/__init__.py b/src/pet/tuners/__init__.py
new file mode 100644
index 0000000000..41ae36c5c8
--- /dev/null
+++ b/src/pet/tuners/__init__.py
@@ -0,0 +1,7 @@
+# flake8: noqa
+# There's no way to ignore "F401 '...' imported but unused" warnings in this
+# module, but to preserve other warnings. So, don't check this module at all
+
+from .p_tuning import PromptEncoder, PromptEncoderReparameterizationType
+from .prefix_tuning import PrefixEncoder
+from .prompt_tuning import PromptEmbedding, PromptTuningInit
diff --git a/src/tuners/lora.py b/src/pet/tuners/lora.py
similarity index 100%
rename from src/tuners/lora.py
rename to src/pet/tuners/lora.py
diff --git a/src/tuners/p_tuning.py b/src/pet/tuners/p_tuning.py
similarity index 87%
rename from src/tuners/p_tuning.py
rename to src/pet/tuners/p_tuning.py
index 2c103c3f5f..5b161c39a2 100644
--- a/src/tuners/p_tuning.py
+++ b/src/pet/tuners/p_tuning.py
@@ -1,6 +1,7 @@
-import torch
import enum
+import torch
+
class PromptEncoderReparameterizationType(str, enum.Enum):
MLP = "MLP"
@@ -11,8 +12,7 @@ class PromptEncoderReparameterizationType(str, enum.Enum):
# with some refactor
class PromptEncoder(torch.nn.Module):
"""
- The prompt encoder network that is used to generate the virtual
- token embeddings for p-tuning.
+ The prompt encoder network that is used to generate the virtual token embeddings for p-tuning.
"""
def __init__(self, config):
@@ -54,8 +54,16 @@ def __init__(self, config):
)
elif self.encoder_type == PromptEncoderReparameterizationType.MLP:
- layers = [torch.nn.Linear(self.input_size, self.hidden_size), torch.nn.ReLU()]
- layers.extend([torch.nn.Linear(self.hidden_size, self.hidden_size), torch.nn.ReLU()])
+ layers = [
+ torch.nn.Linear(self.input_size, self.hidden_size),
+ torch.nn.ReLU(),
+ ]
+ layers.extend(
+ [
+ torch.nn.Linear(self.hidden_size, self.hidden_size),
+ torch.nn.ReLU(),
+ ]
+ )
layers.append(torch.nn.Linear(self.hidden_size, self.output_size))
self.mlp_head = torch.nn.Sequential(*layers)
diff --git a/src/tuners/prefix_tuning.py b/src/pet/tuners/prefix_tuning.py
similarity index 81%
rename from src/tuners/prefix_tuning.py
rename to src/pet/tuners/prefix_tuning.py
index c593b1eff5..761545f55b 100644
--- a/src/tuners/prefix_tuning.py
+++ b/src/pet/tuners/prefix_tuning.py
@@ -1,5 +1,6 @@
import torch
+
# Based on https://github.com/THUDM/P-tuning-v2/blob/main/model/prefix_encoder.py
# with some refactor
class PrefixEncoder(torch.nn.Module):
@@ -20,11 +21,15 @@ def __init__(self, config):
self.trans = torch.nn.Sequential(
torch.nn.Linear(config["token_dim"], config["prompt_hidden_size"]),
torch.nn.Tanh(),
- torch.nn.Linear(config["prompt_hidden_size"], config["num_layers"] * 2 * config["token_dim"]),
+ torch.nn.Linear(
+ config["prompt_hidden_size"],
+ config["num_layers"] * 2 * config["token_dim"],
+ ),
)
else:
self.embedding = torch.nn.Embedding(
- config["num_virtual_tokens"], config["num_layers"] * 2 * config["token_dim"]
+ config["num_virtual_tokens"],
+ config["num_layers"] * 2 * config["token_dim"],
)
def forward(self, prefix: torch.Tensor):
diff --git a/src/tuners/prompt_tuning.py b/src/pet/tuners/prompt_tuning.py
similarity index 99%
rename from src/tuners/prompt_tuning.py
rename to src/pet/tuners/prompt_tuning.py
index cb4fc8df0a..c3bfed7ece 100644
--- a/src/tuners/prompt_tuning.py
+++ b/src/pet/tuners/prompt_tuning.py
@@ -1,7 +1,8 @@
-import torch
import enum
import math
+import torch
+
class PromptTuningInit(str, enum.Enum):
TEXT = "TEXT"
diff --git a/src/utils/constants.py b/src/pet/utils/constants.py
similarity index 100%
rename from src/utils/constants.py
rename to src/pet/utils/constants.py
diff --git a/utils/style_doc.py b/utils/style_doc.py
new file mode 100644
index 0000000000..0422ebeb4e
--- /dev/null
+++ b/utils/style_doc.py
@@ -0,0 +1,556 @@
+# coding=utf-8
+# Copyright 2020 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Style utils for the .rst and the docstrings."""
+
+import argparse
+import os
+import re
+import warnings
+
+import black
+
+
+BLACK_AVOID_PATTERNS = {}
+
+
+# Regexes
+# Re pattern that catches list introduction (with potential indent)
+_re_list = re.compile(r"^(\s*-\s+|\s*\*\s+|\s*\d+\.\s+)")
+# Re pattern that catches code block introduction (with potential indent)
+_re_code = re.compile(r"^(\s*)```(.*)$")
+# Re pattern that catches rst args blocks of the form `Parameters:`.
+_re_args = re.compile("^\s*(Args?|Arguments?|Params?|Parameters?):\s*$")
+# Re pattern that catches return blocks of the form `Return:`.
+_re_returns = re.compile("^\s*Returns?:\s*$")
+# Matches the special tag to ignore some paragraphs.
+_re_doc_ignore = re.compile(r"(\.\.|#)\s*docstyle-ignore")
+# Re pattern that matches , and blocks.
+_re_tip = re.compile("^\s*?Tip(>|\s+warning={true}>)\s*$")
+
+DOCTEST_PROMPTS = [">>>", "..."]
+
+
+def is_empty_line(line):
+ return len(line) == 0 or line.isspace()
+
+
+def find_indent(line):
+ """
+ Returns the number of spaces that start a line indent.
+ """
+ search = re.search("^(\s*)(?:\S|$)", line)
+ if search is None:
+ return 0
+ return len(search.groups()[0])
+
+
+def parse_code_example(code_lines):
+ """
+ Parses a code example
+
+ Args:
+ code_lines (`List[str]`): The code lines to parse.
+ max_len (`int`): The maximum length per line.
+
+ Returns:
+ (List[`str`], List[`str`]): The list of code samples and the list of outputs.
+ """
+ has_doctest = code_lines[0][:3] in DOCTEST_PROMPTS
+
+ code_samples = []
+ outputs = []
+ in_code = True
+ current_bit = []
+
+ for line in code_lines:
+ if in_code and has_doctest and not is_empty_line(line) and line[:3] not in DOCTEST_PROMPTS:
+ code_sample = "\n".join(current_bit)
+ code_samples.append(code_sample.strip())
+ in_code = False
+ current_bit = []
+ elif not in_code and line[:3] in DOCTEST_PROMPTS:
+ output = "\n".join(current_bit)
+ outputs.append(output.strip())
+ in_code = True
+ current_bit = []
+
+ # Add the line without doctest prompt
+ if line[:3] in DOCTEST_PROMPTS:
+ line = line[4:]
+ current_bit.append(line)
+
+ # Add last sample
+ if in_code:
+ code_sample = "\n".join(current_bit)
+ code_samples.append(code_sample.strip())
+ else:
+ output = "\n".join(current_bit)
+ outputs.append(output.strip())
+
+ return code_samples, outputs
+
+
+def format_code_example(code: str, max_len: int, in_docstring: bool = False):
+ """
+ Format a code example using black. Will take into account the doctest syntax as well as any initial indentation in
+ the code provided.
+
+ Args:
+ code (`str`): The code example to format.
+ max_len (`int`): The maximum length per line.
+ in_docstring (`bool`, *optional*, defaults to `False`): Whether or not the code example is inside a docstring.
+
+ Returns:
+ `str`: The formatted code.
+ """
+ code_lines = code.split("\n")
+
+ # Find initial indent
+ idx = 0
+ while idx < len(code_lines) and is_empty_line(code_lines[idx]):
+ idx += 1
+ if idx >= len(code_lines):
+ return "", ""
+ indent = find_indent(code_lines[idx])
+
+ # Remove the initial indent for now, we will had it back after styling.
+ # Note that l[indent:] works for empty lines
+ code_lines = [l[indent:] for l in code_lines[idx:]]
+ has_doctest = code_lines[0][:3] in DOCTEST_PROMPTS
+
+ code_samples, outputs = parse_code_example(code_lines)
+
+ # Let's blackify the code! We put everything in one big text to go faster.
+ delimiter = "\n\n### New code sample ###\n"
+ full_code = delimiter.join(code_samples)
+ line_length = max_len - indent
+ if has_doctest:
+ line_length -= 4
+
+ for k, v in BLACK_AVOID_PATTERNS.items():
+ full_code = full_code.replace(k, v)
+ try:
+ mode = black.Mode(target_versions={black.TargetVersion.PY37}, line_length=line_length)
+ formatted_code = black.format_str(full_code, mode=mode)
+ error = ""
+ except Exception as e:
+ formatted_code = full_code
+ error = f"Code sample:\n{full_code}\n\nError message:\n{e}"
+
+ # Let's get back the formatted code samples
+ for k, v in BLACK_AVOID_PATTERNS.items():
+ formatted_code = formatted_code.replace(v, k)
+ # Triple quotes will mess docstrings.
+ if in_docstring:
+ formatted_code = formatted_code.replace('"""', "'''")
+
+ code_samples = formatted_code.split(delimiter)
+ # We can have one output less than code samples
+ if len(outputs) == len(code_samples) - 1:
+ outputs.append("")
+
+ formatted_lines = []
+ for code_sample, output in zip(code_samples, outputs):
+ # black may have added some new lines, we remove them
+ code_sample = code_sample.strip()
+ in_triple_quotes = False
+ in_decorator = False
+ for line in code_sample.strip().split("\n"):
+ if has_doctest and not is_empty_line(line):
+ prefix = (
+ "... "
+ if line.startswith(" ") or line in [")", "]", "}"] or in_triple_quotes or in_decorator
+ else ">>> "
+ )
+ else:
+ prefix = ""
+ indent_str = "" if is_empty_line(line) else (" " * indent)
+ formatted_lines.append(indent_str + prefix + line)
+
+ if '"""' in line:
+ in_triple_quotes = not in_triple_quotes
+ if line.startswith(" "):
+ in_decorator = False
+ if line.startswith("@"):
+ in_decorator = True
+
+ formatted_lines.extend([" " * indent + line for line in output.split("\n")])
+ if not output.endswith("===PT-TF-SPLIT==="):
+ formatted_lines.append("")
+
+ result = "\n".join(formatted_lines)
+ return result.rstrip(), error
+
+
+def format_text(text, max_len, prefix="", min_indent=None):
+ """
+ Format a text in the biggest lines possible with the constraint of a maximum length and an indentation.
+
+ Args:
+ text (`str`): The text to format
+ max_len (`int`): The maximum length per line to use
+ prefix (`str`, *optional*, defaults to `""`): A prefix that will be added to the text.
+ The prefix doesn't count toward the indent (like a - introducing a list).
+ min_indent (`int`, *optional*): The minimum indent of the text.
+ If not set, will default to the length of the `prefix`.
+
+ Returns:
+ `str`: The formatted text.
+ """
+ text = re.sub(r"\s+", " ", text)
+ if min_indent is not None:
+ if len(prefix) < min_indent:
+ prefix = " " * (min_indent - len(prefix)) + prefix
+
+ indent = " " * len(prefix)
+ new_lines = []
+ words = text.split(" ")
+ current_line = f"{prefix}{words[0]}"
+ for word in words[1:]:
+ try_line = f"{current_line} {word}"
+ if len(try_line) > max_len:
+ new_lines.append(current_line)
+ current_line = f"{indent}{word}"
+ else:
+ current_line = try_line
+ new_lines.append(current_line)
+ return "\n".join(new_lines)
+
+
+def split_line_on_first_colon(line):
+ splits = line.split(":")
+ return splits[0], ":".join(splits[1:])
+
+
+def style_docstring(docstring, max_len):
+ """
+ Style a docstring by making sure there is no useless whitespace and the maximum horizontal space is used.
+
+ Args:
+ docstring (`str`): The docstring to style.
+ max_len (`int`): The maximum length of each line.
+
+ Returns:
+ `str`: The styled docstring
+ """
+ lines = docstring.split("\n")
+ new_lines = []
+
+ # Initialization
+ current_paragraph = None
+ current_indent = -1
+ in_code = False
+ param_indent = -1
+ prefix = ""
+ black_errors = []
+
+ # Special case for docstrings that begin with continuation of Args with no Args block.
+ idx = 0
+ while idx < len(lines) and is_empty_line(lines[idx]):
+ idx += 1
+ if (
+ len(lines[idx]) > 1
+ and lines[idx].rstrip().endswith(":")
+ and find_indent(lines[idx + 1]) > find_indent(lines[idx])
+ ):
+ param_indent = find_indent(lines[idx])
+
+ for idx, line in enumerate(lines):
+ # Doing all re searches once for the one we need to repeat.
+ list_search = _re_list.search(line)
+ code_search = _re_code.search(line)
+
+ # Are we starting a new paragraph?
+ # New indentation or new line:
+ new_paragraph = find_indent(line) != current_indent or is_empty_line(line)
+ # List item
+ new_paragraph = new_paragraph or list_search is not None
+ # Code block beginning
+ new_paragraph = new_paragraph or code_search is not None
+ # Beginning/end of tip
+ new_paragraph = new_paragraph or _re_tip.search(line)
+
+ # In this case, we treat the current paragraph
+ if not in_code and new_paragraph and current_paragraph is not None and len(current_paragraph) > 0:
+ paragraph = " ".join(current_paragraph)
+ new_lines.append(format_text(paragraph, max_len, prefix=prefix, min_indent=current_indent))
+ current_paragraph = None
+
+ if code_search is not None:
+ if not in_code:
+ current_paragraph = []
+ current_indent = len(code_search.groups()[0])
+ current_code = code_search.groups()[1]
+ prefix = ""
+ if current_indent < param_indent:
+ param_indent = -1
+ else:
+ current_indent = -1
+ code = "\n".join(current_paragraph)
+ if current_code in ["py", "python"]:
+ formatted_code, error = format_code_example(code, max_len, in_docstring=True)
+ new_lines.append(formatted_code)
+ if len(error) > 0:
+ black_errors.append(error)
+ else:
+ new_lines.append(code)
+ current_paragraph = None
+ new_lines.append(line)
+ in_code = not in_code
+
+ elif in_code:
+ current_paragraph.append(line)
+ elif is_empty_line(line):
+ current_paragraph = None
+ current_indent = -1
+ prefix = ""
+ new_lines.append(line)
+ elif list_search is not None:
+ prefix = list_search.groups()[0]
+ current_indent = len(prefix)
+ current_paragraph = [line[current_indent:]]
+ elif _re_args.search(line):
+ new_lines.append(line)
+ param_indent = find_indent(lines[idx + 1])
+ elif _re_tip.search(line):
+ # Add a new line before if not present
+ if not is_empty_line(new_lines[-1]):
+ new_lines.append("")
+ new_lines.append(line)
+ # Add a new line after if not present
+ if idx < len(lines) - 1 and not is_empty_line(lines[idx + 1]):
+ new_lines.append("")
+ elif current_paragraph is None or find_indent(line) != current_indent:
+ indent = find_indent(line)
+ # Special behavior for parameters intros.
+ if indent == param_indent:
+ # Special rules for some docstring where the Returns blocks has the same indent as the parameters.
+ if _re_returns.search(line) is not None:
+ param_indent = -1
+ new_lines.append(line)
+ elif len(line) < max_len:
+ new_lines.append(line)
+ else:
+ intro, description = split_line_on_first_colon(line)
+ new_lines.append(intro + ":")
+ if len(description) != 0:
+ if find_indent(lines[idx + 1]) > indent:
+ current_indent = find_indent(lines[idx + 1])
+ else:
+ current_indent = indent + 4
+ current_paragraph = [description.strip()]
+ prefix = ""
+ else:
+ # Check if we have exited the parameter block
+ if indent < param_indent:
+ param_indent = -1
+
+ current_paragraph = [line.strip()]
+ current_indent = find_indent(line)
+ prefix = ""
+ elif current_paragraph is not None:
+ current_paragraph.append(line.lstrip())
+
+ if current_paragraph is not None and len(current_paragraph) > 0:
+ paragraph = " ".join(current_paragraph)
+ new_lines.append(format_text(paragraph, max_len, prefix=prefix, min_indent=current_indent))
+
+ return "\n".join(new_lines), "\n\n".join(black_errors)
+
+
+def style_docstrings_in_code(code, max_len=119):
+ """
+ Style all docstrings in some code.
+
+ Args:
+ code (`str`): The code in which we want to style the docstrings.
+ max_len (`int`): The maximum number of characters per line.
+
+ Returns:
+ `Tuple[str, str]`: A tuple with the clean code and the black errors (if any)
+ """
+ # fmt: off
+ splits = code.split('\"\"\"')
+ splits = [
+ (s if i % 2 == 0 or _re_doc_ignore.search(splits[i - 1]) is not None else style_docstring(s, max_len=max_len))
+ for i, s in enumerate(splits)
+ ]
+ black_errors = "\n\n".join([s[1] for s in splits if isinstance(s, tuple) and len(s[1]) > 0])
+ splits = [s[0] if isinstance(s, tuple) else s for s in splits]
+ clean_code = '\"\"\"'.join(splits)
+ # fmt: on
+
+ return clean_code, black_errors
+
+
+def style_file_docstrings(code_file, max_len=119, check_only=False):
+ """
+ Style all docstrings in a given file.
+
+ Args:
+ code_file (`str` or `os.PathLike`): The file in which we want to style the docstring.
+ max_len (`int`): The maximum number of characters per line.
+ check_only (`bool`, *optional*, defaults to `False`):
+ Whether to restyle file or just check if they should be restyled.
+
+ Returns:
+ `bool`: Whether or not the file was or should be restyled.
+ """
+ with open(code_file, "r", encoding="utf-8", newline="\n") as f:
+ code = f.read()
+
+ clean_code, black_errors = style_docstrings_in_code(code, max_len=max_len)
+
+ diff = clean_code != code
+ if not check_only and diff:
+ print(f"Overwriting content of {code_file}.")
+ with open(code_file, "w", encoding="utf-8", newline="\n") as f:
+ f.write(clean_code)
+
+ return diff, black_errors
+
+
+def style_mdx_file(mdx_file, max_len=119, check_only=False):
+ """
+ Style a MDX file by formatting all Python code samples.
+
+ Args:
+ mdx_file (`str` or `os.PathLike`): The file in which we want to style the examples.
+ max_len (`int`): The maximum number of characters per line.
+ check_only (`bool`, *optional*, defaults to `False`):
+ Whether to restyle file or just check if they should be restyled.
+
+ Returns:
+ `bool`: Whether or not the file was or should be restyled.
+ """
+ with open(mdx_file, "r", encoding="utf-8", newline="\n") as f:
+ content = f.read()
+
+ lines = content.split("\n")
+ current_code = []
+ current_language = ""
+ in_code = False
+ new_lines = []
+ black_errors = []
+
+ for line in lines:
+ if _re_code.search(line) is not None:
+ in_code = not in_code
+ if in_code:
+ current_language = _re_code.search(line).groups()[1]
+ current_code = []
+ else:
+ code = "\n".join(current_code)
+ if current_language in ["py", "python"]:
+ code, error = format_code_example(code, max_len)
+ if len(error) > 0:
+ black_errors.append(error)
+ new_lines.append(code)
+
+ new_lines.append(line)
+ elif in_code:
+ current_code.append(line)
+ else:
+ new_lines.append(line)
+
+ if in_code:
+ raise ValueError(f"There was a problem when styling {mdx_file}. A code block is opened without being closed.")
+
+ clean_content = "\n".join(new_lines)
+ diff = clean_content != content
+ if not check_only and diff:
+ print(f"Overwriting content of {mdx_file}.")
+ with open(mdx_file, "w", encoding="utf-8", newline="\n") as f:
+ f.write(clean_content)
+
+ return diff, "\n\n".join(black_errors)
+
+
+def style_doc_files(*files, max_len=119, check_only=False):
+ """
+ Applies doc styling or checks everything is correct in a list of files.
+
+ Args:
+ files (several `str` or `os.PathLike`): The files to treat.
+ max_len (`int`): The maximum number of characters per line.
+ check_only (`bool`, *optional*, defaults to `False`):
+ Whether to restyle file or just check if they should be restyled.
+
+ Returns:
+ List[`str`]: The list of files changed or that should be restyled.
+ """
+ changed = []
+ black_errors = []
+ for file in files:
+ # Treat folders
+ if os.path.isdir(file):
+ files = [os.path.join(file, f) for f in os.listdir(file)]
+ files = [f for f in files if os.path.isdir(f) or f.endswith(".mdx") or f.endswith(".py")]
+ changed += style_doc_files(*files, max_len=max_len, check_only=check_only)
+ # Treat mdx
+ elif file.endswith(".mdx"):
+ try:
+ diff, black_error = style_mdx_file(file, max_len=max_len, check_only=check_only)
+ if diff:
+ changed.append(file)
+ if len(black_error) > 0:
+ black_errors.append(
+ f"There was a problem while formatting an example in {file} with black:\m{black_error}"
+ )
+ except Exception:
+ print(f"There is a problem in {file}.")
+ raise
+ # Treat python files
+ elif file.endswith(".py"):
+ try:
+ diff, black_error = style_file_docstrings(file, max_len=max_len, check_only=check_only)
+ if diff:
+ changed.append(file)
+ if len(black_error) > 0:
+ black_errors.append(
+ f"There was a problem while formatting an example in {file} with black:\m{black_error}"
+ )
+ except Exception:
+ print(f"There is a problem in {file}.")
+ raise
+ else:
+ warnings.warn(f"Ignoring {file} because it's not a py or an mdx file or a folder.")
+ if len(black_errors) > 0:
+ black_message = "\n\n".join(black_errors)
+ raise ValueError(
+ "Some code examples can't be interpreted by black, which means they aren't regular python:\n\n"
+ + black_message
+ + "\n\nMake sure to fix the corresponding docstring or doc file, or remove the py/python after ``` if it "
+ + "was not supposed to be a Python code sample."
+ )
+ return changed
+
+
+def main(*files, max_len=119, check_only=False):
+ changed = style_doc_files(*files, max_len=max_len, check_only=check_only)
+ if check_only and len(changed) > 0:
+ raise ValueError(f"{len(changed)} files should be restyled!")
+ elif len(changed) > 0:
+ print(f"Cleaned {len(changed)} files!")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("files", nargs="+", help="The file(s) or folder(s) to restyle.")
+ parser.add_argument("--max_len", type=int, help="The maximum length of lines.")
+ parser.add_argument("--check_only", action="store_true", help="Whether to only check and not fix styling issues.")
+ args = parser.parse_args()
+
+ main(*args.files, max_len=args.max_len, check_only=args.check_only)