Skip to content

Commit

Permalink
remove machines aliases
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb committed Jan 25, 2021
1 parent 481a5b4 commit 63387db
Showing 1 changed file with 35 additions and 29 deletions.
64 changes: 35 additions & 29 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2180,35 +2180,41 @@ def __init__(self, params=None, train_set=None, model_file=None, model_str=None,
if not isinstance(train_set, Dataset):
raise TypeError('Training data should be Dataset instance, met {}'
.format(type(train_set).__name__))
# set network if necessary
for alias in _ConfigAliases.get("machines"):
if alias in params:
machines = params[alias]
if isinstance(machines, str):
num_machines_from_machine_list = len(machines.split(','))
elif isinstance(machines, (list, set)):
num_machines_from_machine_list = len(machines)
machines = ','.join(machines)
else:
raise ValueError("Invalid machines in params.")

params = _choose_param_value(
main_param_name="num_machines",
params=params,
default_value=num_machines_from_machine_list
)
params = _choose_param_value(
main_param_name="local_listen_port",
params=params,
default_value=12400
)
self.set_network(
machines=machines,
local_listen_port=params["local_listen_port"],
listen_time_out=params.get("listen_time_out", 120),
num_machines=params["num_machines"]
)
break
params = _choose_param_value(
main_param_name="machines",
params=params,
default_value=None
)
# if "machines" is given, assume user wants to do distributed learning, and set up network
if params["machines"] is None:
params.pop("machines", None)
else:
machines = params["machines"]
if isinstance(machines, str):
num_machines_from_machine_list = len(machines.split(','))
elif isinstance(machines, (list, set)):
num_machines_from_machine_list = len(machines)
machines = ','.join(machines)
else:
raise ValueError("Invalid machines in params.")

params = _choose_param_value(
main_param_name="num_machines",
params=params,
default_value=num_machines_from_machine_list
)
params = _choose_param_value(
main_param_name="local_listen_port",
params=params,
default_value=12400
)
self.set_network(
machines=machines,
local_listen_port=params["local_listen_port"],
listen_time_out=params.get("listen_time_out", 120),
num_machines=params["num_machines"]
)
break
# construct booster object
train_set.construct()
# copy the parameters from train_set
Expand Down

0 comments on commit 63387db

Please sign in to comment.