Skip to content

Commit

Permalink
Add Basic Pylint and Mypy Tooling (mlc-ai#1100)
Browse files Browse the repository at this point in the history
Add pylint/mypy tooling into pyproject.toml

This PR establishes the initial Python tooling infra with Pylint and
Mypy. Currently only the newest modules, i.e. `mlc_chat.support` and
`mlc_chat.compiler` are covered, and we expect to cover the entire
package, as being tracked in mlc-ai#1101.
  • Loading branch information
junrushao authored Oct 21, 2023
1 parent 03c641a commit 46d11e6
Show file tree
Hide file tree
Showing 10 changed files with 100 additions and 21 deletions.
34 changes: 31 additions & 3 deletions .github/workflows/python_lint.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
name: Python Lint

on: [push, pull_request]

env:
IMAGE: 'mlcaidev/ci-cpu:8a87699'
IMAGE: 'mlcaidev/ci-cpu:2c03e7f'

jobs:
isort:
Expand Down Expand Up @@ -35,3 +33,33 @@ jobs:
- name: Lint
run: |
./ci/bash.sh $IMAGE bash ./ci/task/black.sh
mypy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
submodules: 'recursive'
- name: Version
run: |
wget https://raw.githubusercontent.com/mlc-ai/package/main/docker/bash.sh -O ./ci/bash.sh
chmod u+x ./ci/bash.sh
./ci/bash.sh $IMAGE "conda env export --name ci-lint"
- name: Lint
run: |
./ci/bash.sh $IMAGE bash ./ci/task/mypy.sh
pylint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
submodules: 'recursive'
- name: Version
run: |
wget https://raw.githubusercontent.com/mlc-ai/package/main/docker/bash.sh -O ./ci/bash.sh
chmod u+x ./ci/bash.sh
./ci/bash.sh $IMAGE "conda env export --name ci-lint"
- name: Lint
run: |
./ci/bash.sh $IMAGE bash ./ci/task/pylint.sh
3 changes: 2 additions & 1 deletion ci/task/black.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ set -eo pipefail

source ~/.bashrc
micromamba activate ci-lint
NUM_THREADS=$(nproc)
export NUM_THREADS=$(nproc)
export PYTHONPATH="./python:$PYTHONPATH"

black --check --workers $NUM_THREADS ./python/
black --check --workers $NUM_THREADS ./tests/python
3 changes: 2 additions & 1 deletion ci/task/isort.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ set -eo pipefail

source ~/.bashrc
micromamba activate ci-lint
NUM_THREADS=$(nproc)
export NUM_THREADS=$(nproc)
export PYTHONPATH="./python:$PYTHONPATH"

isort --check-only -j $NUM_THREADS --profile black ./python/
isort --check-only -j $NUM_THREADS --profile black ./tests/python/
10 changes: 10 additions & 0 deletions ci/task/mypy.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/bin/bash
set -eo pipefail

source ~/.bashrc
micromamba activate ci-lint
export NUM_THREADS=$(nproc)
export PYTHONPATH="./python:$PYTHONPATH"

mypy ./python/mlc_chat/compiler ./python/mlc_chat/support
mypy ./tests/python/model ./tests/python/parameter
13 changes: 13 additions & 0 deletions ci/task/pylint.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#!/bin/bash
set -eo pipefail

source ~/.bashrc
micromamba activate ci-lint
export NUM_THREADS=$(nproc)
export PYTHONPATH="./python:$PYTHONPATH"

# TVM Unity is a dependency to this testing
pip install --quiet --pre -U -f https://mlc.ai/wheels mlc-ai-nightly

pylint --jobs $NUM_THREADS ./python/mlc_chat/compiler ./python/mlc_chat/support
pylint --jobs $NUM_THREADS --recursive=y ./tests/python/model ./tests/python/parameter
17 changes: 16 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,19 @@ profile = "black"

[tool.black]
line-length = 100
target-version = ['py310']

[tool.mypy]
ignore_missing_imports = true
show_column_numbers = true
show_error_context = true
follow_imports = "skip"
ignore_errors = false
strict_optional = false
install_types = true
non_interactive = true

[tool.pylint.messages_control]
max-line-length = 100
disable = """
duplicate-code,
"""
16 changes: 9 additions & 7 deletions python/mlc_chat/compiler/model/llama_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
This file specifies how MLC's Llama parameter maps from other formats, for example HuggingFace
PyTorch, HuggingFace safetensors.
"""
from typing import Callable, Dict, List

import numpy as np

from ..parameter import ExternMapping
Expand All @@ -26,33 +28,33 @@ def hf_torch(model_config: LlamaConfig) -> ExternMapping:
_, named_params = model.export_tvm(spec=model.get_default_spec())
parameter_names = {name for name, _ in named_params}

param_map = {}
map_func = {}
param_map: Dict[str, List[str]] = {}
map_func: Dict[str, Callable] = {}
unused_params = set()

for i in range(model_config.num_hidden_layers):
# Add QKV in self attention
attn = f"model.layers.{i}.self_attn"
assert f"{attn}.qkv_proj.weight" in parameter_names
map_func[f"{attn}.qkv_proj.weight"] = lambda q, k, v: np.concatenate([q, k, v], axis=0)
param_map[f"{attn}.qkv_proj.weight"] = (
param_map[f"{attn}.qkv_proj.weight"] = [
f"{attn}.q_proj.weight",
f"{attn}.k_proj.weight",
f"{attn}.v_proj.weight",
)
]
# Add gates in MLP
mlp = f"model.layers.{i}.mlp"
assert f"{mlp}.gate_up_proj.weight" in parameter_names
map_func[f"{mlp}.gate_up_proj.weight"] = lambda gate, up: np.concatenate([gate, up], axis=0)
param_map[f"{mlp}.gate_up_proj.weight"] = (
param_map[f"{mlp}.gate_up_proj.weight"] = [
f"{mlp}.gate_proj.weight",
f"{mlp}.up_proj.weight",
)
]
# inv_freq is not used in the model
unused_params.add(f"{attn}.rotary_emb.inv_freq")

for name in parameter_names:
if name not in map_func:
map_func[name] = lambda x: x
param_map[name] = (name,)
param_map[name] = [name]
return ExternMapping(param_map, map_func, unused_params)
18 changes: 13 additions & 5 deletions python/mlc_chat/compiler/parameter/mapping.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
"""Parameter mapping for converting different LLM implementations to MLC LLM."""
import dataclasses
from typing import Callable, Dict, List, Set
from typing import Callable, Dict, List, Set, Union

import numpy as np
from tvm.runtime import NDArray

MapFuncVariadic = Union[
Callable[[], np.ndarray],
Callable[[np.ndarray], np.ndarray],
Callable[[np.ndarray, np.ndarray], np.ndarray],
Callable[[np.ndarray, np.ndarray, np.ndarray], np.ndarray],
Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray],
]


@dataclasses.dataclass
class ExternMapping:
Expand Down Expand Up @@ -33,8 +41,8 @@ class ExternMapping:
"""

param_map: Dict[str, List[str]]
map_func: Dict[str, Callable[[np.ndarray, ...], np.ndarray]]
unused_params: Set[str] = dataclasses.field(default_factory=dict)
map_func: Dict[str, MapFuncVariadic]
unused_params: Set[str] = dataclasses.field(default_factory=set)


@dataclasses.dataclass
Expand Down Expand Up @@ -72,8 +80,8 @@ class QuantizeMapping:
used to convert the quantized parameters into the desired form.
"""

param_map: Dict[str, Callable[str, List[str]]]
map_func: Dict[str, Callable[NDArray, List[NDArray]]]
param_map: Dict[str, Callable[[str], List[str]]]
map_func: Dict[str, Callable[[NDArray], List[NDArray]]]


__all__ = ["ExternMapping", "QuantizeMapping"]
4 changes: 2 additions & 2 deletions python/mlc_chat/support/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ def from_dict(cls: Type[ConfigClass], source: Dict[str, Any]) -> ConfigClass:
cfg : ConfigClass
An instance of the config object.
"""
field_names = [field.name for field in dataclasses.fields(cls)]
field_names = [field.name for field in dataclasses.fields(cls)] # type: ignore[arg-type]
fields = {k: v for k, v in source.items() if k in field_names}
kwargs = {k: v for k, v in source.items() if k not in field_names}
return cls(**fields, kwargs=kwargs)
return cls(**fields, kwargs=kwargs) # type: ignore[call-arg]

@classmethod
def from_file(cls: Type[ConfigClass], source: Path) -> ConfigClass:
Expand Down
3 changes: 2 additions & 1 deletion tests/python/parameter/test_hf_torch_loader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# pylint: disable=missing-docstring
import logging
from pathlib import Path
from typing import Union

import pytest
from mlc_chat.compiler.model.llama import LlamaConfig
Expand All @@ -24,7 +25,7 @@
"./dist/models/Llama-2-70b-hf",
],
)
def test_load_llama(base_path: str):
def test_load_llama(base_path: Union[str, Path]):
base_path = Path(base_path)
path_config = base_path / "config.json"
path_params = base_path / "pytorch_model.bin.index.json"
Expand Down

0 comments on commit 46d11e6

Please sign in to comment.