-
Notifications
You must be signed in to change notification settings - Fork 151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
chore: implement hybrid model demo with GPT-2 #246
Conversation
ac760f5
to
4504d5a
Compare
Yes I think so. |
c2a1289
to
0b6ef48
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nicely done! I have a few comments as well as question for my understanding.
@@ -695,5 +695,5 @@ | |||
} | |||
}, | |||
"nbformat": 4, | |||
"nbformat_minor": 2 | |||
"nbformat_minor": 4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this expected?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No idea actually, nbqa
probably? 🤷🏼
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure. It's not important I guess.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't modify the file myself
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't do much about this one
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably a version difference or something
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add HybridFHEModel to init.py so that we can do from concrete.ml.torch import HybridFHEModel
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good idea
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For some reason this result in a circular import
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weird
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤷🏼
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess it's because the hybridmodel.py import something from concrete.ml.torch? Not sure there is an easy solution then. We leave it like this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Some minor changes especially to docs: ex: better docstrings for the arguments to the Hybrid class
src/concrete/ml/onnx/convert.py
Outdated
@@ -47,7 +47,58 @@ def get_equivalent_numpy_forward_and_onnx_model( | |||
opset_version=OPSET_VERSION_FOR_ONNX_EXPORT, | |||
) | |||
equivalent_onnx_model = onnx.load_model(str(output_onnx_file_path)) | |||
|
|||
# List of all currently supported onnx passes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we keep this comment here or can we link to the list in the onnxoptimizer repo ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Link should be enough, I'll remove
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed and kept only the link to the repository
continue | ||
# Store MatMul node output name | ||
matmul_node_output_name = matmul_node.output[0] | ||
assert len(matmul_node.output) == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good! we expect the matmul node to always have a single output
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that is expected, this is more of a guide for someone reading the code than a real assert
|
||
from .onnx_utils import IMPLEMENTED_ONNX_OPS, execute_onnx_with_numpy, get_op_type | ||
|
||
OPSET_VERSION_FOR_ONNX_EXPORT = 14 | ||
|
||
|
||
def get_equivalent_numpy_forward_and_onnx_model( | ||
# pylint: disable=too-many-branches | ||
def fuse_matmul_bias_to_gemm(onnx_model: onnx.ModelProto): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well done!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! 🙏🏼
src/concrete/ml/onnx/convert.py
Outdated
# Optimize ONNX graph | ||
# List of all currently supported onnx optimizer passes | ||
# From https://github.com/onnx/optimizer/blob/master/onnxoptimizer/pass_registry.h | ||
# onnx_passes = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we keep this here ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes maybe it's not necessary to keep them all here since the relevant link is given
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed them
rounding_threshold_bits: int = 8, | ||
p_error=0.01, | ||
configuration: Configuration = None, | ||
rounding_threshold_bits: Optional[int] = 8, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems like a weird default to me.. why is 8 a good value ? do we know it's a good value for LLMs or for any NN in general ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll change this to the normal defaults (None
) this was here from a previous PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed defaults back to normal
print(f"Using device: {device}") | ||
|
||
# Get GPT2 from Huggingface | ||
# TODO: migrate to auto-model with model_name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you make an issue on this ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is done already no? The following lines use the automodel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed the comment
from concrete.ml.torch.hybrid_model import FHEMode, HybridFHEModel | ||
|
||
if __name__ == "__main__": | ||
configs = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you explain what is in this config structure ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also, they are already defined in the compile file : maybe refactor this into a single config file to avoid any unwanted mismatches ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking about it but creating a file just for that seemed a bit overkill.
Can do it if you think it's necessary.
I also thought about dumping a json in the compile file with the configuration that is then re-used in the inference script.
Would that work for you?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think having a config file is fine but as you wish, second solutions seems ok
in any case as Andrei said a comment explaining what these configs are would be great as well !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Compilation dumps a json that can then be used by the client
0a414b8
to
c964cf2
Compare
c964cf2
to
ea204f4
Compare
52be915
to
2c3cbb3
Compare
default=["transformer.h.0.attn.c_proj"], | ||
type=module_names_parser, | ||
help="""The module(s) name(s) to compile to FHE. | ||
Examples for GPT-2 model: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weird indent, but great thanks !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would try to see if it's possible to use argparse in serve.py but if not, than looks good to me ! Huge work here thanks a lot for that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. I am testing this with phi1.5 to see how it generalizes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It misses loguru I think?
6ab7d93
2c3cbb3
to
6ab7d93
Compare
Fixed @jfrery 's issue and squashed commits. |
Coverage passed ✅Coverage details
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good thanks
Closes https://github.com/zama-ai/concrete-ml-internal/issues/3842
Closes https://github.com/zama-ai/concrete-ml-internal/issues/3869
Closes https://github.com/zama-ai/concrete-ml-internal/issues/3875
Closes https://github.com/zama-ai/concrete-ml-internal/issues/3855
Closes https://github.com/zama-ai/concrete-ml-internal/issues/3852