Skip to content

Commit

Permalink
Merge pull request #5 from payalcha/fix-global-writer-1
Browse files Browse the repository at this point in the history
Remove get_writer function and initialize writer variable in calling function
  • Loading branch information
payalcha authored Dec 11, 2024
2 parents fe48944 + db2730e commit 8e7a9c3
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,8 @@
"source": [
"from torch.utils.tensorboard import SummaryWriter\n",
"\n",
"writer = None\n",
"\n",
"def get_writer():\n",
" global writer\n",
" if not writer:\n",
" writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)\n",
"\n",
"def write_metric(node_name, task_name, metric_name, metric, round_number):\n",
" get_writer()\n",
" writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)\n",
" writer.add_scalar(\"{}/{}/{}\".format(node_name, task_name, metric_name),\n",
" metric, round_number)"
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,7 @@
from torch.utils.tensorboard import SummaryWriter


writer = None


def get_writer():
"""Create global writer object."""
global writer
if not writer:
writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)
return writer


def write_metric(node_name, task_name, metric_name, metric, round_number):
"""Write metric callback."""
writer = get_writer()
writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)
writer.add_scalar(f'{node_name}/{task_name}/{metric_name}', metric, round_number)
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,7 @@

from tensorflow.summary import SummaryWriter

writer = None

def get_writer():
"""Create global writer object."""
global writer
if not writer:
writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)
return writer

def write_metric(node_name, task_name, metric_name, metric, round_number):
"""Write metric callback."""
writer = get_writer()
writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5)
writer.add_scalar(f'{node_name}/{task_name}/{metric_name}', metric, round_number)

0 comments on commit 8e7a9c3

Please sign in to comment.