Skip to content

Commit

Permalink
fix(lint): fix lint warning
Browse files Browse the repository at this point in the history
  • Loading branch information
volchyt2024 committed Jul 5, 2024
1 parent 00abc79 commit 3eaecd1
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 5 deletions.
3 changes: 2 additions & 1 deletion example/wide_n_deep/follower.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ def model_fn(model, features, labels, mode):
train_op = model.minimize(
optimizer, act1_f, grad_loss=gact1_f, global_step=global_step)
final_ops = final_fn(model=model, tensor_name='reflux_embedding',
is_send=False, assignee=peer_embeddings, shape=[num_slot,fid_size,embed_size])
is_send=False, assignee=peer_embeddings,
shape=[num_slot, fid_size, embed_size])
embedding_hook = tf.train.FinalOpsHook(final_ops=final_ops)
return model.make_spec(mode, loss=tf.math.reduce_mean(act1_f),
training_chief_hooks=[embedding_hook],
Expand Down
5 changes: 3 additions & 2 deletions example/wide_n_deep/leader.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def final_fn(model, tensor_name, is_send, tensor=None, shape=None):
if is_send:
assert tensor, "Please specify tensor to send"
if DEBUG_PRINT:
ops.append(tf.print(tensor))
ops.append(tf.print(tensor))
ops.append(model.send_no_deps(tensor_name, tensor))
return ops

Expand Down Expand Up @@ -161,7 +161,8 @@ def model_fn(model, features, labels, mode):
{"loss" : loss}, every_n_iter=10)
metric_hook = flt.GlobalStepMetricTensorHook(tensor_dict={"loss": loss},
every_steps=10)
final_ops = final_fn(model=model, tensor_name='reflux_embedding',is_send=True,tensor=embeddings)
final_ops = final_fn(model=model, tensor_name='reflux_embedding',
is_send=True, tensor=embeddings)
embedding_hook = tf.train.FinalOpsHook(final_ops=final_ops)

optimizer = tf.train.GradientDescentOptimizer(0.1)
Expand Down
3 changes: 2 additions & 1 deletion fedlearner/trainer/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ def send_no_deps(self, name, tensor):
self._sends.append((name, tensor, False))
return send_op

def recv_no_deps(self, name, dtype=tf.float32, require_grad=False, shape=None):
def recv_no_deps(self,
name, dtype=tf.float32, require_grad=False, shape=None):
receive_op = self._bridge.receive_op(name, dtype)
if shape:
receive_op = tf.ensure_shape(receive_op, shape)
Expand Down
2 changes: 1 addition & 1 deletion fedlearner/trainer/run_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,4 +243,4 @@ def _parse_op_label(self, label):
inputs = []
else:
inputs = inputs.split(', ')
return nn, op, inputs
return nn, op, inputs

0 comments on commit 3eaecd1

Please sign in to comment.