Skip to content

Commit

Permalink
don't initialize cuda at import time (#3244)
Browse files Browse the repository at this point in the history
  • Loading branch information
technillogue authored Nov 8, 2024
1 parent 5129688 commit 233d0bf
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions py/torch_tensorrt/ts/_compile_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def _parse_compile_spec(compile_spec_: Dict[str, Any]) -> _ts_C.CompileSpec:
def TensorRTCompileSpec(
inputs: Optional[List[torch.Tensor | Input]] = None,
input_signature: Optional[Any] = None,
device: torch.device | Device = Device._current_device(),
device: Optional[torch.device | Device] = None,
disable_tf32: bool = False,
sparse_weights: bool = False,
enabled_precisions: Optional[Set[torch.dtype | dtype]] = None,
Expand Down Expand Up @@ -365,7 +365,7 @@ def TensorRTCompileSpec(
compile_spec = {
"inputs": inputs if inputs is not None else [],
# "input_signature": input_signature,
"device": device,
"device": Device._current_device() if device is None else device,
"disable_tf32": disable_tf32, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
"sparse_weights": sparse_weights, # Enable sparsity for convolution and fully connected layers.
"enabled_precisions": (
Expand Down

0 comments on commit 233d0bf

Please sign in to comment.