Skip to content

Commit

Permalink
Fixes for YOLOv8 (#144)
Browse files Browse the repository at this point in the history
* fixed links on speedster readme & added save/load example to all notebooks

* fix bugs in torchscript and tensorrt compilers

* added notebook for yolov8

* change input size in notebook
  • Loading branch information
valeriosofi authored Jan 10, 2023
1 parent e50d9ca commit 7dc5b8a
Show file tree
Hide file tree
Showing 7 changed files with 383 additions and 9 deletions.
2 changes: 1 addition & 1 deletion nebullvm/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from nebullvm.optional_modules.torch import torch


VERSION = "0.7.0"
VERSION = "0.7.1"
LEARNER_METADATA_FILENAME = "metadata.json"
ONNX_OPSET_VERSION = 13
NEBULLVM_DEBUG_FILE = "nebullvm_debug.json"
Expand Down
10 changes: 8 additions & 2 deletions nebullvm/operations/optimizations/compilers/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,22 @@ def execute(
model, quantization_type, input_tfms, train_input_data
)

self.compiled_model = self._compile_model(model, input_data)
self.compiled_model = self._compile_model(
model, input_data, quantization_type
)

def _compile_model(
self,
model: Union[Module, GraphModule],
input_data: DataManager,
quantization_type: QuantizationType,
) -> ScriptModule:
input_sample = input_data.get_list(1)[0]
if self.device is Device.GPU:
input_sample = [t.cuda() for t in input_sample]
if quantization_type is QuantizationType.HALF:
input_sample = [t.cuda().half() for t in input_sample]
else:
input_sample = [t.cuda() for t in input_sample]

if not isinstance(model, torch.fx.GraphModule):
model.eval()
Expand Down
5 changes: 4 additions & 1 deletion nebullvm/operations/optimizations/compilers/tensor_rt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import copy
import os
import subprocess
from pathlib import Path
Expand Down Expand Up @@ -153,7 +154,9 @@ def _compile_model(

with torch_tensorrt.logging.errors():
trt_model = torch_tensorrt.compile(
model,
model
if dtype is not torch.half
else copy.deepcopy(model).half(),
inputs=[
torch_tensorrt.Input(
tensor.shape,
Expand Down
Loading

0 comments on commit 7dc5b8a

Please sign in to comment.