From 2668a6e24abf5cf17bb3bf24012115ca70d50c89 Mon Sep 17 00:00:00 2001 From: lzy-dev Date: Wed, 27 Nov 2024 17:00:41 +0800 Subject: [PATCH] remove bi150 code --- .../core/extensions/transformer_engine.py | 19 ++++--------------- 1 file changed, 4 insertions(+), 15 deletions(-) diff --git a/megatron/megatron/core/extensions/transformer_engine.py b/megatron/megatron/core/extensions/transformer_engine.py index 0b5bc938..51c7344c 100644 --- a/megatron/megatron/core/extensions/transformer_engine.py +++ b/megatron/megatron/core/extensions/transformer_engine.py @@ -957,13 +957,8 @@ def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): TEColumnParallelGroupedLinear = None TERowParallelGroupedLinear = None -# To compatible with BI150 -try: - from transformer_engine.common.recipe import DelayedScaling - DelayedScalingBaseClass = DelayedScaling -except ImportError: - DelayedScalingBaseClass = object -class TEDelayedScaling(DelayedScalingBaseClass): + +class TEDelayedScaling(te.common.recipe.DelayedScalingBaseClass): """ Wrapper for the Transformer-Engine's `DelayedScaling` layer. """ @@ -991,14 +986,8 @@ def __init__( override_linear_precision=override_linear_precision, **extra_kwargs, ) -# To compatible with BI150 -try: - from transformer_engine.pytorch.distributed import CudaRNGStatesTracker - CudaRNGStatesTrackerBaseClass = CudaRNGStatesTracker -except ImportError: - from transformer_engine.pytorch.dist import CudaRNGStatesTracker - CudaRNGStatesTrackerBaseClass = CudaRNGStatesTracker -class TECudaRNGStatesTracker(CudaRNGStatesTrackerBaseClass): + +class TECudaRNGStatesTracker(te.pytorch.distributed.CudaRNGStatesTrackerBaseClass): """Wraps TransformerEngine's CudaRNGStatesTracker so that it is interchangeable with Megatron's RNG tracker"""