Skip to content

Commit

Permalink
chore: add gpu option
Browse files Browse the repository at this point in the history
  • Loading branch information
jfrery committed Jun 20, 2024
1 parent 11bc753 commit fa0b73a
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions use_case_examples/resnet/run_resnet18_fhe.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ def evaluate_model(model, processor, images, labels):
print(f"Top-5 Accuracy of the ResNet18 model on the images: {topk_accuracy*100:.2f}%")


def compile_model(model, images):
def compile_model(model, images, use_gpu=False):
# Enable TLU fusing to optimize the number of TLUs in the residual connections
config = Configuration(enable_tlu_fusing=True, print_tlu_fusing=False)
config = Configuration(enable_tlu_fusing=True, print_tlu_fusing=False, use_gpu=use_gpu)
print("Compiling the model...")
return compile_torch_model(
model,
Expand Down Expand Up @@ -111,14 +111,17 @@ def main():
parser.add_argument(
"--export_statistics", action="store_true", help="Export the circuit statistics."
)
parser.add_argument(
"--use_gpu", action="store_true", help="Use the available GPU at FHE runtime."
)
args = parser.parse_args()

resnet18 = load_model()
processor, calib_images, images, labels = load_data()

evaluate_model(resnet18, processor, images, labels)

q_module = compile_model(resnet18, calib_images)
q_module = compile_model(resnet18, calib_images, use_gpu=args.use_gpu)

if args.export_statistics:
export_statistics(q_module)
Expand Down

0 comments on commit fa0b73a

Please sign in to comment.