Skip to content

Commit

Permalink
remove bi150 code
Browse files Browse the repository at this point in the history
  • Loading branch information
lzy-dev committed Nov 27, 2024
1 parent 6e39154 commit 2668a6e
Showing 1 changed file with 4 additions and 15 deletions.
19 changes: 4 additions & 15 deletions megatron/megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,13 +957,8 @@ def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
TEColumnParallelGroupedLinear = None
TERowParallelGroupedLinear = None

# To compatible with BI150
try:
from transformer_engine.common.recipe import DelayedScaling
DelayedScalingBaseClass = DelayedScaling
except ImportError:
DelayedScalingBaseClass = object
class TEDelayedScaling(DelayedScalingBaseClass):

class TEDelayedScaling(te.common.recipe.DelayedScalingBaseClass):
"""
Wrapper for the Transformer-Engine's `DelayedScaling` layer.
"""
Expand Down Expand Up @@ -991,14 +986,8 @@ def __init__(
override_linear_precision=override_linear_precision,
**extra_kwargs,
)
# To compatible with BI150
try:
from transformer_engine.pytorch.distributed import CudaRNGStatesTracker
CudaRNGStatesTrackerBaseClass = CudaRNGStatesTracker
except ImportError:
from transformer_engine.pytorch.dist import CudaRNGStatesTracker
CudaRNGStatesTrackerBaseClass = CudaRNGStatesTracker
class TECudaRNGStatesTracker(CudaRNGStatesTrackerBaseClass):

class TECudaRNGStatesTracker(te.pytorch.distributed.CudaRNGStatesTrackerBaseClass):
"""Wraps TransformerEngine's CudaRNGStatesTracker so that it is
interchangeable with Megatron's RNG tracker"""

Expand Down

0 comments on commit 2668a6e

Please sign in to comment.