Skip to content

Commit

Permalink
break(framework) Fix simulation arguments and tests (#4563)
Browse files Browse the repository at this point in the history
  • Loading branch information
chongshenng authored Nov 26, 2024
1 parent db87321 commit 7a02f7a
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 10 deletions.
7 changes: 6 additions & 1 deletion e2e/test_exec_api.sh
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ case "$3" in
executor_arg="--executor flwr.superexec.deployment:executor"
;;
simulation-engine)
executor_config="$executor_config num-supernodes=10"
executor_arg="--executor flwr.superexec.simulation:executor"
;;
esac
Expand Down Expand Up @@ -74,6 +73,10 @@ else
echo -e $"\n[tool.flwr.federations.e2e]\naddress = \"127.0.0.1:9093\"\nroot-certificates = \"../certificates/ca.crt\"" >> pyproject.toml
fi

if [ "$3" = "simulation-engine" ]; then
echo -e $"options.num-supernodes = 10" >> pyproject.toml
fi

# Combine the arguments into a single command for flower-superlink
combined_args="$server_arg $server_auth $exec_api_arg $executor_arg"

Expand Down Expand Up @@ -113,6 +116,7 @@ while [ "$found_success" = false ] && [ $elapsed -lt $timeout ]; do
kill $cl1_pid; kill $cl2_pid;
fi
sleep 1; kill $sl_pid;
exit 0;
else
echo "Waiting for training ... ($elapsed seconds elapsed)"
fi
Expand All @@ -127,4 +131,5 @@ if [ "$found_success" = false ]; then
kill $cl1_pid; kill $cl2_pid;
fi
kill $sl_pid;
exit 1;
fi
9 changes: 7 additions & 2 deletions src/py/flwr/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def run_superlink() -> None:
address=simulationio_address,
state_factory=state_factory,
ffs_factory=ffs_factory,
certificates=certificates,
certificates=None, # SimulationAppIo API doesn't support SSL yet
)
grpc_servers.append(simulationio_server)

Expand Down Expand Up @@ -389,6 +389,9 @@ def run_superlink() -> None:
io_address = (
f"{CLIENT_OCTET}:{_port}" if _octet == SERVER_OCTET else serverappio_address
)
address_arg = (
"--simulationio-api-address" if sim_exec else "--serverappio-api-address"
)
address = simulationio_address if sim_exec else io_address
cmd = "flwr-simulation" if sim_exec else "flwr-serverapp"

Expand All @@ -397,6 +400,7 @@ def run_superlink() -> None:
target=_flwr_scheduler,
args=(
state_factory,
address_arg,
address,
cmd,
),
Expand All @@ -422,6 +426,7 @@ def run_superlink() -> None:

def _flwr_scheduler(
state_factory: LinkStateFactory,
io_api_arg: str,
io_api_address: str,
cmd: str,
) -> None:
Expand All @@ -446,7 +451,7 @@ def _flwr_scheduler(
command = [
cmd,
"--run-once",
"--serverappio-api-address",
io_api_arg,
io_api_address,
"--insecure",
]
Expand Down
20 changes: 13 additions & 7 deletions src/py/flwr/simulation/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@
get_project_dir,
unflatten_dict,
)
from flwr.common.constant import Status, SubStatus
from flwr.common.constant import (
SIMULATIONIO_API_DEFAULT_CLIENT_ADDRESS,
Status,
SubStatus,
)
from flwr.common.logger import (
log,
mirror_output_to_queue,
Expand Down Expand Up @@ -73,9 +77,11 @@ def flwr_simulation() -> None:
description="Run a Flower Simulation",
)
parser.add_argument(
"--superlink",
"--simulationio-api-address",
default=SIMULATIONIO_API_DEFAULT_CLIENT_ADDRESS,
type=str,
help="Address of SuperLink's SimulationIO API",
help="Address of SuperLink's SimulationIO API (IPv4, IPv6, or a domain name)."
f"By default, it is set to {SIMULATIONIO_API_DEFAULT_CLIENT_ADDRESS}.",
)
parser.add_argument(
"--run-once",
Expand Down Expand Up @@ -111,15 +117,15 @@ def flwr_simulation() -> None:
args = parser.parse_args()

log(INFO, "Starting Flower Simulation")
certificates = try_obtain_root_certificates(args, args.superlink)
certificates = try_obtain_root_certificates(args, args.simulationio_api_address)

log(
DEBUG,
"Staring isolated `Simulation` connected to SuperLink DriverAPI at %s",
args.superlink,
args.simulationio_api_address,
)
run_simulation_process(
simulationio_api_address=args.superlink,
simulationio_api_address=args.simulationio_api_address,
log_queue=log_queue,
run_once=args.run_once,
flwr_dir_=args.flwr_dir,
Expand Down Expand Up @@ -225,7 +231,7 @@ def run_simulation_process( # pylint: disable=R0914, disable=W0212, disable=R09
)
backend_config: BackendConfig = fed_opt.get("backend", {})
verbose: bool = fed_opt.get("verbose", False)
enable_tf_gpu_growth: bool = fed_opt.get("enable_tf_gpu_growth", True)
enable_tf_gpu_growth: bool = fed_opt.get("enable_tf_gpu_growth", False)

# Launch the simulation
_run_simulation(
Expand Down

0 comments on commit 7a02f7a

Please sign in to comment.