Skip to content

Commit

Permalink
remove old code
Browse files Browse the repository at this point in the history
  • Loading branch information
mvpatel2000 committed Aug 9, 2024
1 parent 13ef790 commit 1b7ba47
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 447 deletions.
8 changes: 1 addition & 7 deletions composer/callbacks/memory_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,7 @@ def __init__(
_, _, self.remote_path_in_bucket = parse_uri(remote_file_name)
else:
self.remote_path_in_bucket = None

if version.parse(torch.__version__.split('.dev')[0]) >= version.parse('2.1.0'): # type: ignore
# MemorySnapshot is only supported in torch v2.1.0-rc1 or higher
self._enabled = True
else:
self._enabled = False
warnings.warn('Memory snapshot is supported after PyTorch 2.1.0. Skipping memory snapshot callback.')
self._enabled = True

def init(self, state: State, logger: Logger) -> None:
if not self._enabled:
Expand Down
8 changes: 1 addition & 7 deletions composer/callbacks/oom_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,7 @@ def __init__(
else:
self.remote_path_in_bucket = None

if version.parse(torch.__version__.split('.dev')[0]) >= version.parse('2.1.0'): # type: ignore
# OOMObserver is only supported in torch v2.1.0 or higher
self._enabled = True
else:
self._enabled = False
warnings.warn('OOMObserver is supported after PyTorch 2.1.0. Disabling OOMObserver callback.')

self._enabled = True
self.filename_config: Optional[SnapshotFileNameConfig] = None

def init(self, state: State, logger: Logger) -> None:
Expand Down
333 changes: 104 additions & 229 deletions composer/distributed/dist_strategy.py

Large diffs are not rendered by default.

6 changes: 2 additions & 4 deletions composer/distributed/mosaic_parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,10 @@
'NO_SHARD': ShardingStrategy.NO_SHARD,
'SHARD_GRAD_OP': ShardingStrategy.SHARD_GRAD_OP,
'FULL_SHARD': ShardingStrategy.FULL_SHARD,
'_HYBRID_SHARD_ZERO2': ShardingStrategy._HYBRID_SHARD_ZERO2,
'HYBRID_SHARD': ShardingStrategy.HYBRID_SHARD,
}

if version.parse(torch.__version__) >= version.parse('2.1.0'):
SHARDING_MAP['_HYBRID_SHARD_ZERO2'] = ShardingStrategy._HYBRID_SHARD_ZERO2
SHARDING_MAP['HYBRID_SHARD'] = ShardingStrategy.HYBRID_SHARD

BACKWARD_PREFETCH_MAP = {
'NONE': None,
'BACKWARD_PRE': BackwardPrefetch.BACKWARD_PRE,
Expand Down
68 changes: 32 additions & 36 deletions composer/profiler/torch_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
format_name_with_dist,
format_name_with_dist_and_time,
)
from composer.profiler.utils import export_memory_timeline_html

if TYPE_CHECKING:
from composer.core import State
Expand Down Expand Up @@ -296,44 +297,39 @@ def handler_fn(prof: torch.profiler.profiler.profile):
f'PyTorch memory timeline profiler enabled: {self.memory_filename if self.memory_filename else False}',
)
if self.memory_filename is not None:
if version.parse(torch.__version__) > version.parse('2.1.0.dev'): # type: ignore
# memory timeline profiling is only supported in torch v2.1.0-rc1 or higher
memory_trace_file_name = os.path.join(
folder_name,
format_name_with_dist_and_time(
self.memory_filename,
run_name=state.run_name,
timestamp=timestamp,
),
memory_trace_file_name = os.path.join(
folder_name,
format_name_with_dist_and_time(
self.memory_filename,
run_name=state.run_name,
timestamp=timestamp,
),
)
log.debug(f'Saving memory trace to {memory_trace_file_name}')
memory_trace_file_dirname = os.path.dirname(memory_trace_file_name)
if memory_trace_file_dirname:
os.makedirs(memory_trace_file_dirname, exist_ok=True)
export_memory_timeline_html(
prof,
memory_trace_file_name,
torch.cuda.current_device(), # type: ignore
)
log.debug(f'Uploaded memory trace to {self.memory_remote_file_name}')
if self.memory_remote_file_name is not None:
memory_trace_remote_file_name = format_name_with_dist_and_time(
self.memory_remote_file_name,
run_name=state.run_name,
timestamp=timestamp,
)
memory_trace_remote_file_name = memory_trace_remote_file_name.lstrip('/')
log.debug(
f'Uploading memory trace to {memory_trace_remote_file_name} from {memory_trace_file_name}',
)
log.debug(f'Saving memory trace to {memory_trace_file_name}')
memory_trace_file_dirname = os.path.dirname(memory_trace_file_name)
if memory_trace_file_dirname:
os.makedirs(memory_trace_file_dirname, exist_ok=True)
from composer.profiler.utils import export_memory_timeline_html
export_memory_timeline_html(
prof,
memory_trace_file_name,
torch.cuda.current_device(), # type: ignore
logger.upload_file(
remote_file_name=memory_trace_remote_file_name,
file_path=memory_trace_file_name,
overwrite=self.overwrite,
)
log.debug(f'Uploaded memory trace to {self.memory_remote_file_name}')
if self.memory_remote_file_name is not None:
memory_trace_remote_file_name = format_name_with_dist_and_time(
self.memory_remote_file_name,
run_name=state.run_name,
timestamp=timestamp,
)
memory_trace_remote_file_name = memory_trace_remote_file_name.lstrip('/')
log.debug(
f'Uploading memory trace to {memory_trace_remote_file_name} from {memory_trace_file_name}',
)
logger.upload_file(
remote_file_name=memory_trace_remote_file_name,
file_path=memory_trace_file_name,
overwrite=self.overwrite,
)
else:
log.warning('Memory timeline is supported after PyTorch 2.1.0. Skipping memory trace.')

if self.num_traces_to_keep >= 0:
while len(self.saved_traces) > self.num_traces_to_keep:
Expand Down
7 changes: 2 additions & 5 deletions composer/profiler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from tempfile import NamedTemporaryFile
from typing import Any, Optional, Union

from torch.profiler._memory_profiler import _CATEGORY_TO_COLORS, _CATEGORY_TO_INDEX, MemoryProfileTimeline

import numpy as np
import torch
import torch.cuda
Expand All @@ -29,11 +31,6 @@ def export_memory_timeline_html(
return_fig: bool = False,
) -> Optional[Union[None, Any]]:
"""Exports a memory timeline to an HTML file. Similar to the PyTorch plotting function, but with adjusted axis tickers and grids."""
if version.parse(torch.__version__) <= version.parse('2.1.0.dev'):
log.warning('export_memory_timeline_html failed because memory timeline is supported after PyTorch 2.1.0.')
return

from torch.profiler._memory_profiler import _CATEGORY_TO_COLORS, _CATEGORY_TO_INDEX, MemoryProfileTimeline

# Default to device 0, if unset. Fallback on cpu.
if device is None and prof.use_device and prof.use_device != 'cuda':
Expand Down
158 changes: 1 addition & 157 deletions composer/trainer/_patch_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,29 +47,7 @@ def patch_unshard_for_automicrobatching(auto_microbatch_size_found=False):

def patch_pytorch():
"""Monkey patches pytorch functions based on pytorch version."""
if version.parse(torch.__version__) < version.parse('2.1.1'):
# Monkey patch for torch < 2.1.1 ie torch == 2.1.0

# Monkey patch sharding method
ChunkShardingSpec.build_metadata = build_metadata

# Monkey patch partial state dict handling
from torch.distributed.fsdp import _state_dict_utils

_state_dict_utils._sharded_pre_load_state_dict_hook = (_sharded_pre_load_state_dict_hook)

# Allow 2D HSDP
from torch.distributed.fsdp import _runtime_utils
_runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None

elif version.parse(torch.__version__) < version.parse('2.1.3'):
# Monkey patch for torch < 2.1.3 ie torch == 2.1.1, 2.1.2

# Allow 2D HSDP
from torch.distributed.fsdp import _runtime_utils
_runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None

elif version.parse(torch.__version__) < version.parse('2.2.1'):
if version.parse(torch.__version__) < version.parse('2.2.1'):
# Monkey patch for torch < 2.2.1 ie torch == 2.2.0

# Allow 2D HSDP
Expand Down Expand Up @@ -140,140 +118,6 @@ def patch_pytorch():
pass


def build_metadata(
self,
tensor_sizes: torch.Size,
tensor_properties: sharded_tensor_meta.TensorProperties,
) -> sharded_tensor_meta.ShardedTensorMetadata:
"""Adds nightly change for ChunkShardingSpec.
Change implemented in https://github.com/pytorch/pytorch/pull/108915
"""
tensor_num_dim = len(tensor_sizes)

self._verify_dim(self.dim)
if self.dim >= tensor_num_dim or self.dim < -tensor_num_dim: # type: ignore[operator]
raise ValueError(f'Invalid sharding dim: {self.dim}')

shards_metadata = []
sharding_dim_size = tensor_sizes[self.dim] # type: ignore[index]
chunks = len(self.placements)
split_size = get_split_size(sharding_dim_size, chunks)
for idx, placement in enumerate(self.placements):
# generate ShardMetadata for each placement device
chunked_dim_size = get_chunked_dim_size(sharding_dim_size, split_size, idx)
shard_size = list(tensor_sizes)
current_offsets = [0] * tensor_num_dim
current_offsets[self.dim] = split_size * idx # type: ignore[index]
shard_size[self.dim] = chunked_dim_size # type: ignore[index]

shard_metadata = ShardMetadata(
shard_offsets=current_offsets,
shard_sizes=shard_size,
placement=placement,
)
shards_metadata.append(shard_metadata)

return sharded_tensor_meta.ShardedTensorMetadata(shards_metadata, tensor_sizes, tensor_properties)


@no_type_check
def _sharded_pre_load_state_dict_hook(
module: nn.Module,
fsdp_state,
state_dict: dict[str, Any],
prefix: str,
) -> None:
"""Adds nightly change for partial state dict error handling.
https://github.com/pytorch/pytorch/blob/0511df0ee9edeb5c2613805ccfb49beb323b87f9/torch/distributed/fsdp/_state_dict_utils.py#L607-L615
The hook combines the unflattened, sharded parameters (ShardedTensor) to
a new FlatParameter and shards the new FlatParameter to the local chunk.
"""
from torch.distributed._tensor import Replicate
from torch.distributed.distributed_c10d import _get_pg_default_device
from torch.distributed.fsdp._common_utils import FSDP_PREFIX, _has_fsdp_params, _is_composable, _module_handle
from torch.distributed.fsdp._runtime_utils import _lazy_init
from torch.distributed.fsdp._state_dict_utils import _enter_unshard_params_ctx, _param_name_infos

_lazy_init(fsdp_state, module)
if not _is_composable(fsdp_state):
_replace_by_prefix(state_dict, prefix, prefix + f'{FSDP_PREFIX}')
if not _has_fsdp_params(fsdp_state, module):
return

handle = _module_handle(fsdp_state, module)
if not handle.uses_sharded_strategy: # type: ignore
raise RuntimeError(
'load_sharded_state_dict can only be called when parameters '
'are flattened and sharded.',
)

device = fsdp_state.compute_device
for fqn, _, _ in _param_name_infos(module, fsdp_state):
if not _is_composable(fsdp_state):
fqn_from_global_root = f'{prefix}{FSDP_PREFIX}{fqn}'
else:
fqn_from_global_root = f'{prefix}{fqn}'
try:
param = state_dict.pop(fqn_from_global_root)
except KeyError:
log.warning(
f'Did not find param with FQN {fqn_from_global_root}, skipping it. ' # noqa: G004
'The weight will not be filled if you expect it to be.',
)
continue # TODO: Improve unittesting for state_dict finetuning
# cases: https://github.com/pytorch/pytorch/issues/109134

if not fsdp_state._state_dict_config.use_dtensor:
# All-gather the param (ShardedTensor)
param, shards = _ext_pre_load_state_dict_transform(param)

assert len(shards) < 2, (
'Expects 0 or 1 shard per rank '
f'but got {len(shards)} shards on rank {fsdp_state.rank}.'
)
param_numel = param.size().numel()
dim_0_size = param.size()[0]
chunk_size = (math.ceil(dim_0_size / fsdp_state.world_size) * param_numel // dim_0_size)
if len(shards) == 1:
local_tensor = shards[0].tensor.flatten()
pg_device = _get_pg_default_device(fsdp_state.process_group)
if local_tensor.device.type != pg_device.type:
local_tensor = local_tensor.to(pg_device)
num_padding = chunk_size - local_tensor.numel()
if num_padding > 0:
local_tensor = F.pad(local_tensor, [0, num_padding])
else:
local_tensor = torch.zeros(chunk_size, dtype=param.dtype, device=device)
tensor = torch.empty(
chunk_size * fsdp_state.world_size,
dtype=local_tensor.dtype,
device=device,
)
if local_tensor.is_cpu:
# Tensor could be on FSDP GPU compute device, while local_tensor is on CPU.
# Convert to CPU so all_gather can work.
tensor_dev = tensor.device
tensor = tensor.cpu()
tensor_list = list(torch.chunk(tensor, torch.distributed.get_world_size(fsdp_state.process_group)))
torch.distributed.all_gather(tensor_list, local_tensor, group=fsdp_state.process_group)
tensor.to(tensor_dev)
else:
torch.distributed.all_gather_into_tensor(tensor, local_tensor, group=fsdp_state.process_group)
tensor = tensor.narrow(0, 0, param_numel).reshape(param.size())
state_dict[fqn_from_global_root] = tensor
else:
if param.device != fsdp_state._device_mesh.device_type: # type: ignore
param = param.to(fsdp_state._device_mesh.device_type) # type: ignore

param = param.redistribute(device_mesh=param.device_mesh, placements=[Replicate()])
state_dict[fqn_from_global_root] = param.to_local()

_enter_unshard_params_ctx(module, fsdp_state, writeback=True)


if version.parse(torch.__version__) >= version.parse('2.2.1') and version.parse(
torch.__version__,) < version.parse('2.2.3'):

Expand Down
2 changes: 0 additions & 2 deletions composer/utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,8 +579,6 @@ def initialize_dist(device: Union[str, Device], timeout: float = 300.0) -> None:
'PyTorch XLA package not found. In order to use XLA based devices '
'PyTorch XLA must be installed.',
)
if version.parse(torch_xla.__version__) < version.parse('2.1.0'):
raise RuntimeError(f'PyTorch XLA version must be at least 2.1.0, found {torch_xla.__version__}.')
# XLA initialization requires the init_method to be set
dist.init_process_group(device_obj.dist_backend, init_method='xla://')
elif dist_env_vars_match_defaults:
Expand Down

0 comments on commit 1b7ba47

Please sign in to comment.