Skip to content

Commit

Permalink
fix: running yapf again with 0.32, earlier using 0.43
Browse files Browse the repository at this point in the history
  • Loading branch information
init-22 committed Dec 3, 2024
1 parent 1b587b7 commit c65d93e
Show file tree
Hide file tree
Showing 16 changed files with 54 additions and 40 deletions.
4 changes: 2 additions & 2 deletions algorithmic_efficiency/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def _make_report(
float(np.std(d)),
len(d),
float(np.sum(d)),
100.0 * float(np.sum(d)) / total_duration)
for a, d in self.recorded_durations.items()]
100.0 * float(np.sum(d)) / total_duration) for a,
d in self.recorded_durations.items()]
report.sort(key=lambda x: x[5], reverse=True)
total_calls = sum(x[3] for x in report)
return report, total_calls, total_duration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,9 @@ def _eval_model_on_split(self,
for _ in range(num_batches):
batch = next(self._eval_iters[split])
batch_metrics = self._eval_model(params, batch, model_rng)
total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()}
total_metrics = {
k: v + batch_metrics[k] for k, v in total_metrics.items()
}
if USE_PYTORCH_DDP:
for metric in total_metrics.values():
dist.all_reduce(metric)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,8 @@ def build_lut(histo, step):
# If step is zero, return the original image. Otherwise, build
# lut from the full histogram and step and then index from it.
result = tf.cond(
tf.equal(step, 0), lambda: im,
tf.equal(step, 0),
lambda: im,
lambda: tf.gather(build_lut(histo, step), im))

return tf.cast(result, tf.uint8)
Expand Down Expand Up @@ -548,6 +549,7 @@ def distort_image_with_randaugment(image, num_layers, magnitude, key):
translate_const=100)
image = tf.cond(
tf.equal(i, op_to_select),
lambda selected_func=func, selected_args=args: selected_func(
image, *selected_args), lambda: image)
lambda selected_func=func,
selected_args=args: selected_func(image, *selected_args),
lambda: image)
return image
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,9 @@ def _eval_model_on_split(self,
update_batch_norm=False)
weights = batch.get('weights')
batch_metrics = self._compute_metrics(logits, batch['targets'], weights)
total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()}
total_metrics = {
k: v + batch_metrics[k] for k, v in total_metrics.items()
}
if USE_PYTORCH_DDP:
for metric in total_metrics.values():
dist.all_reduce(metric)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ def setup(self):
self.kernel = self.param('kernel',
nn.initializers.xavier_uniform(),
self.filter_shape)
self.bias = self.param('bias', lambda rng, s: jnp.zeros(s, jnp.float32),
self.output_channels)
self.bias = self.param(
'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels)

@nn.compact
def __call__(self, inputs, paddings):
Expand Down Expand Up @@ -442,10 +442,12 @@ def setup(self):
dtype = self.config.dtype

self.ra_mean = self.variable('batch_stats',
'mean', lambda s: jnp.zeros(s, dtype),
'mean',
lambda s: jnp.zeros(s, dtype),
dim)
self.ra_var = self.variable('batch_stats',
'var', lambda s: jnp.ones(s, dtype),
'var',
lambda s: jnp.ones(s, dtype),
dim)

self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def _get_mask(self,
jnp.expand_dims(jnp.arange(multiplicity, dtype=jnp.int32), 0),
[batch_size, 1])
multiplicity_tensor = masks_per_frame * choose_range
multiplicity_weights = (multiplicity_weights
< multiplicity_tensor).astype(jnp.int32)
multiplicity_weights = (multiplicity_weights <
multiplicity_tensor).astype(jnp.int32)
pre_mask = jnp.einsum('bmt,bm->bt', pre_mask, multiplicity_weights)
else:
pre_mask = jnp.einsum('bmt->bt', pre_mask)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,8 @@ def greedy_decode(
idxs = torch.arange(
fin_result.numel(), device=result.device).view(*fin_result.shape)
mask = torch.arange(
fin_result.shape[1],
device=result.device).view(1, -1) < result.count_nonzero(dim=1).view(
-1, 1)
fin_result.shape[1], device=result.device).view(
1, -1) < result.count_nonzero(dim=1).view(-1, 1)
fin_result.view(-1)[idxs[mask != 0]] = result[result != blank_id]
padding = fin_result == 0
return fin_result, padding
Expand Down Expand Up @@ -330,7 +329,9 @@ def _eval_model_on_split(self,
'word_errors': word_errors,
'num_words': num_words,
}
total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()}
total_metrics = {
k: v + batch_metrics[k] for k, v in total_metrics.items()
}
if USE_PYTORCH_DDP:
for metric in total_metrics.values():
dist.all_reduce(metric)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ def setup(self):
self.kernel = self.param('kernel',
nn.initializers.xavier_uniform(),
self.filter_shape)
self.bias = self.param('bias', lambda rng, s: jnp.zeros(s, jnp.float32),
self.output_channels)
self.bias = self.param(
'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels)

@nn.compact
def __call__(self, inputs, paddings, train):
Expand Down Expand Up @@ -273,10 +273,12 @@ def setup(self):
dtype = self.dtype

self.ra_mean = self.variable('batch_stats',
'mean', lambda s: jnp.zeros(s, dtype),
'mean',
lambda s: jnp.zeros(s, dtype),
dim)
self.ra_var = self.variable('batch_stats',
'var', lambda s: jnp.ones(s, dtype),
'var',
lambda s: jnp.ones(s, dtype),
dim)

self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype)
Expand Down
7 changes: 5 additions & 2 deletions algorithmic_efficiency/workloads/mnist/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def _build_mnist_dataset(
ds = ds.map(
lambda x: {
'inputs': _normalize(x['image'], train_mean, train_stddev),
'targets': x['label'],})
'targets': x['label'],
})
is_train = split == 'train'

if cache:
Expand Down Expand Up @@ -213,6 +214,8 @@ def _eval_model_on_split(self,
batch,
model_state,
per_device_model_rngs)
total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()}
total_metrics = {
k: v + batch_metrics[k] for k, v in total_metrics.items()
}

return self._normalize_eval_metrics(num_examples, total_metrics)
4 changes: 2 additions & 2 deletions algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,8 +942,8 @@ def forward(self,
# not the remaining zero elements.
if attn_mask is not None:
raise ValueError('Attention mask has to be None for decode == True.')
attn_mask = (torch.arange(max_len, device=k.device)
>= cache_index).reshape(1, max_len)
attn_mask = (torch.arange(max_len, device=k.device) >=
cache_index).reshape(1, max_len)

# Update sequence length to account for complete sequence.
seq_len = k.size(1)
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ dev =
isort==5.13.2
pylint==3.3.1
pytest==8.3.3
yapf==0.43.0
yapf==0.32.0
pre-commit==4.0.1

# Workloads #
Expand Down
4 changes: 2 additions & 2 deletions submission_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,8 +377,8 @@ def train_once(
train_state['is_time_remaining'] = (
train_state['accumulated_submission_time'] < max_allowed_runtime_sec)
# Check if submission is eligible for an untimed eval.
if ((train_step_end_time - train_state['last_eval_time'])
>= workload.eval_period_time_sec or train_state['training_complete']):
if ((train_step_end_time - train_state['last_eval_time']) >=
workload.eval_period_time_sec or train_state['training_complete']):
with profiler.profile('Evaluation'):
del batch
_reset_cuda_mem()
Expand Down
6 changes: 3 additions & 3 deletions tests/modeldiffs/wmt/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ def sd_transform(sd):
out = {
tuple(
k.replace('SelfAttention', 'MultiHeadDotProductAttention')
for k in key):
value
for key, value in out.items()
for k in key): value
for key,
value in out.items()
}
elif 'Dense' in k_str:
new_key = (*k[:2], 'MlpBlock_0', *k[2:])
Expand Down
6 changes: 3 additions & 3 deletions tests/modeldiffs/wmt_attention_temp/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ def sd_transform(sd):
out = {
tuple(
k.replace('SelfAttention', 'MultiHeadDotProductAttention')
for k in key):
value
for key, value in out.items()
for k in key): value
for key,
value in out.items()
}
elif 'Dense' in k_str:
new_key = (*k[:2], 'MlpBlock_0', *k[2:])
Expand Down
6 changes: 3 additions & 3 deletions tests/modeldiffs/wmt_glu_tanh/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ def sd_transform(sd):
out = {
tuple(
k.replace('SelfAttention', 'MultiHeadDotProductAttention')
for k in key):
value
for key, value in out.items()
for k in key): value
for key,
value in out.items()
}
elif 'Dense' in k_str:
new_key = (*k[:2], 'MlpBlock_0', *k[2:])
Expand Down
6 changes: 3 additions & 3 deletions tests/modeldiffs/wmt_post_ln/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ def sd_transform(sd):
out = {
tuple(
k.replace('SelfAttention', 'MultiHeadDotProductAttention')
for k in key):
value
for key, value in out.items()
for k in key): value
for key,
value in out.items()
}
elif 'Dense' in k_str:
new_key = (*k[:2], 'MlpBlock_0', *k[2:])
Expand Down

0 comments on commit c65d93e

Please sign in to comment.