Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Optimize the sync overhead of DLPack and fix some bugs of dynamo+xla #27

Open
wants to merge 3 commits into
base: acc
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 125 additions & 31 deletions torch_xla/_dynamo/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ def _maybe_move_tensors_to_device(tensors: tuple,
target_device: torch.device) -> tuple:
assert target_device, "Moving tensors to None device not supported"

device_id = None

moved_tensors = []
for tensor in tensors:
if not isinstance(tensor, torch.Tensor):
Expand All @@ -159,17 +161,19 @@ def _maybe_move_tensors_to_device(tensors: tuple,
moved_tensors.append(tensor)
continue

if dynamo_debug:
print("Moving Tensor {} to device {}".format(tensor, target_device))
# if dynamo_debug:
# print("Moving Tensor {} to device {}".format(tensor, target_device))

zero_copy_enabled = xu.getenv_as(xenv.ZERO_COPY_ENABLED, bool, defval=False)
if zero_copy_enabled and tensor.device.type == 'cuda' and target_device.type == 'xla':
# If the input cuda tensor requires gradient, we need to call detach. Otherwise, we'd get the error "RuntimeError: Can't export tensors that require gradient, use tensor.detach()"
device_type, device_id = tensor.__dlpack_device__()
moved_tensor = torch_xla_dlpack.from_dlpack(tensor.detach())
elif zero_copy_enabled and tensor.device.type == 'xla' and target_device.type == 'cuda':
# mark_step is need to make sure the pjrt buffer is valid.
xm.mark_step()
moved_tensor = torch_xla_dlpack.from_xla_cuda_to_cuda(tensor)
# HACK: The `torch_xla._XLAC._get_stream_for_cuda_device` requires a local device index, while the device index for xla tensors is always 0.
# Meanwhile, dlpack uses the actual device index, so we use the device index of the converted CUDA tensor.
device_id = moved_tensor.device.index
else:
# Have to move to CPU before moving it to target device.
cpu_device: torch.device = torch.device("cpu")
Expand All @@ -181,6 +185,17 @@ def _maybe_move_tensors_to_device(tensors: tuple,
moved_tensor.requires_grad = tensor.requires_grad
moved_tensors.append(moved_tensor)

if zero_copy_enabled and device_id is not None:
stream = torch_xla._XLAC._get_stream_for_cuda_device(device_id)
stream = 1 if stream == 0 else stream
assert stream is None or type(stream) is int
external_stream = torch.cuda.ExternalStream(stream)
current_stream = torch.cuda.current_stream()
if external_stream != current_stream:
event = torch.cuda.Event()
event.record(current_stream)
external_stream.wait_event(event)

return tuple(moved_tensors)


Expand Down Expand Up @@ -253,17 +268,15 @@ class SpecialReturnHandler:

def __init__(self, trace_inputs, trace_outputs,
trace_inputs_inplace_update_bool, constant_outputs_and_indexes):
self.trace_inputs = trace_inputs
self.trace_outputs = trace_outputs
self.constant_outputs_and_indexes = constant_outputs_and_indexes

# dedup the traced outputs first
self.deduper = Deduper()
self.deduped_trace_outputs = self.deduper.dedup(self.trace_outputs)
self.deduped_trace_outputs = self.deduper.dedup(trace_outputs)

# record the output that is also a input
trace_inputs_id2pos = {
id(x): pos for pos, x in enumerate(self.trace_inputs)
id(x): pos for pos, x in enumerate(trace_inputs)
}
self.trace_outputs_pos_to_inputs_pos = []
for out_pos, out in enumerate(self.deduped_trace_outputs):
Expand Down Expand Up @@ -466,11 +479,14 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule,

with alias_with_buffer_donor_config() as saved_config:
# calculate graph hash
graph_hash = torch_xla._XLAC._get_graph_hash(args_and_out_tensor_only)
if len(args_and_out_tensor_only) == 0:
graph_hash = None
else:
graph_hash = torch_xla._XLAC._get_graph_hash(args_and_out_tensor_only)
# compiles and cache graph rooted at tensors in 'args_and_out_tensor_only'
torch_xla._XLAC._xla_warm_up_cache(args_and_out_tensor_only, [])
if dynamo_debug:
print("Graph Hash: ", graph_hash)
# compiles and cache graph rooted at tensors in 'args_and_out_tensor_only'
torch_xla._XLAC._xla_warm_up_cache(args_and_out_tensor_only, [])

# Restore the origional `xla_args`. Dynamo passed the real tensor as
# `xla_args`` and we performend the tracing on them. During the tracing,
Expand All @@ -490,10 +506,17 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule,
# mistakenlly update the input tensors.
torch_xla._XLAC._clear_pending_irs(str(xm.xla_device()))

vars_to_return = (xla_args_sharding_spec, args_and_out, graph_hash,
xla_args_dtype = []
for arg in xla_args:
if isinstance(arg, torch.Tensor):
xla_args_dtype.append(arg.dtype)
else:
xla_args_dtype.append(None)

vars_to_return = (xla_args_sharding_spec, len(args_and_out), graph_hash,
arg_index_to_need_update_index, none_remover,
graph_input_matcher, special_return_handler,
xla_args_need_update)
xla_args_need_update, xla_args_dtype)
# populate the cache
sym_constants_to_graph_vars[sym_constants] = vars_to_return

Expand Down Expand Up @@ -523,10 +546,10 @@ def extract_internal(xla_model: torch.fx.GraphModule):
sym_constants_to_graph_vars: Dict[Tuple[Union[int, float], ...],
Tuple[Any, ...]] = {}

(xla_args_sharding_spec, args_and_out, graph_hash,
(xla_args_sharding_spec, len_args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
special_return_handler,
xla_args_need_update) = extract_graph_helper(xla_model,
xla_args_need_update, xla_args_dtype) = extract_graph_helper(xla_model,
sym_constants_to_graph_vars)
skip_checking_input_sharding_threashold = xu.getenv_as(
'XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD', int, 5)
Expand All @@ -535,30 +558,42 @@ def optimized_mod(*args: tuple):
nonlocal xla_model
nonlocal skip_checking_input_sharding_threashold
nonlocal sym_constants_to_graph_vars
nonlocal graph_hash

if graph_hash is None:
return xla_model(*args)

original_device: torch.device = _get_input_arg_device(args)
is_cuda_args: bool = False
if original_device:
is_cuda_args = original_device.type == "cuda"


# See [Note: Dynamo real-time input-shape cache look-up] above.
xla_args_tensor_only, sym_constants = _split_xla_args_tensor_sym_constant(
args)
if sym_constants in sym_constants_to_graph_vars:
(xla_args_sharding_spec, args_and_out, graph_hash,
(xla_args_sharding_spec, len_args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
special_return_handler,
xla_args_need_update) = sym_constants_to_graph_vars[sym_constants]
xla_args_need_update, xla_args_dtype) = sym_constants_to_graph_vars[sym_constants]
else:
xla_model.xla_args = args
(xla_args_sharding_spec, args_and_out, graph_hash,
(xla_args_sharding_spec, len_args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
special_return_handler, xla_args_need_update) = extract_graph_helper(
special_return_handler, xla_args_need_update, xla_args_dtype) = extract_graph_helper(
xla_model, sym_constants_to_graph_vars)
if hasattr(xla_model, 'xla_args'):
delattr(xla_model, 'xla_args')

original_device: torch.device = _get_input_arg_device(args)
is_cuda_args: bool = False
if original_device:
is_cuda_args = original_device.type == "cuda"

args = list(args)
for index, arg in enumerate(args):
if isinstance(arg, torch.Tensor) and arg.dtype != xla_args_dtype[index]:
args[index] = arg.to(xla_args_dtype[index])
if is_cuda_args:
args = _maybe_move_tensors_to_device(args, xm.xla_device())

xla_args_tensor_only, sym_constants = _split_xla_args_tensor_sym_constant(
args)
if not config.skip_input_data_check:
# mark_step needs to be blocking since we want to access args's XLADatas
# and they can't be placeholder.
Expand Down Expand Up @@ -593,15 +628,16 @@ def optimized_mod(*args: tuple):
else:
skip_checking_input_sharding_threashold -= 1

if len(args_and_out) == 0:
if len_args_and_out == 0:
return ()

# graph input should be tensor only
graph_input = graph_input_matcher(xla_args_tensor_only)
res = torch_xla._XLAC._run_cached_graph(graph_hash, graph_input)
xm.wait_device_ops()
res = special_return_handler.addDumbReturn(xla_args_tensor_only, res)

assert len(res) == len(args_and_out), f"{len(res)} v.s. {len(args_and_out)}"
assert len(res) == len_args_and_out, f"{len(res)} v.s. {len_args_and_out}"
ncopy = 0

for arg_index, res_index in arg_index_to_need_update_index.items():
Expand All @@ -611,14 +647,22 @@ def optimized_mod(*args: tuple):
result = res[len(xla_args_need_update):]

none_remover.add_nones(result)
if is_cuda_args:
result = _maybe_move_tensors_to_device(tuple(result), original_device)

# TODO: better fix this, input is not cuda tensor, output is cuda tensor
# if is_cuda_args:
original_device = torch.device(torch.cuda.current_device())
result = _maybe_move_tensors_to_device(tuple(result), original_device)

if len(result) == 1:
return result[0]
else:
return result

if hasattr(xla_model, 'xla_args'):
delattr(xla_model, 'xla_args')

torch_xla._XLAC._clear_pending_irs(str(xm.xla_device()))

if dynamo_debug:
print(
'=================== OpenXLA Dynamo Compile Debug End =====================\n'
Expand Down Expand Up @@ -652,6 +696,14 @@ def all_tensors_on_xla_device(value):
# Not a tensor nor a container.
return True

def have_any_tensor(value):
if isinstance(value, torch.Tensor):
return True
if isinstance(value, (list, tuple)):
return any(have_any_tensor(v) for v in value)
# Not a tensor nor a container.
return False

# Check whether the current node is supported or not.
#
# A supported node has the following characteristics:
Expand All @@ -668,8 +720,14 @@ def all_tensors_on_xla_device(value):

# If the current node is NOT supported, we add it to
# the _unsupported_nodes list.
result_have_tensor = have_any_tensor(result)
args_have_tensor = any(
have_any_tensor(v)
for v in itertools.chain(args, kwargs.values()))
if not (result_is_supported and args_are_supported):
self._unsupported_nodes.append(n)
elif not (result_have_tensor or args_have_tensor):
self._unsupported_nodes.append(n)

# Restore this metric counter
torch_xla._XLAC._xla_increment_counter(
Expand Down Expand Up @@ -715,12 +773,47 @@ def allow_cpu_device(self, node: torch.fx.Node):
device = node.kwargs.get("device")
return (device is not None and device.type == self.target)

def move_cuda_to_xla(self, graph: torch.fx.Graph):
constructors = []
for node in graph.nodes:
device = node.kwargs.get("device")
if device is None or device.type != "cuda":
continue

constructors.append(node)

for node in constructors:
kwargs = node.kwargs.copy()
kwargs["device"] = self.target
node.kwargs = kwargs

def move_xla_to_cuda(self, graph: torch.fx.Graph):
constructors = []
for node in graph.nodes:
device = node.kwargs.get("device")
if device is None or device != self.target:
continue
constructors.append(node)

for node in constructors:
kwargs = node.kwargs.copy()
kwargs["device"] = "cuda"
node.kwargs = kwargs


def __call__(self, graph: torch.fx.Graph, move_xla_to_cuda=False) -> None:
if move_xla_to_cuda:
self.move_xla_to_cuda(graph)
else:
self.move_cuda_to_xla(graph)
super().__call__(graph)


def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args):
torch_xla._XLAC._xla_increment_counter('DynamoExtractCompiledGraph', 1)

with torch_xla.experimental.eager_mode_context(False):
return extract_compiled_graph_helper(xla_model, xla_args)
# with torch_xla.experimental.eager_mode_context(False):
return extract_compiled_graph_helper(xla_model, xla_args)


def _clear_pending_irs_on_args(args_tensor_only, cloned_args):
Expand Down Expand Up @@ -791,6 +884,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
node.replace_all_uses_with(new_node)
partitioned_graph.graph.erase_node(node)

XLAConstructorMoverPass()(partitioned_graph.graph, move_xla_to_cuda=True)
partitioned_graph.recompile()

return partitioned_graph
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/dl_convertor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ DLManagedTensor* toDLPack(const at::Tensor& input) {
auto external_ref = pjrt_buffer->AcquireExternalReference();
XLA_CHECK_OK(external_ref.status());
pack->external_reference = std::move(external_ref.value());
XLA_CHECK_OK(pjrt_buffer->GetReadyFuture().Await());
// XLA_CHECK_OK(pjrt_buffer->GetReadyFuture().Await());
}
pack->buffer_reference = pjrt_buffer;

Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/runtime/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class PjRtComputationClient : public ComputationClient {
xla::PjRtLocalDeviceId(local_device_id));
XLA_CHECK(pjrt_device.ok()) << "Failed to get a PjRt device.";
absl::StatusOr<std::intptr_t> stream =
pjrt_device.value()->GetStreamForExternalReadyEvents();
pjrt_device.value()->GetLocalComputeStream();
XLA_CHECK(stream.ok()) << "Failed to get a stream.";
return stream.value();
}
Expand Down
23 changes: 12 additions & 11 deletions torch_xla/utils/dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ def from_dlpack(ext_tensor: Any):
ext_tensor, '__dlpack__'):
device_type, device_id = ext_tensor.__dlpack_device__()
if device_type == DLDeviceType.kDLGPU:
stream = torch_xla._XLAC._get_stream_for_cuda_device(device_id)
# stream = torch_xla._XLAC._get_stream_for_cuda_device(device_id)
stream = None
dlpack = ext_tensor.__dlpack__(stream=stream)
else:
dlpack = ext_tensor.__dlpack__()
Expand All @@ -37,16 +38,16 @@ def from_xla_cuda_to_cuda(tensor):
# https://github.com/pytorch/pytorch/blob/b0ef363972203b163cddc95e4c6054b8221c2300/torch/utils/dlpack.py#L114-L115
# The array API specify that the default legacy stream must be passed
# with a value of 1 for CUDA
device_id = tensor.device.index
stream = torch_xla._XLAC._get_stream_for_cuda_device(device_id)
stream = 1 if stream == 0 else stream
assert stream is None or type(stream) is int
external_stream = torch.cuda.ExternalStream(stream)
current_stream = torch.cuda.current_stream()
if external_stream != current_stream:
event = torch.cuda.Event()
event.record(current_stream)
external_stream.wait_event(event)
# device_id = tensor.device.index
# stream = torch_xla._XLAC._get_stream_for_cuda_device(device_id)
# stream = 1 if stream == 0 else stream
# assert stream is None or type(stream) is int
# external_stream = torch.cuda.ExternalStream(stream)
# current_stream = torch.cuda.current_stream()
# if external_stream != current_stream:
# event = torch.cuda.Event()
# event.record(current_stream)
# external_stream.wait_event(event)
dlpack = to_dlpack(tensor)
cuda_tensor = torch.utils.dlpack.from_dlpack(dlpack)
return cuda_tensor
Loading