Skip to content

Commit

Permalink
More nits
Browse files Browse the repository at this point in the history
  • Loading branch information
Waino committed Nov 4, 2024
1 parent 0794dc6 commit b8e5735
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 27 deletions.
16 changes: 7 additions & 9 deletions mammoth/distributed/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ def get_name(self) -> str:

def get_module(self, model: NMTModel) -> nn.Module:
parent = model.encoder if self.side == Side.encoder else model.decoder
tw = parent[self.task_id]
return tw
transformer_wrapper = parent[self.task_id]
return transformer_wrapper

def named_parameters(self, model: NMTModel):
module = self.get_module(model)
Expand Down Expand Up @@ -137,7 +137,7 @@ def load_state_dict(self, model: NMTModel, state_dict: Dict[str, Any]):

@dataclass # type: ignore
class DistributedAttentionLayersBlock(DistributedComponent, ABC):
"""Represents a distributed AttentionLayers object from x-transformers"""
"""Represents a distributed AdaptedAttentionLayers object"""
layer_stack_index: int
xcoder_id: str

Expand Down Expand Up @@ -180,7 +180,7 @@ def load_state_dict(self, model: NMTModel, state_dict: Dict[str, Any]):

@dataclass
class DistributedEncoderAttentionLayersBlock(DistributedAttentionLayersBlock):
"""Represents a distributed encoder-side AttentionLayers object from x-transformers"""
"""Represents a distributed encoder-side AdaptedAttentionLayers"""
@property
def side(self) -> Side:
return Side.encoder
Expand All @@ -190,13 +190,12 @@ def encoder_id(self) -> str:
return self.xcoder_id

def get_module(self, model: NMTModel) -> nn.Module:
aal = model.encoder.get_attention_layers_by_xcoder_id(self.layer_stack_index, self.xcoder_id)
return aal
return model.encoder.get_attention_layers_by_xcoder_id(self.layer_stack_index, self.xcoder_id)


@dataclass
class DistributedDecoderAttentionLayersBlock(DistributedAttentionLayersBlock):
"""Represents a distributed decoder-side AttentionLayers object from x-transformers"""
"""Represents a distributed decoder-side AdaptedAttentionLayers"""
@property
def side(self) -> Side:
return Side.decoder
Expand All @@ -206,8 +205,7 @@ def decoder_id(self) -> str:
return self.xcoder_id

def get_module(self, model: NMTModel) -> nn.Module:
aal = model.decoder.get_attention_layers_by_xcoder_id(self.layer_stack_index, self.xcoder_id)
return aal
return model.decoder.get_attention_layers_by_xcoder_id(self.layer_stack_index, self.xcoder_id)


@dataclass
Expand Down
4 changes: 2 additions & 2 deletions mammoth/modules/attention_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def from_opts(cls, opts):
opts.model_dim,
opts.hidden_ab_size,
opts.ab_fixed_length,
opts.heads,
opts.ab_heads,
opts.attention_dropout[0],
opts.max_relative_positions,
opts.ab_layer_norm,
Expand Down Expand Up @@ -279,7 +279,7 @@ def forward(self, intermediate_output, encoder_output, mask=None):
def from_opts(cls, opts):
return cls(
opts.model_dim,
opts.heads,
opts.ab_heads,
opts.hidden_ab_size, # d_ff
# TODO: that list indexing things seems suspicious to me...
opts.dropout[0],
Expand Down
2 changes: 1 addition & 1 deletion mammoth/modules/transformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def from_opts(cls, opts, embeddings, is_on_top=False):
return cls(
opts.enc_layers,
opts.model_dim,
opts.heads,
opts.ab_heads,
opts.transformer_ff,
opts.dropout[0] if isinstance(opts.dropout, list) else opts.dropout,
opts.attention_dropout[0] if isinstance(opts.attention_dropout, list) else opts.attention_dropout,
Expand Down
25 changes: 13 additions & 12 deletions mammoth/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,11 +255,6 @@ def model_opts(parser):
"For more detailed information, see: "
"https://arxiv.org/pdf/1803.02155.pdf",
)
group.add(
'--heads', '-heads', type=int, default=8,
help='Number of heads for transformer self-attention. '
' Semi-obsolete: not used for x-transformers, only used for some attention bridge configuations.'
)
group.add(
"-x_transformers_opts",
"--x_transformers_opts",
Expand Down Expand Up @@ -287,7 +282,7 @@ def model_opts(parser):
'--loss_scale',
'-loss_scale',
type=float,
default=0,
default=0.0,
help="For FP16 training, the static loss scale to use. If not set, the loss scale is dynamically computed.",
)
group.add(
Expand Down Expand Up @@ -327,6 +322,11 @@ def model_opts(parser):
choices=['none', 'rmsnorm', 'layernorm'],
help="""Use layer normalization after lin, simple and feedforward bridge layers""",
)
group.add(
'--ab_heads', '-ab_heads', type=int, default=8,
help='Number of heads for transformer self-attention. '
' Semi-obsolete: not used for x-transformers, only used for some attention bridge configuations.'
)

# adapter options are in a dict "adapters", and in the corpus options
group = parser.add_argument_group("Adapters")
Expand Down Expand Up @@ -371,7 +371,7 @@ def _add_train_general_opts(parser):
help='Criteria to use for early stopping.',
)
group.add(
'--max_nan_batches', '-max_nan_batches', type=int, default=5,
'--max_nan_batches', '-max_nan_batches', type=int, default=0,
help='Number of batches that may be skipped due to loss blowout.'
)

Expand Down Expand Up @@ -524,7 +524,7 @@ def _add_train_general_opts(parser):
'--adagrad_accumulator_init',
'-adagrad_accumulator_init',
type=float,
default=0,
default=0.0,
help="Initializes the accumulator values in adagrad. "
"Mirrors the initial_accumulator_value option "
"in the tensorflow adagrad (use 0.1 for their default).",
Expand All @@ -533,7 +533,7 @@ def _add_train_general_opts(parser):
'--max_grad_norm',
'-max_grad_norm',
type=float,
default=1,
default=1.0,
help="If the norm of the gradient vector exceeds this, "
"renormalize it to have the norm equal to "
"max_grad_norm",
Expand All @@ -542,7 +542,7 @@ def _add_train_general_opts(parser):
'--weight_decay',
'-weight_decay',
type=float,
default=0,
default=0.0,
help="L2 penalty (weight decay) regularizer",
)
# FIXME, mentions LSTM
Expand Down Expand Up @@ -609,7 +609,7 @@ def _add_train_general_opts(parser):
'--average_decay',
'-average_decay',
type=float,
default=0,
default=0.0,
help="Moving average decay. "
"Set to other than 0 (e.g. 1e-4) to activate. "
"Similar to Marian NMT implementation: "
Expand Down Expand Up @@ -656,7 +656,8 @@ def _add_train_general_opts(parser):
'-learning_rate',
type=float,
default=1.0,
help="Starting learning rate. Recommended settings: sgd = TBD, adagrad = TBD, adadelta = TBD, adam = TBD",
help="Starting learning rate. ",
# "Recommended settings: sgd = TBD, adagrad = TBD, adadelta = TBD, adam = TBD",
)
group.add(
'--learning_rate_decay',
Expand Down
10 changes: 7 additions & 3 deletions mammoth/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
from mammoth.utils.statistics import Statistics


class NanLossException(Exception):
pass


def iter_on_device(iterator, device_context):
if device_context.is_gpu():
device = torch.device(f'cuda:{device_context.local_rank}')
Expand Down Expand Up @@ -494,7 +498,7 @@ def _gradient_accumulation(
try:
if loss is not None:
if torch.isnan(loss):
raise Exception('Loss blowout')
raise NanLossException('Loss blowout')
# loss /= normalization
self.optim.backward(loss)

Expand All @@ -516,12 +520,12 @@ def _gradient_accumulation(
total_stats.update(batch_stats)
report_stats.update(batch_stats)
report_stats.update_task_loss(batch_stats.loss, metadata)
except Exception:
except NanLossException:
traceback.print_exc()
logger.info("At step %d, we removed a batch - accum %d", self.optim.training_step, k)
self.nan_batches += 1
if self.nan_batches >= self.max_nan_batches:
raise Exception('Exceeded allowed --max_nan_batches.')
raise NanLossException('Exceeded allowed --max_nan_batches.')

if len(seen_comm_batches) != 1:
logger.warning('Communication batches out of synch with batch accumulation')
Expand Down

0 comments on commit b8e5735

Please sign in to comment.