diff --git a/torch_xla/_dynamo/dynamo_bridge.py b/torch_xla/_dynamo/dynamo_bridge.py index 596b2c5efca..c0599f0ad96 100644 --- a/torch_xla/_dynamo/dynamo_bridge.py +++ b/torch_xla/_dynamo/dynamo_bridge.py @@ -149,6 +149,8 @@ def _maybe_move_tensors_to_device(tensors: tuple, target_device: torch.device) -> tuple: assert target_device, "Moving tensors to None device not supported" + device_id = None + moved_tensors = [] for tensor in tensors: if not isinstance(tensor, torch.Tensor): @@ -159,17 +161,19 @@ def _maybe_move_tensors_to_device(tensors: tuple, moved_tensors.append(tensor) continue - if dynamo_debug: - print("Moving Tensor {} to device {}".format(tensor, target_device)) + # if dynamo_debug: + # print("Moving Tensor {} to device {}".format(tensor, target_device)) zero_copy_enabled = xu.getenv_as(xenv.ZERO_COPY_ENABLED, bool, defval=False) if zero_copy_enabled and tensor.device.type == 'cuda' and target_device.type == 'xla': # If the input cuda tensor requires gradient, we need to call detach. Otherwise, we'd get the error "RuntimeError: Can't export tensors that require gradient, use tensor.detach()" + device_type, device_id = tensor.__dlpack_device__() moved_tensor = torch_xla_dlpack.from_dlpack(tensor.detach()) elif zero_copy_enabled and tensor.device.type == 'xla' and target_device.type == 'cuda': - # mark_step is need to make sure the pjrt buffer is valid. - xm.mark_step() moved_tensor = torch_xla_dlpack.from_xla_cuda_to_cuda(tensor) + # HACK: The `torch_xla._XLAC._get_stream_for_cuda_device` requires a local device index, while the device index for xla tensors is always 0. + # Meanwhile, dlpack uses the actual device index, so we use the device index of the converted CUDA tensor. + device_id = moved_tensor.device.index else: # Have to move to CPU before moving it to target device. cpu_device: torch.device = torch.device("cpu") @@ -181,6 +185,17 @@ def _maybe_move_tensors_to_device(tensors: tuple, moved_tensor.requires_grad = tensor.requires_grad moved_tensors.append(moved_tensor) + if zero_copy_enabled and device_id is not None: + stream = torch_xla._XLAC._get_stream_for_cuda_device(device_id) + stream = 1 if stream == 0 else stream + assert stream is None or type(stream) is int + external_stream = torch.cuda.ExternalStream(stream) + current_stream = torch.cuda.current_stream() + if external_stream != current_stream: + event = torch.cuda.Event() + event.record(current_stream) + external_stream.wait_event(event) + return tuple(moved_tensors) @@ -253,17 +268,15 @@ class SpecialReturnHandler: def __init__(self, trace_inputs, trace_outputs, trace_inputs_inplace_update_bool, constant_outputs_and_indexes): - self.trace_inputs = trace_inputs - self.trace_outputs = trace_outputs self.constant_outputs_and_indexes = constant_outputs_and_indexes # dedup the traced outputs first self.deduper = Deduper() - self.deduped_trace_outputs = self.deduper.dedup(self.trace_outputs) + self.deduped_trace_outputs = self.deduper.dedup(trace_outputs) # record the output that is also a input trace_inputs_id2pos = { - id(x): pos for pos, x in enumerate(self.trace_inputs) + id(x): pos for pos, x in enumerate(trace_inputs) } self.trace_outputs_pos_to_inputs_pos = [] for out_pos, out in enumerate(self.deduped_trace_outputs): @@ -466,11 +479,14 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule, with alias_with_buffer_donor_config() as saved_config: # calculate graph hash - graph_hash = torch_xla._XLAC._get_graph_hash(args_and_out_tensor_only) + if len(args_and_out_tensor_only) == 0: + graph_hash = None + else: + graph_hash = torch_xla._XLAC._get_graph_hash(args_and_out_tensor_only) + # compiles and cache graph rooted at tensors in 'args_and_out_tensor_only' + torch_xla._XLAC._xla_warm_up_cache(args_and_out_tensor_only, []) if dynamo_debug: print("Graph Hash: ", graph_hash) - # compiles and cache graph rooted at tensors in 'args_and_out_tensor_only' - torch_xla._XLAC._xla_warm_up_cache(args_and_out_tensor_only, []) # Restore the origional `xla_args`. Dynamo passed the real tensor as # `xla_args`` and we performend the tracing on them. During the tracing, @@ -490,10 +506,17 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule, # mistakenlly update the input tensors. torch_xla._XLAC._clear_pending_irs(str(xm.xla_device())) - vars_to_return = (xla_args_sharding_spec, args_and_out, graph_hash, + xla_args_dtype = [] + for arg in xla_args: + if isinstance(arg, torch.Tensor): + xla_args_dtype.append(arg.dtype) + else: + xla_args_dtype.append(None) + + vars_to_return = (xla_args_sharding_spec, len(args_and_out), graph_hash, arg_index_to_need_update_index, none_remover, graph_input_matcher, special_return_handler, - xla_args_need_update) + xla_args_need_update, xla_args_dtype) # populate the cache sym_constants_to_graph_vars[sym_constants] = vars_to_return @@ -523,10 +546,10 @@ def extract_internal(xla_model: torch.fx.GraphModule): sym_constants_to_graph_vars: Dict[Tuple[Union[int, float], ...], Tuple[Any, ...]] = {} - (xla_args_sharding_spec, args_and_out, graph_hash, + (xla_args_sharding_spec, len_args_and_out, graph_hash, arg_index_to_need_update_index, none_remover, graph_input_matcher, special_return_handler, - xla_args_need_update) = extract_graph_helper(xla_model, + xla_args_need_update, xla_args_dtype) = extract_graph_helper(xla_model, sym_constants_to_graph_vars) skip_checking_input_sharding_threashold = xu.getenv_as( 'XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD', int, 5) @@ -535,30 +558,42 @@ def optimized_mod(*args: tuple): nonlocal xla_model nonlocal skip_checking_input_sharding_threashold nonlocal sym_constants_to_graph_vars + nonlocal graph_hash + + if graph_hash is None: + return xla_model(*args) + + original_device: torch.device = _get_input_arg_device(args) + is_cuda_args: bool = False + if original_device: + is_cuda_args = original_device.type == "cuda" + # See [Note: Dynamo real-time input-shape cache look-up] above. xla_args_tensor_only, sym_constants = _split_xla_args_tensor_sym_constant( args) if sym_constants in sym_constants_to_graph_vars: - (xla_args_sharding_spec, args_and_out, graph_hash, + (xla_args_sharding_spec, len_args_and_out, graph_hash, arg_index_to_need_update_index, none_remover, graph_input_matcher, special_return_handler, - xla_args_need_update) = sym_constants_to_graph_vars[sym_constants] + xla_args_need_update, xla_args_dtype) = sym_constants_to_graph_vars[sym_constants] else: xla_model.xla_args = args - (xla_args_sharding_spec, args_and_out, graph_hash, + (xla_args_sharding_spec, len_args_and_out, graph_hash, arg_index_to_need_update_index, none_remover, graph_input_matcher, - special_return_handler, xla_args_need_update) = extract_graph_helper( + special_return_handler, xla_args_need_update, xla_args_dtype) = extract_graph_helper( xla_model, sym_constants_to_graph_vars) + if hasattr(xla_model, 'xla_args'): + delattr(xla_model, 'xla_args') - original_device: torch.device = _get_input_arg_device(args) - is_cuda_args: bool = False - if original_device: - is_cuda_args = original_device.type == "cuda" - + args = list(args) + for index, arg in enumerate(args): + if isinstance(arg, torch.Tensor) and arg.dtype != xla_args_dtype[index]: + args[index] = arg.to(xla_args_dtype[index]) if is_cuda_args: args = _maybe_move_tensors_to_device(args, xm.xla_device()) - + xla_args_tensor_only, sym_constants = _split_xla_args_tensor_sym_constant( + args) if not config.skip_input_data_check: # mark_step needs to be blocking since we want to access args's XLADatas # and they can't be placeholder. @@ -593,15 +628,16 @@ def optimized_mod(*args: tuple): else: skip_checking_input_sharding_threashold -= 1 - if len(args_and_out) == 0: + if len_args_and_out == 0: return () # graph input should be tensor only graph_input = graph_input_matcher(xla_args_tensor_only) res = torch_xla._XLAC._run_cached_graph(graph_hash, graph_input) + xm.wait_device_ops() res = special_return_handler.addDumbReturn(xla_args_tensor_only, res) - assert len(res) == len(args_and_out), f"{len(res)} v.s. {len(args_and_out)}" + assert len(res) == len_args_and_out, f"{len(res)} v.s. {len_args_and_out}" ncopy = 0 for arg_index, res_index in arg_index_to_need_update_index.items(): @@ -611,14 +647,22 @@ def optimized_mod(*args: tuple): result = res[len(xla_args_need_update):] none_remover.add_nones(result) - if is_cuda_args: - result = _maybe_move_tensors_to_device(tuple(result), original_device) + + # TODO: better fix this, input is not cuda tensor, output is cuda tensor + # if is_cuda_args: + original_device = torch.device(torch.cuda.current_device()) + result = _maybe_move_tensors_to_device(tuple(result), original_device) if len(result) == 1: return result[0] else: return result + if hasattr(xla_model, 'xla_args'): + delattr(xla_model, 'xla_args') + + torch_xla._XLAC._clear_pending_irs(str(xm.xla_device())) + if dynamo_debug: print( '=================== OpenXLA Dynamo Compile Debug End =====================\n' @@ -652,6 +696,14 @@ def all_tensors_on_xla_device(value): # Not a tensor nor a container. return True + def have_any_tensor(value): + if isinstance(value, torch.Tensor): + return True + if isinstance(value, (list, tuple)): + return any(have_any_tensor(v) for v in value) + # Not a tensor nor a container. + return False + # Check whether the current node is supported or not. # # A supported node has the following characteristics: @@ -668,8 +720,14 @@ def all_tensors_on_xla_device(value): # If the current node is NOT supported, we add it to # the _unsupported_nodes list. + result_have_tensor = have_any_tensor(result) + args_have_tensor = any( + have_any_tensor(v) + for v in itertools.chain(args, kwargs.values())) if not (result_is_supported and args_are_supported): self._unsupported_nodes.append(n) + elif not (result_have_tensor or args_have_tensor): + self._unsupported_nodes.append(n) # Restore this metric counter torch_xla._XLAC._xla_increment_counter( @@ -715,12 +773,47 @@ def allow_cpu_device(self, node: torch.fx.Node): device = node.kwargs.get("device") return (device is not None and device.type == self.target) + def move_cuda_to_xla(self, graph: torch.fx.Graph): + constructors = [] + for node in graph.nodes: + device = node.kwargs.get("device") + if device is None or device.type != "cuda": + continue + + constructors.append(node) + + for node in constructors: + kwargs = node.kwargs.copy() + kwargs["device"] = self.target + node.kwargs = kwargs + + def move_xla_to_cuda(self, graph: torch.fx.Graph): + constructors = [] + for node in graph.nodes: + device = node.kwargs.get("device") + if device is None or device != self.target: + continue + constructors.append(node) + + for node in constructors: + kwargs = node.kwargs.copy() + kwargs["device"] = "cuda" + node.kwargs = kwargs + + + def __call__(self, graph: torch.fx.Graph, move_xla_to_cuda=False) -> None: + if move_xla_to_cuda: + self.move_xla_to_cuda(graph) + else: + self.move_cuda_to_xla(graph) + super().__call__(graph) + def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args): torch_xla._XLAC._xla_increment_counter('DynamoExtractCompiledGraph', 1) - with torch_xla.experimental.eager_mode_context(False): - return extract_compiled_graph_helper(xla_model, xla_args) + # with torch_xla.experimental.eager_mode_context(False): + return extract_compiled_graph_helper(xla_model, xla_args) def _clear_pending_irs_on_args(args_tensor_only, cloned_args): @@ -791,6 +884,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: node.replace_all_uses_with(new_node) partitioned_graph.graph.erase_node(node) + XLAConstructorMoverPass()(partitioned_graph.graph, move_xla_to_cuda=True) partitioned_graph.recompile() return partitioned_graph diff --git a/torch_xla/csrc/dl_convertor.cpp b/torch_xla/csrc/dl_convertor.cpp index b6e0d23655a..b7fd66af73c 100644 --- a/torch_xla/csrc/dl_convertor.cpp +++ b/torch_xla/csrc/dl_convertor.cpp @@ -138,7 +138,7 @@ DLManagedTensor* toDLPack(const at::Tensor& input) { auto external_ref = pjrt_buffer->AcquireExternalReference(); XLA_CHECK_OK(external_ref.status()); pack->external_reference = std::move(external_ref.value()); - XLA_CHECK_OK(pjrt_buffer->GetReadyFuture().Await()); + // XLA_CHECK_OK(pjrt_buffer->GetReadyFuture().Await()); } pack->buffer_reference = pjrt_buffer; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index 8f607a51403..b5851e3c84b 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -110,7 +110,7 @@ class PjRtComputationClient : public ComputationClient { xla::PjRtLocalDeviceId(local_device_id)); XLA_CHECK(pjrt_device.ok()) << "Failed to get a PjRt device."; absl::StatusOr stream = - pjrt_device.value()->GetStreamForExternalReadyEvents(); + pjrt_device.value()->GetLocalComputeStream(); XLA_CHECK(stream.ok()) << "Failed to get a stream."; return stream.value(); } diff --git a/torch_xla/utils/dlpack.py b/torch_xla/utils/dlpack.py index d66bafe749d..d4066d60fc8 100644 --- a/torch_xla/utils/dlpack.py +++ b/torch_xla/utils/dlpack.py @@ -15,7 +15,8 @@ def from_dlpack(ext_tensor: Any): ext_tensor, '__dlpack__'): device_type, device_id = ext_tensor.__dlpack_device__() if device_type == DLDeviceType.kDLGPU: - stream = torch_xla._XLAC._get_stream_for_cuda_device(device_id) + # stream = torch_xla._XLAC._get_stream_for_cuda_device(device_id) + stream = None dlpack = ext_tensor.__dlpack__(stream=stream) else: dlpack = ext_tensor.__dlpack__() @@ -37,16 +38,16 @@ def from_xla_cuda_to_cuda(tensor): # https://github.com/pytorch/pytorch/blob/b0ef363972203b163cddc95e4c6054b8221c2300/torch/utils/dlpack.py#L114-L115 # The array API specify that the default legacy stream must be passed # with a value of 1 for CUDA - device_id = tensor.device.index - stream = torch_xla._XLAC._get_stream_for_cuda_device(device_id) - stream = 1 if stream == 0 else stream - assert stream is None or type(stream) is int - external_stream = torch.cuda.ExternalStream(stream) - current_stream = torch.cuda.current_stream() - if external_stream != current_stream: - event = torch.cuda.Event() - event.record(current_stream) - external_stream.wait_event(event) + # device_id = tensor.device.index + # stream = torch_xla._XLAC._get_stream_for_cuda_device(device_id) + # stream = 1 if stream == 0 else stream + # assert stream is None or type(stream) is int + # external_stream = torch.cuda.ExternalStream(stream) + # current_stream = torch.cuda.current_stream() + # if external_stream != current_stream: + # event = torch.cuda.Event() + # event.record(current_stream) + # external_stream.wait_event(event) dlpack = to_dlpack(tensor) cuda_tensor = torch.utils.dlpack.from_dlpack(dlpack) return cuda_tensor