Skip to content

Commit 8a2ef26

Browse files
authored
[fix] Pickling issues with xformer models (#290) (#309)
* Tentatively fixing pickling issues, lazy init
1 parent 3ec97d9 commit 8a2ef26

File tree

3 files changed

+43
-17
lines changed

3 files changed

+43
-17
lines changed

tests/test_block_factory.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@
2727
VOCAB_SIZE = 64
2828

2929

30-
@pytest.mark.parametrize("attn_dropout", [0.0, 0.1])
31-
@pytest.mark.parametrize("residual_dropout", [0.0, 0.1])
30+
@pytest.mark.parametrize("attn_dropout", [0.1])
31+
@pytest.mark.parametrize("residual_dropout", [0.1])
3232
@pytest.mark.parametrize("heads", [1, 2])
3333
@pytest.mark.parametrize("activation", [a.value for a in Activation])
3434
@pytest.mark.parametrize("attention_name", ATTENTION_REGISTRY.keys())
3535
@pytest.mark.parametrize("feedforward_name", FEEDFORWARD_REGISTRY.keys())
36-
@pytest.mark.parametrize("layer_norm_style", ["pre", "post"])
36+
@pytest.mark.parametrize("layer_norm_style", ["pre", "post", "deepnorm"])
3737
@pytest.mark.parametrize("device", DEVICES)
3838
@pytest.mark.parametrize("reversible", [True, False])
3939
@pytest.mark.skipif(
@@ -127,15 +127,15 @@ def test_xformer_encoder_block(
127127
_ = block(inputs, input_mask=input_mask)
128128

129129

130-
@pytest.mark.parametrize("attn_dropout", [0.0, 0.1])
131-
@pytest.mark.parametrize("residual_dropout", [0.0, 0.1])
130+
@pytest.mark.parametrize("attn_dropout", [0.1])
131+
@pytest.mark.parametrize("residual_dropout", [0.1])
132132
@pytest.mark.parametrize("causal", [True, False])
133133
@pytest.mark.parametrize("heads", [1, 2])
134134
@pytest.mark.parametrize("activation", [a.value for a in Activation])
135135
@pytest.mark.parametrize("rotary_embeddings", [False, True])
136136
@pytest.mark.parametrize("attention_name", ATTENTION_REGISTRY.keys())
137137
@pytest.mark.parametrize("feedforward_name", FEEDFORWARD_REGISTRY.keys())
138-
@pytest.mark.parametrize("layer_norm_style", ["pre", "post"])
138+
@pytest.mark.parametrize("layer_norm_style", ["pre", "post", "deepnorm"])
139139
@pytest.mark.parametrize("device", DEVICES)
140140
@pytest.mark.skipif(
141141
not torch.cuda.is_available(), reason="This test requires a CUDA device"

tests/test_pickling.py

+19-7
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,17 @@
77
# https://github.com/facebookresearch/xformers/issues/203
88

99
import pickle
10+
from copy import deepcopy
1011

12+
import pytest
1113
from torch import nn
1214

15+
from xformers import _is_triton_available
1316
from xformers.factory import xFormer, xFormerConfig
1417

1518
test_config = [
1619
{
17-
"reversible": False, # Turn on to test the effect of using reversible layers
20+
"reversible": False,
1821
"block_type": "encoder",
1922
"num_layers": 2,
2023
"dim_model": 768,
@@ -30,7 +33,7 @@
3033
},
3134
},
3235
"feedforward_config": {
33-
"name": "MLP", # FIXME: Test with FusedMLP also
36+
"name": "FusedMLP",
3437
"dropout": 0.1,
3538
"activation": "gelu",
3639
"hidden_layer_multiplier": 4,
@@ -40,11 +43,20 @@
4043

4144

4245
class ViT(nn.Module):
43-
def __init__(self):
46+
def __init__(self, mlp):
4447
super().__init__()
45-
self.xformer = xFormer.from_config(xFormerConfig(test_config))
48+
test_config[0]["feedforward_config"]["name"] = mlp
49+
xformer_config = xFormerConfig(test_config)
50+
self.xformer = xFormer.from_config(xformer_config)
4651

4752

48-
def test_pickling():
49-
test = ViT()
50-
pickle.dumps(test)
53+
MLPs = ["MLP"]
54+
if _is_triton_available:
55+
MLPs.append("FusedMLP")
56+
57+
58+
@pytest.mark.parametrize("mlp", MLPs)
59+
def test_pickling(mlp):
60+
test = ViT(mlp)
61+
_ = pickle.dumps(test)
62+
_ = deepcopy(test)

xformers/triton/dropout.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# CREDITS: This comes almost as-is from the Triton dropout tutorial
88
# https://raw.githubusercontent.com/openai/triton/master/python/tutorials/04-low-memory-dropout.py
99

10-
from typing import Optional
10+
from typing import Any, Optional
1111

1212
import torch
1313
import triton
@@ -196,6 +196,11 @@ def dropout(
196196

197197

198198
class FusedDropoutBias(torch.nn.Module):
199+
"""
200+
A layer which fuses the computation of Dropout(Activation(x))
201+
in a single GPU kernel
202+
"""
203+
199204
def __init__(
200205
self,
201206
p: float,
@@ -216,15 +221,24 @@ def __init__(
216221
if bias_shape is not None
217222
else None
218223
)
219-
self.activation = get_triton_activation_kernel(activation)
220-
self.pytorch_activation = build_activation(self.activation_type)
221-
self.activation_grad = get_triton_activation_bwd_kernel(activation)
224+
225+
self.activation: Optional[Any] = None
226+
self.activation_grad: Optional[Any] = None
227+
self.activation_pytorch: Optional[Any] = None
222228

223229
def forward(self, x: torch.Tensor) -> torch.Tensor:
224230
# Convenience, catch a possible type or device mismatch
225231
if self.bias is not None:
226232
self.bias = self.bias.to(dtype=x.dtype, device=x.device) # type: ignore
227233

234+
# Lazy init (helps with pickling)
235+
if self.activation is None:
236+
self.activation = get_triton_activation_kernel(self.activation_type)
237+
self.pytorch_activation = build_activation(self.activation_type)
238+
self.activation_grad = get_triton_activation_bwd_kernel(
239+
self.activation_type
240+
)
241+
228242
# Train/inference
229243
p = self.p if self.training else 0.0
230244

0 commit comments

Comments
 (0)