Skip to content

Commit

Permalink
Sparsity docs update (#1590)
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva authored Jan 21, 2025
1 parent ea7910e commit 5d1444b
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 19 deletions.
6 changes: 3 additions & 3 deletions docs/source/api_ref_sparsity.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ torchao.sparsity

WandaSparsifier
PerChannelNormObserver
apply_sparse_semi_structured
apply_fake_sparsity


sparsify_
semi_sparse_weight
int8_dynamic_activation_int8_semi_sparse_weight
32 changes: 16 additions & 16 deletions torchao/sparsity/sparse_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def sparsify_(
apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor],
filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None,
) -> torch.nn.Module:
"""Convert the weight of linear modules in the model with `apply_tensor_subclass`
"""Convert the weight of linear modules in the model with `apply_tensor_subclass`.
This function is essentially the same as quantize, put for sparsity subclasses.
Currently, we support three options for sparsity:
Expand All @@ -54,26 +54,26 @@ def sparsify_(
Args:
model (torch.nn.Module): input model
apply_tensor_subclass (Callable[[torch.Tensor], torch.Tensor]): function that convert a floating point Tensor to a (sparsified) tensor subclass instance (e.g. affine quantized tensor instance)
filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on
the weight of the module
filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on the weight of the module
Example::
import torch
import torch.nn as nn
from torchao.sparsity import sparsify_
**Example:**
::
import torch
import torch.nn as nn
from torchao.sparsity import sparsify_
def filter_fn(module: nn.Module, fqn: str) -> bool:
return isinstance(module, nn.Linear)
def filter_fn(module: nn.Module, fqn: str) -> bool:
return isinstance(module, nn.Linear)
m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32))
m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32))
# for 2:4 sparsity
from torchao.sparse_api import semi_sparse_weight
m = sparsify_(m, semi_sparse_weight(), filter_fn)
# for 2:4 sparsity
from torchao.sparse_api import semi_sparse_weight
m = sparsify_(m, semi_sparse_weight(), filter_fn)
# for int8 dynamic quantization + 2:4 sparsity
from torchao.dtypes import SemiSparseLayout
m = quantize_(m, int8_dynamic_activation_int8_weight(layout=SemiSparseLayout), filter_fn)
# for int8 dynamic quantization + 2:4 sparsity
from torchao.dtypes import SemiSparseLayout
m = quantize_(m, int8_dynamic_activation_int8_weight(layout=SemiSparseLayout), filter_fn)
"""
_replace_with_custom_fn_if_matches_filter(
model,
Expand Down

0 comments on commit 5d1444b

Please sign in to comment.