Skip to content

Commit

Permalink
Review comments fix and adding process details in output json
Browse files Browse the repository at this point in the history
  • Loading branch information
payalcha committed Nov 19, 2024
1 parent 37caa75 commit 1a22ee4
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 29 deletions.
4 changes: 2 additions & 2 deletions openfl-workspace/torch_cnn_mnist/plan/plan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ aggregator :
best_state_path : save/torch_cnn_mnist_best.pbuf
last_state_path : save/torch_cnn_mnist_last.pbuf
rounds_to_train : 10
memleak_check : true
log_memory_usage : true
log_metric_callback :
template : src.utils.write_metric

Expand All @@ -20,7 +20,7 @@ collaborator :
settings :
delta_updates : false
opt_treatment : RESET
memleak_check : true
log_memory_usage : true

data_loader :
defaults : plan/defaults/data_loader.yaml
Expand Down
22 changes: 10 additions & 12 deletions openfl/component/aggregator/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(
compression_pipeline=None,
db_store_rounds=1,
write_logs=False,
memleak_check=False,
log_memory_usage=False,
log_metric_callback=None,
**kwargs,
):
Expand Down Expand Up @@ -126,7 +126,7 @@ def __init__(
)
self._end_of_round_check_done = [False] * rounds_to_train
self.stragglers = []
self.memleak_check = memleak_check
self.log_memory_usage = log_memory_usage
self.memory_details = []
self.rounds_to_train = rounds_to_train

Expand Down Expand Up @@ -673,8 +673,7 @@ def send_local_task_results(
self._end_of_round_with_stragglers_check()

def get_memory_usage(self, round_number, metric_origin):
"""
Logs the memory usage statistics for the given round number.
"""Logs the memory usage statistics for the given round number.
This method retrieves the current virtual and swap memory usage statistics
using the psutil library, formats them into a dictionary, and logs the
Expand All @@ -683,11 +682,14 @@ def get_memory_usage(self, round_number, metric_origin):
Args:
round_number (int): The current round number for which memory usage is being logged.
"""
process = psutil.Process()
self.logger.info(f"{metric_origin} process id is {process}")
virtual_memory = psutil.virtual_memory()
swap_memory = psutil.swap_memory()
memory_usage = {
"round_number": round_number,
"metric_origin": metric_origin,
"process_memory": round(process.memory_info().rss / (1024 ** 2),2),
"virtual_memory": {
"total": round(virtual_memory.total / (1024 ** 2), 2),
"available": round(virtual_memory.available / (1024 ** 2), 2),
Expand All @@ -707,7 +709,10 @@ def get_memory_usage(self, round_number, metric_origin):
"percent": swap_memory.percent,
},
}
self.logger.info(f"*******************END OF ROUND CHECK: {metric_origin} LOGS*******************************")
self.logger.info("Memory Usage: %s", memory_usage)
self.logger.info("*************************************************************************************")

return memory_usage

def _end_of_round_with_stragglers_check(self):
Expand Down Expand Up @@ -1008,13 +1013,6 @@ def _end_of_round_check(self):
all_tasks = self.assigner.get_all_tasks_for_round(self.round_number)
for task_name in all_tasks:
self._compute_validation_related_task_metrics(task_name)

self.logger.info("*******************END OF ROUND CHECK: AGGREGATOR LOGS*******************************")
process = psutil.Process()
process_mem = round(process.memory_info().rss / (1024 ** 2),2)
self.logger.info(f"Aggregator Round: {self.round_number}")
self.logger.info(f"Aggregator Process Mem: {process_mem}")
self.logger.info("*************************************************************************************")
memory_detail = self.get_memory_usage(self.round_number, "aggregator")
self.memory_details.append(memory_detail)

Expand All @@ -1034,7 +1032,7 @@ def _end_of_round_check(self):
# TODO This needs to be fixed!
if self._time_to_quit():
# Write self.memory_details to a file
if self.memleak_check:
if self.log_memory_usage:
self.logger.info("Writing memory details to file...")
with open(AGG_MEM_FILE_NAME, "w") as f:
json.dump(self.memory_details, f, indent=4)
Expand Down
15 changes: 11 additions & 4 deletions openfl/component/collaborator/collaborator.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(
delta_updates=False,
compression_pipeline=None,
db_store_rounds=1,
memleak_check=False,
log_memory_usage=False,
**kwargs,
):
"""Initialize the Collaborator object.
Expand Down Expand Up @@ -125,7 +125,7 @@ def __init__(
self.delta_updates = delta_updates

self.client = client
self.memleak_check = memleak_check
self.log_memory_usage = log_memory_usage
self.task_config = task_config

self.logger = getLogger(__name__)
Expand Down Expand Up @@ -174,10 +174,11 @@ def run(self):

# Cleaning tensor db
self.tensor_db.clean_up(self.db_store_rounds)
if self.memleak_check:
if self.log_memory_usage:
# This is the place to check the memory usage of the collaborator
self.logger.info("*****************COLLABORATOR LOGS*******************************")
process = psutil.Process()
self.logger.info(process)
process_mem = round(process.memory_info().rss / (1024 ** 2),2)
self.logger.info("Collaborator Round: %s", round_number)
self.logger.info("Collaborator Process Mem: %s", process_mem)
Expand All @@ -187,7 +188,7 @@ def run(self):
memory_detail = self.get_memory_usage(round_number,
metric_origin=self.collaborator_name)
memory_details.append(memory_detail)
if self.memleak_check:
if self.log_memory_usage:
# Write json file with memory usage details and collabrator name
with open(f"{self.collaborator_name}_mem_details.json", "w") as f:
json.dump(memory_details, f, indent=4)
Expand Down Expand Up @@ -620,11 +621,14 @@ def get_memory_usage(self, round_number, metric_origin):
Args:
round_number (int): The current round number for which memory usage is being logged.
"""
process = psutil.Process()
self.logger.info(f"{metric_origin} process id is {process}")
virtual_memory = psutil.virtual_memory()
swap_memory = psutil.swap_memory()
memory_usage = {
"round_number": round_number,
"metric_origin": metric_origin,
"process_memory": round(process.memory_info().rss / (1024 ** 2),2),
"virtual_memory": {
"total": round(virtual_memory.total / (1024 ** 2), 2),
"available": round(virtual_memory.available / (1024 ** 2), 2),
Expand All @@ -644,5 +648,8 @@ def get_memory_usage(self, round_number, metric_origin):
"percent": swap_memory.percent,
},
}
self.logger.info(f"*******************END OF ROUND CHECK: {metric_origin} LOGS*******************************")
self.logger.info("Memory Usage: %s", memory_usage)
self.logger.info("*************************************************************************************")

return memory_usage
8 changes: 4 additions & 4 deletions tests/end_to_end/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def pytest_addoption(parser):
help="Disable TLS for communication",
)
parser.addoption(
"--memleak_check",
"--log_memory_usage",
action="store_true",
help="Enable memory log in collaborators and aggregator",
)
Expand Down Expand Up @@ -239,7 +239,7 @@ def fx_federation(request, pytestconfig):
num_rounds = args.num_rounds
disable_client_auth = args.disable_client_auth
disable_tls = args.disable_tls
memleak_check = args.memleak_check
log_memory_usage = args.log_memory_usage

log.info(
f"Running federation setup using Task Runner API on single machine with below configurations:\n"
Expand All @@ -248,7 +248,7 @@ def fx_federation(request, pytestconfig):
f"\tModel name: {model_name}\n"
f"\tClient authentication: {not disable_client_auth}\n"
f"\tTLS: {not disable_tls}\n"
f"\tMemory Logs: {memleak_check}"
f"\tMemory Logs: {log_memory_usage}"
)

# Validate the model name and create the workspace name
Expand All @@ -258,7 +258,7 @@ def fx_federation(request, pytestconfig):
workspace_name = f"workspace_{model_name}"

# Create model owner object and the workspace for the model
model_owner = participants.ModelOwner(workspace_name, model_name, memleak_check)
model_owner = participants.ModelOwner(workspace_name, model_name, log_memory_usage)
try:
workspace_path = model_owner.create_workspace(results_dir=results_dir)
except Exception as e:
Expand Down
10 changes: 5 additions & 5 deletions tests/end_to_end/models/participants.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ class ModelOwner:
4. Importing and exporting the workspace etc.
"""

def __init__(self, workspace_name, model_name, memleak_check):
def __init__(self, workspace_name, model_name, log_memory_usage):
"""
Initialize the ModelOwner class
Args:
workspace_name (str): Workspace name
model_name (str): Model name
memleak_check (bool): Memory Log flag
log_memory_usage (bool): Memory Log flag
"""
self.workspace_name = workspace_name
self.model_name = model_name
Expand All @@ -39,7 +39,7 @@ def __init__(self, workspace_name, model_name, memleak_check):
self.plan_path = None
self.num_collaborators = constants.NUM_COLLABORATORS
self.rounds_to_train = constants.NUM_ROUNDS
self.memleak_check = memleak_check
self.log_memory_usage = log_memory_usage

def create_workspace(self, results_dir=None):
"""
Expand Down Expand Up @@ -135,8 +135,8 @@ def modify_plan(self, new_rounds=None, num_collaborators=None, disable_client_au

data["aggregator"]["settings"]["rounds_to_train"] = int(self.rounds_to_train)
# Memory Leak related
data["aggregator"]["settings"]["memleak_check"] = self.memleak_check
data["collaborator"]["settings"]["memleak_check"] = self.memleak_check
data["aggregator"]["settings"]["log_memory_usage"] = self.log_memory_usage
data["collaborator"]["settings"]["log_memory_usage"] = self.log_memory_usage

data["data_loader"]["settings"]["collaborator_count"] = int(self.num_collaborators)
data["network"]["settings"]["disable_client_auth"] = disable_client_auth
Expand Down
4 changes: 2 additions & 2 deletions tests/end_to_end/utils/conftest_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def parse_arguments():
- model_name (str, default="torch_cnn_mnist"): Model name
- disable_client_auth (bool): Disable client authentication
- disable_tls (bool): Disable TLS for communication
- memleak_check (bool): Enable Memory leak logs
- log_memory_usage (bool): Enable Memory leak logs
Raises:
SystemExit: If the required arguments are not provided or if any argument parsing error occurs.
Expand All @@ -33,7 +33,7 @@ def parse_arguments():
parser.add_argument("--model_name", type=str, help="Model name")
parser.add_argument("--disable_client_auth", action="store_true", help="Disable client authentication")
parser.add_argument("--disable_tls", action="store_true", help="Disable TLS for communication")
parser.add_argument("--memleak_check", action="store_true", help="Enable Memory leak logs")
parser.add_argument("--log_memory_usage", action="store_true", help="Enable Memory leak logs")
args = parser.parse_known_args()[0]
return args

Expand Down

0 comments on commit 1a22ee4

Please sign in to comment.