Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
mvpatel2000 committed Aug 9, 2024
1 parent 1b7ba47 commit 6da77c1
Show file tree
Hide file tree
Showing 8 changed files with 7 additions and 30 deletions.
1 change: 0 additions & 1 deletion composer/callbacks/memory_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from typing import Optional, Union

import torch.cuda
from packaging import version

from composer import State
from composer.core import Callback, State, Time, TimeUnit
Expand Down
1 change: 0 additions & 1 deletion composer/callbacks/oom_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from typing import Optional

import torch.cuda
from packaging import version

from composer.core import Callback, State
from composer.loggers import Logger
Expand Down
12 changes: 4 additions & 8 deletions composer/distributed/dist_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

"""Helpers for running distributed data parallel training."""

import collections
import logging
import warnings
from contextlib import contextmanager, nullcontext
Expand All @@ -15,18 +14,17 @@
CheckpointImpl,
apply_activation_checkpointing,
checkpoint_wrapper,
offload_wrapper,
)
from torch.distributed.fsdp.wrap import CustomPolicy
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import offload_wrapper
from torch.distributed.fsdp import FullyShardedDataParallel, ShardingStrategy
from torch.distributed.fsdp._common_utils import clean_tensor_name
from torch.distributed.fsdp.wrap import CustomPolicy
from torch.nn.parallel import DistributedDataParallel
from torchmetrics import Metric, MetricCollection

from composer.core import Precision, State
from composer.core.precision import _validate_precision
from composer.devices import Device, DeviceGPU
from composer.distributed.meta_safe_apply import meta_safe_apply
from composer.distributed.mosaic_parallelism import (
BACKWARD_PREFETCH_MAP,
SHARDING_MAP,
Expand Down Expand Up @@ -431,7 +429,7 @@ def _param_init_fn(module: torch.nn.Module) -> None:
# It is assumed that whatever process moved the parameters off of meta device initialized them.
# We expect this to occur if we have tied weights, as the second module will already have the weights initialized.
is_meta = any(param.is_meta for param in module.parameters(recurse=False)
) or any(buffer.is_meta for buffer in module.buffers(recurse=False))
) or any(buffer.is_meta for buffer in module.buffers(recurse=False))
if not is_meta:
return

Expand Down Expand Up @@ -543,9 +541,7 @@ def lambda_fn(module: torch.nn.Module) -> Union[bool, dict]:
try:
import transformer_engine.pytorch as te
except ModuleNotFoundError:
raise ModuleNotFoundError(
'Please install transformer-engine to use TE checkpoint wrapper',
)
raise ModuleNotFoundError('Please install transformer-engine to use TE checkpoint wrapper',)

# RNG state tracker for checkpointing
CUDA_RNG_STATES_TRACKER = te.distributed.CudaRNGStatesTracker()
Expand Down
1 change: 0 additions & 1 deletion composer/distributed/mosaic_parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import Any, Union

import torch
from packaging import version
from torch import distributed
from torch.distributed import ProcessGroup
from torch.distributed.fsdp import (
Expand Down
3 changes: 1 addition & 2 deletions composer/profiler/torch_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@

import torch.cuda
import torch.profiler
from packaging import version
from torch.profiler.profiler import ProfilerAction as TorchProfilerAction

from composer.core.callback import Callback
from composer.loggers import Logger
from composer.profiler.profiler_action import ProfilerAction
from composer.profiler.utils import export_memory_timeline_html
from composer.utils import (
FORMAT_NAME_WITH_DIST_AND_TIME_TABLE,
FORMAT_NAME_WITH_DIST_TABLE,
Expand All @@ -27,7 +27,6 @@
format_name_with_dist,
format_name_with_dist_and_time,
)
from composer.profiler.utils import export_memory_timeline_html

if TYPE_CHECKING:
from composer.core import State
Expand Down
5 changes: 1 addition & 4 deletions composer/profiler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,10 @@
from tempfile import NamedTemporaryFile
from typing import Any, Optional, Union

from torch.profiler._memory_profiler import _CATEGORY_TO_COLORS, _CATEGORY_TO_INDEX, MemoryProfileTimeline

import numpy as np
import torch
import torch.cuda
from packaging import version
from torch.profiler._memory_profiler import _CATEGORY_TO_COLORS, _CATEGORY_TO_INDEX, MemoryProfileTimeline
from torch.profiler.profiler import profile as TorchProfile

log = logging.getLogger(__name__)
Expand All @@ -31,7 +29,6 @@ def export_memory_timeline_html(
return_fig: bool = False,
) -> Optional[Union[None, Any]]:
"""Exports a memory timeline to an HTML file. Similar to the PyTorch plotting function, but with adjusted axis tickers and grids."""

# Default to device 0, if unset. Fallback on cpu.
if device is None and prof.use_device and prof.use_device != 'cuda':
device = prof.use_device + ':0'
Expand Down
8 changes: 0 additions & 8 deletions composer/trainer/_patch_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
"""PyTorch, especially PyTorch Distributed, monkeypatches."""

import logging
import math
import functools
import contextlib
from dataclasses import asdict
Expand All @@ -20,16 +19,9 @@


import torch
import torch.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
import torch.nn as nn
import torch.nn.functional as F
from packaging import version
from torch.distributed._shard.sharding_spec import ShardMetadata
from torch.distributed._shard.sharding_spec._internals import get_chunked_dim_size, get_split_size
from torch.distributed.fsdp import FullyShardedDataParallel, ShardingStrategy
from torch.distributed.fsdp._fsdp_extensions import _ext_pre_load_state_dict_transform
from torch.distributed.utils import _replace_by_prefix

from composer.utils import dist

Expand Down
6 changes: 1 addition & 5 deletions composer/utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,8 @@
import torch
import torch.distributed as dist
import torch.utils.data
from packaging import version

from composer.utils.device import get_device, is_hpu_installed, is_xla_installed

if is_xla_installed():
import torch_xla
from composer.utils.device import get_device, is_hpu_installed

if TYPE_CHECKING:
from composer.devices import Device
Expand Down

0 comments on commit 6da77c1

Please sign in to comment.