Skip to content

Commit

Permalink
chore: fix hybrid models (#765)
Browse files Browse the repository at this point in the history
  • Loading branch information
jfrery authored Jun 27, 2024
1 parent be6255c commit 23cb7dc
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 44 deletions.
37 changes: 16 additions & 21 deletions src/concrete/ml/torch/hybrid_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class HybridFHEMode(enum.Enum):
SIMULATE = "simulate" # Use FHE simulation
CALIBRATE = "calibrate" # Use calibration (to run before FHE compilation)
EXECUTE = "execute" # Use FHE execution
TORCH = "torch" # Use torch layers


def tuple_to_underscore_str(tup: Tuple) -> str:
Expand Down Expand Up @@ -236,15 +237,15 @@ def forward(self, x: torch.Tensor) -> Union[torch.Tensor, QuantTensor]:
Raises:
ValueError: if local_fhe_mode is not supported
"""
# - disable: torch module
# - disable: quantized module
# - remote: client-server
# - simulate: compiled simulation
# - calibrate: calibration

if self.fhe_local_mode not in {
HybridFHEMode.DISABLE,
HybridFHEMode.CALIBRATE,
HybridFHEMode.REMOTE,
HybridFHEMode.TORCH,
None,
}:
# Using quantized module
Expand All @@ -253,14 +254,6 @@ def forward(self, x: torch.Tensor) -> Union[torch.Tensor, QuantTensor]:
self.private_q_module.forward(x.detach().numpy(), fhe=self.fhe_local_mode.value)
)

elif self.fhe_local_mode == HybridFHEMode.DISABLE:
# Calling torch
assert self.private_module is not None
y = self.private_module.forward(
x.detach(),
)
assert isinstance(y, (QuantTensor, torch.Tensor))

elif self.fhe_local_mode == HybridFHEMode.CALIBRATE:
# Calling torch + gathering calibration data
assert self.private_module is not None
Expand All @@ -271,7 +264,10 @@ def forward(self, x: torch.Tensor) -> Union[torch.Tensor, QuantTensor]:
elif self.fhe_local_mode == HybridFHEMode.REMOTE: # pragma:no cover
# Remote call
y = self.remote_call(x)

elif self.fhe_local_mode == HybridFHEMode.TORCH:
# Using torch layers
assert self.private_module is not None
y = self.private_module(x)
else: # pragma:no cover
# Shouldn't happen
raise ValueError(f"{self.fhe_local_mode} is not recognized")
Expand Down Expand Up @@ -369,7 +365,7 @@ def __init__(
name: self._get_module_by_name(self.model, name) for name in self.module_names
}
self.remote_modules: Dict[str, RemoteModule] = {}
self.private_q_modules: dict = {}
self.private_q_modules: Dict[str, QuantizedModule] = {}
self.configuration: Optional[Configuration] = None
self.model_name = model_name
self.verbose = verbose
Expand Down Expand Up @@ -413,9 +409,7 @@ def __call__(self, x: torch.Tensor, fhe: str = "disable") -> torch.Tensor:
Returns:
(torch.Tensor): The output tensor.
"""
# Set the fhe mode in each remote module
for module in self.remote_modules.values():
module.fhe_local_mode = HybridFHEMode(fhe)
self.set_fhe_mode(fhe)
x = self.model(x)
return x

Expand Down Expand Up @@ -478,9 +472,7 @@ def compile_model(
encryption parameters. If not specified, a default configuration is used.
"""
# We do a forward pass where we accumulate inputs to use for compilation
for name in self.module_names:
# default is "calibrate"
self.remote_modules[name].fhe_local_mode = HybridFHEMode.CALIBRATE
self.set_fhe_mode(HybridFHEMode.CALIBRATE)
self.model(x)

self.configuration = configuration
Expand Down Expand Up @@ -523,12 +515,15 @@ def _save_fhe_circuit(self, path: Path, via_mlir=False):

model_path = Path(path)
for module_name in self.module_names:
onnx_model = self.private_q_modules[module_name].onnx_model

# mypy
assert onnx_model is not None
input_shapes = [
tuple(elt.dim_value for elt in onnx_input.type.tensor_type.shape.dim)
for onnx_input in self.private_q_modules[ # pylint: disable=protected-access
self.module_names[0]
]._onnx_model.graph.input
for onnx_input in onnx_model.graph.input
]

assert len(input_shapes) == 1, "Multi-input circuits not supported yet"
model_module_path = model_path.resolve() / module_name
model_module_path.mkdir(exist_ok=True)
Expand Down
45 changes: 22 additions & 23 deletions tests/torch/test_hybrid_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def run_hybrid_llm_test(
# Create a hybrid model
hybrid_model = HybridFHEModel(model, module_names)
hybrid_model.compile_model(
inputs, p_error=0.01, n_bits=8, rounding_threshold_bits=8, configuration=configuration
inputs, p_error=0.1, n_bits=9, rounding_threshold_bits=8, configuration=configuration
)

if has_pbs:
Expand All @@ -71,35 +71,34 @@ def run_hybrid_llm_test(
# Check we can run the simulate locally
logits_simulate = hybrid_model(inputs, fhe="simulate").logits
logits_disable = hybrid_model(inputs, fhe="disable").logits
logits_original = model(inputs).logits
logits_original = hybrid_model(inputs, fhe="torch").logits

# Ensure logits_disable and logits_original return the same output for the logits
assert torch.allclose(logits_disable, logits_original, atol=1e-7), "Outputs do not match!"
assert torch.allclose(logits_disable, logits_simulate, atol=1e-7), "Outputs do not match!"

# Compare the topk accuracy of the FHE simulate circuit vs. the original.
k = 100
k = 5

# Check that the topk next tokens are similar for the different FHE modes
# and the original model.

# Get the topk indices for logits_disable and logits_simulate
topk_disable = logits_disable.topk(k, dim=-1).indices
topk_simulate = logits_simulate.topk(k, dim=-1).indices
topk_original = logits_original.topk(k, dim=-1).indices

# Prepare tensors for broadcasting
expanded_simulate = topk_simulate.unsqueeze(-1)
expanded_disable = topk_disable.unsqueeze(-2)

# Compute if elements of topk_simulate are in topk_disable for each token
(expanded_simulate == expanded_disable).any(-1)

# Make sure accuracy is above a certain threshold
# Even with a small tolerance the test is flaky
# Commenting the assertion for now until issue is resolved
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3905
# Compute accuracy of disable and simulate by checking
# how many labels correspond with the topk_original
accuracy_disable = (topk_disable == topk_original).float().mean().item()
accuracy_simulate = (topk_simulate == topk_original).float().mean().item()

# Compute average of these counts (the accuracy)
# accuracy = is_in.float().mean()
# To use expected accuracy until the check is done
assert expected_accuracy > -1
# assert accuracy >= expected_accuracy, "Expected accuracy GPT2 hybrid not matched."
# Assert that both accuracy values are above the expected threshold
assert (
accuracy_disable >= expected_accuracy
), f"Disable accuracy {accuracy_disable:.4f} is below the expected {expected_accuracy:.4f}"
assert (
accuracy_simulate >= expected_accuracy
), f"Simulate accuracy {accuracy_simulate:.4f} is below the expected {expected_accuracy:.4f}"

with tempfile.TemporaryDirectory() as temp_dir:
temp_dir_path = Path(temp_dir)
Expand Down Expand Up @@ -127,9 +126,9 @@ def run_hybrid_llm_test(
@pytest.mark.parametrize(
"list_or_str_private_modules_names, expected_accuracy, has_pbs",
[
("transformer.h.0.mlp", 0.934, True),
(["transformer.h.0.mlp", "transformer.h.1.mlp"], 0.42, True),
("transformer.h.0.mlp.c_fc", 0.986, False),
("transformer.h.0.mlp", 0.95, True),
(["transformer.h.0.mlp", "transformer.h.1.mlp"], 0.40, True),
("transformer.h.0.mlp.c_fc", 1.0, False),
],
)
def test_gpt2_hybrid_mlp(list_or_str_private_modules_names, expected_accuracy, has_pbs):
Expand Down

0 comments on commit 23cb7dc

Please sign in to comment.