Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: implement hybrid model demo with GPT-2 #246

Merged
merged 1 commit into from
Sep 21, 2023
Merged

Conversation

@cla-bot cla-bot bot added the cla-signed label Sep 13, 2023
@fd0r fd0r force-pushed the llm_partial_fhe_pr branch 12 times, most recently from ac760f5 to 4504d5a Compare September 14, 2023 09:56
@fd0r
Copy link
Collaborator Author

fd0r commented Sep 15, 2023

Yes I think so.
Currently fixing the batch-size/gemm optimization issue.

@fd0r fd0r force-pushed the llm_partial_fhe_pr branch 4 times, most recently from c2a1289 to 0b6ef48 Compare September 18, 2023 17:21
@fd0r fd0r marked this pull request as ready for review September 18, 2023 18:32
@fd0r fd0r requested a review from a team as a code owner September 18, 2023 18:32
Copy link
Collaborator

@jfrery jfrery left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nicely done! I have a few comments as well as question for my understanding.

use_case_examples/llm/QGPT2Evaluate.ipynb Show resolved Hide resolved
src/concrete/ml/torch/compile.py Show resolved Hide resolved
src/concrete/ml/torch/hybrid_model.py Outdated Show resolved Hide resolved
tests/torch/test_compile_torch.py Outdated Show resolved Hide resolved
use_case_examples/hybrid_model/README.md Show resolved Hide resolved
use_case_examples/hybrid_model/compile_hybrid_llm.py Outdated Show resolved Hide resolved
use_case_examples/hybrid_model/compile_hybrid_llm.py Outdated Show resolved Hide resolved
use_case_examples/hybrid_model/load_and_analyze_data.py Outdated Show resolved Hide resolved
@@ -695,5 +695,5 @@
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this expected?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No idea actually, nbqa probably? 🤷🏼

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure. It's not important I guess.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't modify the file myself

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't do much about this one

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably a version difference or something

Copy link
Collaborator

@jfrery jfrery Sep 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add HybridFHEModel to init.py so that we can do from concrete.ml.torch import HybridFHEModel ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some reason this result in a circular import

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weird

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤷🏼

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it's because the hybridmodel.py import something from concrete.ml.torch? Not sure there is an easy solution then. We leave it like this?

Copy link
Collaborator

@andrei-stoian-zama andrei-stoian-zama left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Some minor changes especially to docs: ex: better docstrings for the arguments to the Hybrid class

@@ -47,7 +47,58 @@ def get_equivalent_numpy_forward_and_onnx_model(
opset_version=OPSET_VERSION_FOR_ONNX_EXPORT,
)
equivalent_onnx_model = onnx.load_model(str(output_onnx_file_path))

# List of all currently supported onnx passes
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we keep this comment here or can we link to the list in the onnxoptimizer repo ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Link should be enough, I'll remove

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed and kept only the link to the repository

continue
# Store MatMul node output name
matmul_node_output_name = matmul_node.output[0]
assert len(matmul_node.output) == 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good! we expect the matmul node to always have a single output

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that is expected, this is more of a guide for someone reading the code than a real assert


from .onnx_utils import IMPLEMENTED_ONNX_OPS, execute_onnx_with_numpy, get_op_type

OPSET_VERSION_FOR_ONNX_EXPORT = 14


def get_equivalent_numpy_forward_and_onnx_model(
# pylint: disable=too-many-branches
def fuse_matmul_bias_to_gemm(onnx_model: onnx.ModelProto):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well done!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! 🙏🏼

# Optimize ONNX graph
# List of all currently supported onnx optimizer passes
# From https://github.com/onnx/optimizer/blob/master/onnxoptimizer/pass_registry.h
# onnx_passes = [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we keep this here ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes maybe it's not necessary to keep them all here since the relevant link is given

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed them

src/concrete/ml/torch/compile.py Show resolved Hide resolved
rounding_threshold_bits: int = 8,
p_error=0.01,
configuration: Configuration = None,
rounding_threshold_bits: Optional[int] = 8,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems like a weird default to me.. why is 8 a good value ? do we know it's a good value for LLMs or for any NN in general ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll change this to the normal defaults (None) this was here from a previous PR

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed defaults back to normal

tests/torch/test_compile_torch.py Outdated Show resolved Hide resolved
print(f"Using device: {device}")

# Get GPT2 from Huggingface
# TODO: migrate to auto-model with model_name
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you make an issue on this ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is done already no? The following lines use the automodel

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the comment

from concrete.ml.torch.hybrid_model import FHEMode, HybridFHEModel

if __name__ == "__main__":
configs = [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain what is in this config structure ?

Copy link
Collaborator

@RomanBredehoft RomanBredehoft Sep 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, they are already defined in the compile file : maybe refactor this into a single config file to avoid any unwanted mismatches ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about it but creating a file just for that seemed a bit overkill.

Can do it if you think it's necessary.
I also thought about dumping a json in the compile file with the configuration that is then re-used in the inference script.
Would that work for you?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think having a config file is fine but as you wish, second solutions seems ok

in any case as Andrei said a comment explaining what these configs are would be great as well !

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Compilation dumps a json that can then be used by the client

use_case_examples/hybrid_model/load_and_analyze_data.py Outdated Show resolved Hide resolved
@fd0r fd0r force-pushed the llm_partial_fhe_pr branch 2 times, most recently from 0a414b8 to c964cf2 Compare September 20, 2023 12:38
@fd0r fd0r force-pushed the llm_partial_fhe_pr branch from c964cf2 to ea204f4 Compare September 20, 2023 12:50
@fd0r fd0r force-pushed the llm_partial_fhe_pr branch 2 times, most recently from 52be915 to 2c3cbb3 Compare September 20, 2023 14:35
@fd0r fd0r requested a review from jfrery September 20, 2023 15:06
default=["transformer.h.0.attn.c_proj"],
type=module_names_parser,
help="""The module(s) name(s) to compile to FHE.
Examples for GPT-2 model:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weird indent, but great thanks !

RomanBredehoft
RomanBredehoft previously approved these changes Sep 20, 2023
Copy link
Collaborator

@RomanBredehoft RomanBredehoft left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would try to see if it's possible to use argparse in serve.py but if not, than looks good to me ! Huge work here thanks a lot for that

jfrery
jfrery previously approved these changes Sep 20, 2023
Copy link
Collaborator

@jfrery jfrery left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. I am testing this with phi1.5 to see how it generalizes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It misses loguru I think?

@fd0r
Copy link
Collaborator Author

fd0r commented Sep 21, 2023

Fixed @jfrery 's issue and squashed commits.

@github-actions
Copy link

Coverage passed ✅

Coverage details

---------- coverage: platform linux, python 3.8.18-final-0 -----------
Name    Stmts   Miss  Cover   Missing
-------------------------------------
TOTAL    5954      0   100%

50 files skipped due to complete coverage.

Copy link
Collaborator

@jfrery jfrery left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good thanks

@fd0r fd0r merged commit f1d1490 into main Sep 21, 2023
8 checks passed
@fd0r fd0r deleted the llm_partial_fhe_pr branch September 21, 2023 12:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants