Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
pacman100 committed Nov 25, 2022
1 parent 4eaf613 commit 61157ea
Show file tree
Hide file tree
Showing 16 changed files with 942 additions and 62 deletions.
141 changes: 141 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# VSCode
.vscode

# IntelliJ
.idea

# Mac .DS_Store
.DS_Store

# More test things
wandb
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include LICENSE
19 changes: 19 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
.PHONY: quality style test docs

check_dirs := src

# Check that source code meets quality standards

# this target runs checks on all files
quality:
black --check $(check_dirs)
isort --check-only $(check_dirs)
flake8 $(check_dirs)
python utils/style_doc.py src --max_len 119 --check_only

# Format source code automatically and check is there are any problems left that need manual fixing
style:
black $(check_dirs)
isort $(check_dirs)
python utils/style_doc.py src --max_len 119

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[tool.black]
line-length = 119
target-version = ['py36']
23 changes: 23 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
[isort]
default_section = FIRSTPARTY
ensure_newline_before_comments = True
force_grid_wrap = 0
include_trailing_comma = True
known_first_party = pet
known_third_party =
numpy
torch
accelerate
transformers

line_length = 119
lines_after_imports = 2
multi_line_output = 3
use_parentheses = True

[flake8]
ignore = E203, E722, E501, E741, W503, W605
max-line-length = 119

[tool:pytest]
doctest_optionflags=NUMBER NORMALIZE_WHITESPACE ELLIPSIS
78 changes: 78 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from setuptools import setup
from setuptools import find_packages

extras = {}
extras["quality"] = ["black ~= 22.0", "isort >= 5.5.4", "flake8 >= 3.8.3"]
extras["dev"] = extras["quality"]

setup(
name="pets",
version="0.1.0.dev0",
description="Parameter-Efficient Tuning at Scale (PETS)",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
keywords="deep learning",
license="Apache",
author="The HuggingFace team",
author_email="[email protected]",
url="https://github.com/huggingface/pets",
package_dir={"": "src"},
packages=find_packages("src"),
entry_points={},
python_requires=">=3.7.0",
install_requires=[
"numpy>=1.17",
"packaging>=20.0",
"psutil",
"pyyaml",
"torch>=1.4.0",
"transformers",
"accelerate",
],
extras_require=extras,
classifiers=[
"Development Status :: 5 - Production/Stable",
"Intended Audience :: Developers",
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.7",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
)

# Release checklist
# 1. Change the version in __init__.py and setup.py.
# 2. Commit these changes with the message: "Release: VERSION"
# 3. Add a tag in git to mark the release: "git tag VERSION -m 'Adds tag VERSION for pypi' "
# Push the tag to git: git push --tags origin main
# 4. Run the following commands in the top-level directory:
# python setup.py bdist_wheel
# python setup.py sdist
# 5. Upload the package to the pypi test server first:
# twine upload dist/* -r pypitest
# twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/
# 6. Check that you can install it in a virtualenv by running:
# pip install -i https://testpypi.python.org/pypi accelerate
# accelerate env
# accelerate test
# 7. Upload the final version to actual pypi:
# twine upload dist/* -r pypi
# 8. Add release notes to the tag in github once everything is looking hunky-dory.
# 9. Update the version in __init__.py, setup.py to the new version "-dev" and push to master
18 changes: 18 additions & 0 deletions src/pet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.

__version__ = "0.1.0.dev0"

from .pet_model import (
ParameterEfficientTuningModel,
ParameterEfficientTuningModelForSequenceClassification,
PromptEncoderType,
)
from .tuners import (
PrefixEncoder,
PromptEmbedding,
PromptEncoder,
PromptEncoderReparameterizationType,
PromptTuningInit,
)
14 changes: 8 additions & 6 deletions src/pet.py → src/pet/pet_model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from collections import OrderedDict
import enum
import warnings
from collections import OrderedDict

import torch
from accelerate.state import AcceleratorState
from transformers import PreTrainedModel

from tuners.p_tuning import PromptEncoder
from tuners.prefix_tuning import PrefixEncoder
from tuners.prompt_tuning import PromptEmbedding
from accelerate.state import AcceleratorState


class PromptEncoderType(str, enum.Enum):
Expand Down Expand Up @@ -88,8 +90,8 @@ def state_dict(self, destination=None, prefix=None, keep_vars=False):

def load_state_dict(self, state_dict, strict: bool = True):
"""
Custom load state dict method that only loads prompt table and prompt encoder
parameters. Matching load method for this class' custom state dict method.
Custom load state dict method that only loads prompt table and prompt encoder parameters. Matching load method
for this class' custom state dict method.
"""
self.prompt_encoder.embedding.load_state_dict({"weight": state_dict["prompt_embeddings"]}, strict)

Expand Down Expand Up @@ -187,8 +189,8 @@ def state_dict(self, destination=None, prefix=None, keep_vars=False):

def load_state_dict(self, state_dict, strict: bool = True):
"""
Custom load state dict method that only loads prompt table and prompt encoder
parameters. Matching load method for this class' custom state dict method.
Custom load state dict method that only loads prompt table and prompt encoder parameters. Matching load method
for this class' custom state dict method.
"""
super().load_state_dict(state_dict["prompt_encoder"], strict)
self.model.classifier.load_state_dict(state_dict["classifier"], strict)
Expand Down
Loading

0 comments on commit 61157ea

Please sign in to comment.