Skip to content

Commit

Permalink
chore: fix server to client transfer time + add Q K V projection
Browse files Browse the repository at this point in the history
  • Loading branch information
jfrery authored Sep 25, 2023
1 parent 5c916ba commit 3c88171
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 21 deletions.
4 changes: 2 additions & 2 deletions use_case_examples/hybrid_model/compile_hybrid_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@ def module_names_parser(string: str) -> List[str]:
arg_parser.add_argument(
"--module-names",
dest="module_names",
default=["transformer.h.0.attn.c_proj"],
default=["transformer.h.0.attn.c_attn"],
type=module_names_parser,
help="""The module(s) name(s) to compile to FHE.
Examples for GPT-2 model:
"transformer.h.0.mlp" for a full MLP
"transformer.h.0.mlp, "transformer.h.1.mlp" for two full MLPs
"transformer.h.0.mlp.c_proj" for only one projection in MLP
"transformer.h.0.attn.c_proj" for only one projection in attention
"transformer.h.0.attn.c_attn" for the Q, K, V projections in the attention
These names might vary according to your model.
""",
Expand Down
64 changes: 45 additions & 19 deletions use_case_examples/hybrid_model/serve_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from collections import defaultdict
from functools import lru_cache
from pathlib import Path
from typing import Optional, Tuple, Union
from typing import Generator, Optional, Tuple, Union

import uvicorn
from fastapi import FastAPI, Form, HTTPException, UploadFile
Expand Down Expand Up @@ -167,45 +167,71 @@ async def add_key(
dump_key(key_bytes, uid)
return {"uid": uid}

def stream_response(
encrypted_results: bytes, chunk_size: int = 1024 * 1024
) -> Generator[bytes, None, None]:
"""Yields chunks of encrypted results.
Args:
encrypted_results (bytes): The byte data to be streamed.
chunk_size (int): The size of the chunks in which encrypted_results should be streamed.
Defaults to 1MB.
Returns:
bytes: Chunks of encrypted_results of size chunk_size.
"""
buffer = io.BytesIO(encrypted_results)
while chunk := buffer.read(chunk_size):
yield chunk

@app.post("/compute")
async def compute(
model_input: UploadFile,
uid: str = Form(),
model_name: str = Form(),
module_name: str = Form(),
input_shape: str = Form(),
): # noqa: B008
"""Compute the circuit over encrypted input.
):
"""
Computes the circuit over encrypted input.
Arguments:
model_input (UploadFile): input of the circuit
uid (str): uid of the public key to use
Args:
model_input (UploadFile): Input of the circuit.
uid (str): The UID of the public key to use for computations.
model_name (str): The name of the model to be used.
module_name (str): The name of the module containing the computation circuit.
input_shape (str): The shape of the input data.
Returns:
StreamingResponse: the result of the circuit
StreamingResponse: The result of the computation, streamed back in chunks.
"""
check_inputs(model_name, module_name, input_shape)

# Read the uploaded file first to avoid including this I/O time in FHE inference runtime measurement.
logger.info("Reading uploaded data...")
start_read = time.time()
uploaded_data = await model_input.read()
logger.info(f"Uploaded data read in {time.time() - start_read} seconds")

logger.info("Loading key...")
start = time.time()
key_bytes = load_key(uid)
end = time.time()
logger.info(f"It took {end - start} seconds to load the key")
logger.info(f"Key loaded in {time.time() - start} seconds")

logger.info("Loading circuit...")
start = time.time()
fhe = get_circuit(model_name, module_name, input_shape)
end = time.time()
logger.info(f"It took {end - start} seconds to load the circuit")
logger.info(f"Circuit loaded in {time.time() - start} seconds")

logger.info("Running FHE inference...")
start = time.time()
encrypted_results = fhe.run(
serialized_encrypted_quantized_data=await model_input.read(),
serialized_encrypted_quantized_data=uploaded_data,
serialized_evaluation_keys=key_bytes,
)
end = time.time()
logger.info(f"fhe inference of input of shape {input_shape} took {end - start}")
logger.info(f"Results size is {len(encrypted_results)/(1024**2)} Mb")
start = time.time()
return StreamingResponse(
io.BytesIO(encrypted_results),
)
logger.info(f"FHE inference completed in {time.time() - start} seconds")
logger.info(f"Results size is {len(encrypted_results) / (1024 ** 2)} Mb")

return StreamingResponse(stream_response(encrypted_results))

uvicorn.run(app, host="0.0.0.0", port=int(PORT))

0 comments on commit 3c88171

Please sign in to comment.