Skip to content

Commit

Permalink
Revert "simplify code"
Browse files Browse the repository at this point in the history
This reverts commit 249a543.
  • Loading branch information
janosh committed Apr 9, 2024
1 parent 9273248 commit 553dba7
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 34 deletions.
18 changes: 9 additions & 9 deletions fireworks/core/firework.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,9 @@ def __init__(
not only to direct children, but to all dependent FireWorks
down to the Workflow's leaves.
"""
mod_spec = mod_spec or []
additions = additions or []
detours = detours or []
mod_spec = mod_spec if mod_spec is not None else []
additions = additions if additions is not None else []
detours = detours if detours is not None else []

self.stored_data = stored_data if stored_data else {}
self.exit = exit
Expand Down Expand Up @@ -267,13 +267,13 @@ def __init__(
NEGATIVE_FWID_CTR -= 1
self.fw_id = NEGATIVE_FWID_CTR

self.launches = launches or []
self.archived_launches = archived_launches or []
self.launches = launches if launches else []
self.archived_launches = archived_launches if archived_launches else []
self.created_on = created_on or datetime.utcnow()
self.updated_on = updated_on or datetime.utcnow()

parents = [parents] if isinstance(parents, Firework) else parents
self.parents = parents or []
self.parents = parents if parents else []

self._state = state

Expand Down Expand Up @@ -476,9 +476,9 @@ def __init__(
self.fworker = fworker or FWorker()
self.host = host or get_my_host()
self.ip = ip or get_my_ip()
self.trackers = trackers or []
self.trackers = trackers if trackers else []
self.action = action if action else None
self.state_history = state_history or []
self.state_history = state_history if state_history else []
self.state = state
self.launch_id = launch_id
self.fw_id = fw_id
Expand Down Expand Up @@ -643,7 +643,7 @@ def _update_state_history(self, state) -> None:
now_time = datetime.utcnow()
new_history_entry = {"state": state, "created_on": now_time}
if state != "COMPLETED" and last_checkpoint:
new_history_entry.update(checkpoint=last_checkpoint)
new_history_entry.update({"checkpoint": last_checkpoint})
self.state_history.append(new_history_entry)
if state in ["RUNNING", "RESERVED"]:
self.touch_history() # add updated_on key
Expand Down
47 changes: 29 additions & 18 deletions fireworks/core/launchpad.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,15 +185,15 @@ def __init__(
self.password = password
self.authsource = authsource or self.name
self.mongoclient_kwargs = mongoclient_kwargs or {}
self.uri_mode = bool(uri_mode)
self.uri_mode = uri_mode

# set up logger
self.logdir = logdir
self.strm_lvl = strm_lvl if strm_lvl else "INFO"
self.m_logger = get_fw_logger("launchpad", l_dir=self.logdir, stream_level=self.strm_lvl)

self.user_indices = user_indices or []
self.wf_user_indices = wf_user_indices or []
self.user_indices = user_indices if user_indices else []
self.wf_user_indices = wf_user_indices if wf_user_indices else []

# get connection
if uri_mode:
Expand Down Expand Up @@ -267,20 +267,31 @@ def update_spec(self, fw_ids, spec_document, mongo=False) -> None:
)

@classmethod
def from_dict(cls, dct):
def from_dict(cls, d):
port = d.get("port", None)
name = d.get("name", None)
username = d.get("username", None)
password = d.get("password", None)
logdir = d.get("logdir", None)
strm_lvl = d.get("strm_lvl", None)
user_indices = d.get("user_indices", [])
wf_user_indices = d.get("wf_user_indices", [])
authsource = d.get("authsource", None)
uri_mode = d.get("uri_mode", False)
mongoclient_kwargs = d.get("mongoclient_kwargs", None)
return LaunchPad(
dct["host"],
port=dct.get("port"),
name=dct.get("name"),
username=dct.get("username"),
password=dct.get("password"),
logdir=dct.get("logdir"),
strm_lvl=dct.get("strm_lvl"),
user_indices=dct.get("user_indices"),
wf_user_indices=dct.get("wf_user_indices"),
authsource=dct.get("authsource"),
uri_mode=dct.get("uri_mode", False),
mongoclient_kwargs=dct.get("mongoclient_kwargs"),
d["host"],
port,
name,
username,
password,
logdir,
strm_lvl,
user_indices,
wf_user_indices,
authsource,
uri_mode,
mongoclient_kwargs,
)

@classmethod
Expand Down Expand Up @@ -1659,7 +1670,7 @@ def rerun_fw(self, fw_id, rerun_duplicates=True, recover_launch=None, recover_mo
# Launch recovery
if recover_launch is not None:
recovery = self.get_recovery(fw_id, recover_launch)
recovery.update(_mode=recover_mode)
recovery.update({"_mode": recover_mode})
set_spec = recursive_dict({"$set": {"spec._recovery": recovery}})
if recover_mode == "prev_dir":
prev_dir = self.get_launch_by_id(recovery.get("_launch_id")).launch_dir
Expand Down Expand Up @@ -1703,7 +1714,7 @@ def get_recovery(self, fw_id, launch_id="last"):
m_fw = self.get_fw_by_id(fw_id)
launch = m_fw.launches[-1] if launch_id == "last" else self.get_launch_by_id(launch_id)
recovery = launch.state_history[-1].get("checkpoint")
recovery.update(_prev_dir=launch.launch_dir, _launch_id=launch.launch_id)
recovery.update({"_prev_dir": launch.launch_dir, "_launch_id": launch.launch_id})
return recovery

def _refresh_wf(self, fw_id) -> None:
Expand Down
10 changes: 5 additions & 5 deletions fireworks/features/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,8 @@ def group_fizzled_fireworks(
"created_on": self._query_datetime_range(start_time=query_start, end_time=query_end, **args),
}
if include_ids:
project_query.update(fw_id=1)
group_query.update(fw_id={"$push": "$fw_id"})
project_query.update({"fw_id": 1})
group_query.update({"fw_id": {"$push": "$fw_id"}})
if query:
match_query.update(query)
return self._aggregate(
Expand Down Expand Up @@ -306,11 +306,11 @@ def _get_summary(
}
match_query.update(query)
if runtime_stats:
project_query.update(runtime_secs=1)
project_query.update({"runtime_secs": 1})
group_query.update(RUNTIME_STATS)
if include_ids:
project_query.update({id_field: 1})
group_query.update(ids={"$push": "$" + id_field})
group_query.update({"ids": {"$push": "$" + id_field}})
return self._aggregate(
coll=coll,
match=match_query,
Expand Down Expand Up @@ -357,7 +357,7 @@ def _aggregate(
for arg in [match, project, unwind, group_op]:
if arg is None:
arg = {}
group_op.update(_id=f"${group_by}")
group_op.update({"_id": "$" + group_by})
if sort is None:
sort_query = ("_id", 1)
query = [{"$match": match}, {"$project": project}, {"$group": group_op}, {"$sort": SON([sort_query])}]
Expand Down
2 changes: 1 addition & 1 deletion fireworks/queue/queue_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def launch_rocket_to_queue(

# update qadapter job_name based on FW name
job_name = get_slug(fw.name)[0:QUEUE_JOBNAME_MAXLEN]
qadapter.update(job_name=job_name)
qadapter.update({"job_name": job_name})

if "_queueadapter" in fw.spec:
l_logger.debug("updating queue params using Firework spec..")
Expand Down
2 changes: 1 addition & 1 deletion fireworks/user_objects/queue_adapters/common_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(self, q_type, q_name=None, template_file=None, timeout=None, **kwar
)
self.q_name = q_name or q_type
self.timeout = timeout or 5
self.update(kwargs)
self.update(dict(kwargs))

self.q_commands = copy.deepcopy(CommonAdapter.default_q_commands)
if "_q_commands_override" in self:
Expand Down

0 comments on commit 553dba7

Please sign in to comment.