Skip to content

Commit

Permalink
fix type when tensorflow is not installed (#97)
Browse files Browse the repository at this point in the history
* fix type when tensorflow is not installed

* fix compressor bug

* fix torch tensorrt

Co-authored-by: Valerio Sofi <[email protected]>
  • Loading branch information
valeriosofi authored Sep 12, 2022
1 parent d013d2c commit db379e9
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 5 deletions.
Empty file.
10 changes: 8 additions & 2 deletions nebullvm/optimizers/tensor_rt.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def optimize_from_torch(
dataset,
batch_size=dataset.batch_size,
shuffle=False,
num_workers=1,
num_workers=0,
)

calibrator = torch_tensorrt.ptq.DataLoaderCalibrator(
Expand All @@ -350,8 +350,14 @@ def optimize_from_torch(
):
return None # Dynamic quantization is not supported on tensorRT

try:
torch.jit.script(torch_model.eval())
model = torch_model
except Exception:
model = torch.jit.trace(torch_model, input_data.get_list(1)[0])

trt_model = torch_tensorrt.compile(
torch_model.eval(),
model.eval(),
inputs=[
torch_tensorrt.Input(
(model_params.batch_size, *input_info.size),
Expand Down
8 changes: 5 additions & 3 deletions nebullvm/utils/optional_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@
from nebullvm.installers.installers import install_tf2onnx, install_tensorflow
from nebullvm.utils.general import check_module_version

NoneType = type(None)


class Keras:
Model = None
Model = NoneType


class Tensorflow:
Module = None
Tensor = None
Module = NoneType
Tensor = NoneType
keras = Keras()

@staticmethod
Expand Down

0 comments on commit db379e9

Please sign in to comment.