From 3c8817171cd77c07db9a9674beba02874749df4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jordan=20Fr=C3=A9ry?= Date: Mon, 25 Sep 2023 16:53:42 +0200 Subject: [PATCH] chore: fix server to client transfer time + add Q K V projection --- .../hybrid_model/compile_hybrid_llm.py | 4 +- use_case_examples/hybrid_model/serve_model.py | 64 +++++++++++++------ 2 files changed, 47 insertions(+), 21 deletions(-) diff --git a/use_case_examples/hybrid_model/compile_hybrid_llm.py b/use_case_examples/hybrid_model/compile_hybrid_llm.py index 652cb45eb..67c7a0510 100644 --- a/use_case_examples/hybrid_model/compile_hybrid_llm.py +++ b/use_case_examples/hybrid_model/compile_hybrid_llm.py @@ -47,14 +47,14 @@ def module_names_parser(string: str) -> List[str]: arg_parser.add_argument( "--module-names", dest="module_names", - default=["transformer.h.0.attn.c_proj"], + default=["transformer.h.0.attn.c_attn"], type=module_names_parser, help="""The module(s) name(s) to compile to FHE. Examples for GPT-2 model: "transformer.h.0.mlp" for a full MLP "transformer.h.0.mlp, "transformer.h.1.mlp" for two full MLPs "transformer.h.0.mlp.c_proj" for only one projection in MLP -"transformer.h.0.attn.c_proj" for only one projection in attention +"transformer.h.0.attn.c_attn" for the Q, K, V projections in the attention These names might vary according to your model. """, diff --git a/use_case_examples/hybrid_model/serve_model.py b/use_case_examples/hybrid_model/serve_model.py index a72974d03..28baebdfe 100644 --- a/use_case_examples/hybrid_model/serve_model.py +++ b/use_case_examples/hybrid_model/serve_model.py @@ -15,7 +15,7 @@ from collections import defaultdict from functools import lru_cache from pathlib import Path -from typing import Optional, Tuple, Union +from typing import Generator, Optional, Tuple, Union import uvicorn from fastapi import FastAPI, Form, HTTPException, UploadFile @@ -167,6 +167,23 @@ async def add_key( dump_key(key_bytes, uid) return {"uid": uid} + def stream_response( + encrypted_results: bytes, chunk_size: int = 1024 * 1024 + ) -> Generator[bytes, None, None]: + """Yields chunks of encrypted results. + + Args: + encrypted_results (bytes): The byte data to be streamed. + chunk_size (int): The size of the chunks in which encrypted_results should be streamed. + Defaults to 1MB. + + Returns: + bytes: Chunks of encrypted_results of size chunk_size. + """ + buffer = io.BytesIO(encrypted_results) + while chunk := buffer.read(chunk_size): + yield chunk + @app.post("/compute") async def compute( model_input: UploadFile, @@ -174,38 +191,47 @@ async def compute( model_name: str = Form(), module_name: str = Form(), input_shape: str = Form(), - ): # noqa: B008 - """Compute the circuit over encrypted input. + ): + """ + Computes the circuit over encrypted input. - Arguments: - model_input (UploadFile): input of the circuit - uid (str): uid of the public key to use + Args: + model_input (UploadFile): Input of the circuit. + uid (str): The UID of the public key to use for computations. + model_name (str): The name of the model to be used. + module_name (str): The name of the module containing the computation circuit. + input_shape (str): The shape of the input data. Returns: - StreamingResponse: the result of the circuit + StreamingResponse: The result of the computation, streamed back in chunks. """ check_inputs(model_name, module_name, input_shape) + + # Read the uploaded file first to avoid including this I/O time in FHE inference runtime measurement. + logger.info("Reading uploaded data...") + start_read = time.time() + uploaded_data = await model_input.read() + logger.info(f"Uploaded data read in {time.time() - start_read} seconds") + + logger.info("Loading key...") start = time.time() key_bytes = load_key(uid) - end = time.time() - logger.info(f"It took {end - start} seconds to load the key") + logger.info(f"Key loaded in {time.time() - start} seconds") + logger.info("Loading circuit...") start = time.time() fhe = get_circuit(model_name, module_name, input_shape) - end = time.time() - logger.info(f"It took {end - start} seconds to load the circuit") + logger.info(f"Circuit loaded in {time.time() - start} seconds") + logger.info("Running FHE inference...") start = time.time() encrypted_results = fhe.run( - serialized_encrypted_quantized_data=await model_input.read(), + serialized_encrypted_quantized_data=uploaded_data, serialized_evaluation_keys=key_bytes, ) - end = time.time() - logger.info(f"fhe inference of input of shape {input_shape} took {end - start}") - logger.info(f"Results size is {len(encrypted_results)/(1024**2)} Mb") - start = time.time() - return StreamingResponse( - io.BytesIO(encrypted_results), - ) + logger.info(f"FHE inference completed in {time.time() - start} seconds") + logger.info(f"Results size is {len(encrypted_results) / (1024 ** 2)} Mb") + + return StreamingResponse(stream_response(encrypted_results)) uvicorn.run(app, host="0.0.0.0", port=int(PORT))