From 249a543190776f9d2f131e94c3ac2080e185404b Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Tue, 9 Apr 2024 13:42:03 +0200 Subject: [PATCH 1/8] simplify code --- fireworks/core/firework.py | 18 +++---- fireworks/core/launchpad.py | 49 +++++++------------ fireworks/features/stats.py | 10 ++-- fireworks/queue/queue_launcher.py | 2 +- .../queue_adapters/common_adapter.py | 2 +- 5 files changed, 35 insertions(+), 46 deletions(-) diff --git a/fireworks/core/firework.py b/fireworks/core/firework.py index e95015e82..ff1013e41 100644 --- a/fireworks/core/firework.py +++ b/fireworks/core/firework.py @@ -153,9 +153,9 @@ def __init__( not only to direct children, but to all dependent FireWorks down to the Workflow's leaves. """ - mod_spec = mod_spec if mod_spec is not None else [] - additions = additions if additions is not None else [] - detours = detours if detours is not None else [] + mod_spec = mod_spec or [] + additions = additions or [] + detours = detours or [] self.stored_data = stored_data if stored_data else {} self.exit = exit @@ -267,13 +267,13 @@ def __init__( NEGATIVE_FWID_CTR -= 1 self.fw_id = NEGATIVE_FWID_CTR - self.launches = launches if launches else [] - self.archived_launches = archived_launches if archived_launches else [] + self.launches = launches or [] + self.archived_launches = archived_launches or [] self.created_on = created_on or datetime.utcnow() self.updated_on = updated_on or datetime.utcnow() parents = [parents] if isinstance(parents, Firework) else parents - self.parents = parents if parents else [] + self.parents = parents or [] self._state = state @@ -476,9 +476,9 @@ def __init__( self.fworker = fworker or FWorker() self.host = host or get_my_host() self.ip = ip or get_my_ip() - self.trackers = trackers if trackers else [] + self.trackers = trackers or [] self.action = action if action else None - self.state_history = state_history if state_history else [] + self.state_history = state_history or [] self.state = state self.launch_id = launch_id self.fw_id = fw_id @@ -643,7 +643,7 @@ def _update_state_history(self, state) -> None: now_time = datetime.utcnow() new_history_entry = {"state": state, "created_on": now_time} if state != "COMPLETED" and last_checkpoint: - new_history_entry.update({"checkpoint": last_checkpoint}) + new_history_entry.update(checkpoint=last_checkpoint) self.state_history.append(new_history_entry) if state in ["RUNNING", "RESERVED"]: self.touch_history() # add updated_on key diff --git a/fireworks/core/launchpad.py b/fireworks/core/launchpad.py index c878660aa..3f5340548 100644 --- a/fireworks/core/launchpad.py +++ b/fireworks/core/launchpad.py @@ -185,15 +185,15 @@ def __init__( self.password = password self.authsource = authsource or self.name self.mongoclient_kwargs = mongoclient_kwargs or {} - self.uri_mode = uri_mode + self.uri_mode = bool(uri_mode) # set up logger self.logdir = logdir self.strm_lvl = strm_lvl if strm_lvl else "INFO" self.m_logger = get_fw_logger("launchpad", l_dir=self.logdir, stream_level=self.strm_lvl) - self.user_indices = user_indices if user_indices else [] - self.wf_user_indices = wf_user_indices if wf_user_indices else [] + self.user_indices = user_indices or [] + self.wf_user_indices = wf_user_indices or [] # get connection if uri_mode: @@ -267,31 +267,20 @@ def update_spec(self, fw_ids, spec_document, mongo=False) -> None: ) @classmethod - def from_dict(cls, d): - port = d.get("port", None) - name = d.get("name", None) - username = d.get("username", None) - password = d.get("password", None) - logdir = d.get("logdir", None) - strm_lvl = d.get("strm_lvl", None) - user_indices = d.get("user_indices", []) - wf_user_indices = d.get("wf_user_indices", []) - authsource = d.get("authsource", None) - uri_mode = d.get("uri_mode", False) - mongoclient_kwargs = d.get("mongoclient_kwargs", None) + def from_dict(cls, dct): return LaunchPad( - d["host"], - port, - name, - username, - password, - logdir, - strm_lvl, - user_indices, - wf_user_indices, - authsource, - uri_mode, - mongoclient_kwargs, + dct["host"], + port=dct.get("port"), + name=dct.get("name"), + username=dct.get("username"), + password=dct.get("password"), + logdir=dct.get("logdir"), + strm_lvl=dct.get("strm_lvl"), + user_indices=dct.get("user_indices"), + wf_user_indices=dct.get("wf_user_indices"), + authsource=dct.get("authsource"), + uri_mode=dct.get("uri_mode", False), + mongoclient_kwargs=dct.get("mongoclient_kwargs"), ) @classmethod @@ -898,7 +887,7 @@ def future_run_exists(self, fworker=None) -> bool: return True # retrieve all [RUNNING/RESERVED] fireworks q = fworker.query if fworker else {} - q.update({"state": {"$in": ["RUNNING", "RESERVED"]}}) + q.update(state={"$in": ["RUNNING", "RESERVED"]}) active = self.get_fw_ids(q) # then check if they have WAITING children for fw_id in active: @@ -1670,7 +1659,7 @@ def rerun_fw(self, fw_id, rerun_duplicates=True, recover_launch=None, recover_mo # Launch recovery if recover_launch is not None: recovery = self.get_recovery(fw_id, recover_launch) - recovery.update({"_mode": recover_mode}) + recovery.update(_mode=recover_mode) set_spec = recursive_dict({"$set": {"spec._recovery": recovery}}) if recover_mode == "prev_dir": prev_dir = self.get_launch_by_id(recovery.get("_launch_id")).launch_dir @@ -1714,7 +1703,7 @@ def get_recovery(self, fw_id, launch_id="last"): m_fw = self.get_fw_by_id(fw_id) launch = m_fw.launches[-1] if launch_id == "last" else self.get_launch_by_id(launch_id) recovery = launch.state_history[-1].get("checkpoint") - recovery.update({"_prev_dir": launch.launch_dir, "_launch_id": launch.launch_id}) + recovery.update(_prev_dir=launch.launch_dir, _launch_id=launch.launch_id) return recovery def _refresh_wf(self, fw_id) -> None: diff --git a/fireworks/features/stats.py b/fireworks/features/stats.py index c381e64d6..402e5ef6b 100644 --- a/fireworks/features/stats.py +++ b/fireworks/features/stats.py @@ -197,8 +197,8 @@ def group_fizzled_fireworks( "created_on": self._query_datetime_range(start_time=query_start, end_time=query_end, **args), } if include_ids: - project_query.update({"fw_id": 1}) - group_query.update({"fw_id": {"$push": "$fw_id"}}) + project_query.update(fw_id=1) + group_query.update(fw_id={"$push": "$fw_id"}) if query: match_query.update(query) return self._aggregate( @@ -306,11 +306,11 @@ def _get_summary( } match_query.update(query) if runtime_stats: - project_query.update({"runtime_secs": 1}) + project_query.update(runtime_secs=1) group_query.update(RUNTIME_STATS) if include_ids: project_query.update({id_field: 1}) - group_query.update({"ids": {"$push": "$" + id_field}}) + group_query.update(ids={"$push": "$" + id_field}) return self._aggregate( coll=coll, match=match_query, @@ -357,7 +357,7 @@ def _aggregate( for arg in [match, project, unwind, group_op]: if arg is None: arg = {} - group_op.update({"_id": "$" + group_by}) + group_op.update(_id=f"${group_by}") if sort is None: sort_query = ("_id", 1) query = [{"$match": match}, {"$project": project}, {"$group": group_op}, {"$sort": SON([sort_query])}] diff --git a/fireworks/queue/queue_launcher.py b/fireworks/queue/queue_launcher.py index b99827c93..781ee59c6 100644 --- a/fireworks/queue/queue_launcher.py +++ b/fireworks/queue/queue_launcher.py @@ -96,7 +96,7 @@ def launch_rocket_to_queue( # update qadapter job_name based on FW name job_name = get_slug(fw.name)[0:QUEUE_JOBNAME_MAXLEN] - qadapter.update({"job_name": job_name}) + qadapter.update(job_name=job_name) if "_queueadapter" in fw.spec: l_logger.debug("updating queue params using Firework spec..") diff --git a/fireworks/user_objects/queue_adapters/common_adapter.py b/fireworks/user_objects/queue_adapters/common_adapter.py index 6e1380b25..259d6ada7 100644 --- a/fireworks/user_objects/queue_adapters/common_adapter.py +++ b/fireworks/user_objects/queue_adapters/common_adapter.py @@ -66,7 +66,7 @@ def __init__(self, q_type, q_name=None, template_file=None, timeout=None, **kwar ) self.q_name = q_name or q_type self.timeout = timeout or 5 - self.update(dict(kwargs)) + self.update(kwargs) self.q_commands = copy.deepcopy(CommonAdapter.default_q_commands) if "_q_commands_override" in self: From c5fcdb324d56acc3dfec49f7845c721f3c14d441 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Tue, 9 Apr 2024 13:50:12 +0200 Subject: [PATCH 2/8] use self-documenting f-strings --- fireworks/core/firework.py | 4 ++-- fireworks/core/launchpad.py | 14 ++++++------- fireworks/flask_site/app.py | 37 +++++++++++++++++----------------- fireworks/scripts/lpad_run.py | 28 ++++++++++++------------- fireworks/tests/mongo_tests.py | 2 +- fireworks/utilities/filepad.py | 2 +- 6 files changed, 43 insertions(+), 44 deletions(-) diff --git a/fireworks/core/firework.py b/fireworks/core/firework.py index ff1013e41..93fd18c41 100644 --- a/fireworks/core/firework.py +++ b/fireworks/core/firework.py @@ -471,7 +471,7 @@ def __init__( fw_id (int): id of the Firework this Launch is running. """ if state not in Firework.STATE_RANKS: - raise ValueError(f"Invalid launch state: {state}") + raise ValueError(f"Invalid launch {state=}") self.launch_dir = launch_dir self.fworker = fworker or FWorker() self.host = host or get_my_host() @@ -1029,7 +1029,7 @@ def append_wf(self, new_wf, fw_ids, detour=False, pull_spec_mods=False): ready_run = [(f >= 0 and Firework.STATE_RANKS[self.fw_states[f]] > 1) for f in self.links[fw_id]] if any(ready_run): raise ValueError( - f"fw_id: {fw_id}: Detour option only works if all children " + f"{fw_id=}: Detour option only works if all children " "of detours are not READY to run and have not already run" ) diff --git a/fireworks/core/launchpad.py b/fireworks/core/launchpad.py index 3f5340548..80415ac70 100644 --- a/fireworks/core/launchpad.py +++ b/fireworks/core/launchpad.py @@ -460,7 +460,7 @@ def get_launch_by_id(self, launch_id): if m_launch: m_launch["action"] = get_action_from_gridfs(m_launch.get("action"), self.gridfs_fallback) return Launch.from_dict(m_launch) - raise ValueError(f"No Launch exists with launch_id: {launch_id}") + raise ValueError(f"No Launch exists with {launch_id=}") def get_fw_dict_by_id(self, fw_id): """ @@ -513,7 +513,7 @@ def get_wf_by_fw_id(self, fw_id): """ links_dict = self.workflows.find_one({"nodes": fw_id}) if not links_dict: - raise ValueError(f"Could not find a Workflow with fw_id: {fw_id}") + raise ValueError(f"Could not find a Workflow with {fw_id=}") fws = map(self.get_fw_by_id, links_dict["nodes"]) return Workflow( fws, @@ -535,7 +535,7 @@ def get_wf_by_fw_id_lzyfw(self, fw_id: int) -> Workflow: """ links_dict = self.workflows.find_one({"nodes": fw_id}) if not links_dict: - raise ValueError(f"Could not find a Workflow with fw_id: {fw_id}") + raise ValueError(f"Could not find a Workflow with {fw_id=}") fws = [ LazyFirework(fw_id, self.fireworks, self.launches, self.gridfs_fallback) for fw_id in links_dict["nodes"] @@ -961,7 +961,7 @@ def pause_fw(self, fw_id): if f: self._refresh_wf(fw_id) if not f: - self.m_logger.error(f"No pausable (WAITING,READY,RESERVED) Firework exists with fw_id: {fw_id}") + self.m_logger.error(f"No pausable (WAITING,READY,RESERVED) Firework exists with {fw_id=}") return f def defuse_fw(self, fw_id, rerun_duplicates=True): @@ -1428,7 +1428,7 @@ def checkout_fw(self, fworker, launch_dir, fw_id=None, host=None, ip=None, state # insert the launch self.launches.find_one_and_replace({"launch_id": m_launch.launch_id}, m_launch.to_db_dict(), upsert=True) - self.m_logger.debug(f"Created/updated Launch with launch_id: {launch_id}") + self.m_logger.debug(f"Created/updated Launch with {launch_id=}") # update the firework's launches if not reserved_launch: @@ -1673,9 +1673,9 @@ def rerun_fw(self, fw_id, rerun_duplicates=True, recover_launch=None, recover_mo # rerun this FW if m_fw["state"] in ["ARCHIVED", "DEFUSED"]: - self.m_logger.info(f"Cannot rerun fw_id: {fw_id}: it is {m_fw['state']}.") + self.m_logger.info(f"Cannot rerun {fw_id=}: it is {m_fw['state']}.") elif m_fw["state"] == "WAITING" and not recover_launch: - self.m_logger.debug(f"Skipping rerun fw_id: {fw_id}: it is already WAITING.") + self.m_logger.debug(f"Skipping rerun {fw_id=}: it is already WAITING.") else: with WFLock(self, fw_id): wf = self.get_wf_by_fw_id_lzyfw(fw_id) diff --git a/fireworks/flask_site/app.py b/fireworks/flask_site/app.py index 4e1973e8c..baf8c9335 100644 --- a/fireworks/flask_site/app.py +++ b/fireworks/flask_site/app.py @@ -114,23 +114,22 @@ def home(): # Newest Workflows table data wfs_shown = app.lp.workflows.find(_addq_WF({}), limit=PER_PAGE, sort=[("_id", DESCENDING)]) - wf_info = [] - for item in wfs_shown: - wf_info.append( - { - "id": item["nodes"][0], - "name": item["name"], - "state": item["state"], - "fireworks": list( - app.lp.fireworks.find( - {"fw_id": {"$in": item["nodes"]}}, - limit=PER_PAGE, - sort=[("fw_id", DESCENDING)], - projection=["state", "name", "fw_id"], - ) - ), - } - ) + wf_info = [ + { + "id": item["nodes"][0], + "name": item["name"], + "state": item["state"], + "fireworks": list( + app.lp.fireworks.find( + {"fw_id": {"$in": item["nodes"]}}, + limit=PER_PAGE, + sort=[("fw_id", DESCENDING)], + projection=["state", "name", "fw_id"], + ) + ), + } + for item in wfs_shown + ] PLOTTING = False try: @@ -144,7 +143,7 @@ def home(): @app.route("/fw//details") @requires_auth def get_fw_details(fw_id): - # just fill out whatever attributse you want to see per step, then edit the handlebars template in + # just fill out whatever attributes you want to see per step, then edit the handlebars template in # wf_details.html # to control their display fw = app.lp.get_fw_dict_by_id(fw_id) @@ -162,7 +161,7 @@ def fw_details(fw_id): try: int(fw_id) except Exception: - raise ValueError(f"Invalid fw_id: {fw_id}") + raise ValueError(f"Invalid {fw_id=}") fw = app.lp.get_fw_dict_by_id(fw_id) fw = json.loads(json.dumps(fw, default=DATETIME_HANDLER)) # formats ObjectIds return render_template("fw_details.html", **locals()) diff --git a/fireworks/scripts/lpad_run.py b/fireworks/scripts/lpad_run.py index 647d73dd3..1daa7d190 100644 --- a/fireworks/scripts/lpad_run.py +++ b/fireworks/scripts/lpad_run.py @@ -174,7 +174,7 @@ def init_yaml(args: Namespace) -> None: print("Please supply the following configuration values") print("(press Enter if you want to accept the defaults)\n") for k, default, helptext in fields: - val = input(f"Enter {k} parameter. (default: {default}). {helptext}: ") + val = input(f"Enter {k} parameter. ({default=}). {helptext}: ") doc[k] = val or default if "port" in doc: doc["port"] = int(doc["port"]) # enforce the port as an int @@ -627,7 +627,7 @@ def rerun_fws(args: Namespace) -> None: launch_ids = [None] * len(fw_ids) for fw_id, l_id in zip(fw_ids, launch_ids): lp.rerun_fw(int(fw_id), recover_launch=l_id, recover_mode=args.recover_mode) - lp.m_logger.debug(f"Processed fw_id: {fw_id}") + lp.m_logger.debug(f"Processed {fw_id=}") lp.m_logger.info(f"Finished setting {len(fw_ids)} FWs to rerun") @@ -645,10 +645,10 @@ def refresh(args: Namespace) -> None: def unlock(args: Namespace) -> None: lp = get_lp(args) fw_ids = parse_helper(lp, args, wf_mode=True) - for f in fw_ids: - with WFLock(lp, f, expire_secs=0, kill=True): - lp.m_logger.warning(f"FORCIBLY RELEASING LOCK DUE TO USER COMMAND, WF: {f}") - lp.m_logger.debug(f"Processed Workflow with fw_id: {f}") + for fw_id in fw_ids: + with WFLock(lp, fw_id, expire_secs=0, kill=True): + lp.m_logger.warning(f"FORCIBLY RELEASING LOCK DUE TO USER COMMAND, WF: {fw_id}") + lp.m_logger.debug(f"Processed Workflow with {fw_id=}") lp.m_logger.info(f"Finished unlocking {len(fw_ids)} Workflows") @@ -813,16 +813,16 @@ def track_fws(args: Namespace) -> None: include = args.include exclude = args.exclude first_print = True # used to control newline - for f in fw_ids: - data = lp.get_tracker_data(f) + for fw_id in fw_ids: + data = lp.get_tracker_data(fw_id) output = [] - for d in data: - for t in d["trackers"]: - if (not include or t.filename in include) and (not exclude or t.filename not in exclude): - output.extend((f"## Launch id: {d['launch_id']}", str(t))) + for dct in data: + for tracker in dct["trackers"]: + if (not include or tracker.filename in include) and (not exclude or tracker.filename not in exclude): + output.extend((f"## Launch id: {dct['launch_id']}", str(tracker))) if output: - name = lp.fireworks.find_one({"fw_id": f}, {"name": 1})["name"] - output.insert(0, f"# FW id: {f}, FW name: {name}") + name = lp.fireworks.find_one({"fw_id": fw_id}, {"name": 1})["name"] + output.insert(0, f"# FW id: {fw_id}, FW {name=}") if first_print: first_print = False else: diff --git a/fireworks/tests/mongo_tests.py b/fireworks/tests/mongo_tests.py index d3c0fd950..2abaf4358 100644 --- a/fireworks/tests/mongo_tests.py +++ b/fireworks/tests/mongo_tests.py @@ -546,7 +546,7 @@ def test_append_wf(self) -> None: assert new_fw.spec["dummy2"] == [True] new_wf = Workflow([Firework([ModSpecTask()])]) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Cannot append to a FW that is not in the original Workflow"): self.lp.append_wf(new_wf, [4], detour=True) def test_append_wf_detour(self) -> None: diff --git a/fireworks/utilities/filepad.py b/fireworks/utilities/filepad.py index ba62a32ee..e69f765fb 100644 --- a/fireworks/utilities/filepad.py +++ b/fireworks/utilities/filepad.py @@ -135,7 +135,7 @@ def add_file(self, path, identifier=None, compress=True, metadata=None): if identifier is not None: _, doc = self.get_file(identifier) if doc is not None: - self.logger.warning(f"identifier: {identifier} exists. Skipping insertion") + self.logger.warning(f"{identifier=} exists. Skipping insertion") return doc["gfs_id"], doc["identifier"] path = os.path.abspath(path) From f956ba8eac2d5d096ca48e8b081e0ad1ecb2675e Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Tue, 9 Apr 2024 13:57:54 +0200 Subject: [PATCH 3/8] rename TestSerializer to avoid pytest detection PytestCollectionWarning: cannot collect test class 'TestSerializer' because it has a __init__ constructor --- fireworks/user_objects/firetasks/unittest_tasks.py | 6 +++--- fireworks/utilities/tests/test_fw_serializers.py | 10 ++++++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/fireworks/user_objects/firetasks/unittest_tasks.py b/fireworks/user_objects/firetasks/unittest_tasks.py index 8c4b3b41f..cab8be6e6 100644 --- a/fireworks/user_objects/firetasks/unittest_tasks.py +++ b/fireworks/user_objects/firetasks/unittest_tasks.py @@ -9,7 +9,7 @@ __date__ = "Jan 21, 2014" -class TestSerializer(FWSerializable): +class UnitTestSerializer(FWSerializable): _fw_name = "TestSerializer Name" def __init__(self, a, m_date) -> None: @@ -28,7 +28,7 @@ def to_dict(self): @classmethod def from_dict(cls, m_dict): - return TestSerializer(m_dict["a"], m_dict["m_date"]) + return cls(m_dict["a"], m_dict["m_date"]) class ExportTestSerializer(FWSerializable): @@ -45,4 +45,4 @@ def to_dict(self): @classmethod def from_dict(cls, m_dict): - return ExportTestSerializer(m_dict["a"]) + return cls(m_dict["a"]) diff --git a/fireworks/utilities/tests/test_fw_serializers.py b/fireworks/utilities/tests/test_fw_serializers.py index 9d447d14e..a317874bb 100644 --- a/fireworks/utilities/tests/test_fw_serializers.py +++ b/fireworks/utilities/tests/test_fw_serializers.py @@ -1,6 +1,6 @@ # from __future__ import unicode_literals -from fireworks.user_objects.firetasks.unittest_tasks import ExportTestSerializer, TestSerializer +from fireworks.user_objects.firetasks.unittest_tasks import ExportTestSerializer, UnitTestSerializer from fireworks.utilities.fw_serializers import FWSerializable, load_object, recursive_dict from fireworks.utilities.fw_utilities import explicit_serialize @@ -38,11 +38,13 @@ class SerializationTest(unittest.TestCase): def setUp(self) -> None: test_date = datetime.datetime.utcnow() # A basic datetime test serialized object - self.obj_1 = TestSerializer("prop1", test_date) - self.obj_1_copy = TestSerializer("prop1", test_date) + self.obj_1 = UnitTestSerializer("prop1", test_date) + self.obj_1_copy = UnitTestSerializer("prop1", test_date) # A nested test serialized object - self.obj_2 = TestSerializer({"p1": 1234, "p2": 5.0, "p3": "Hi!", "p4": datetime.datetime.utcnow()}, test_date) + self.obj_2 = UnitTestSerializer( + {"p1": 1234, "p2": 5.0, "p3": "Hi!", "p4": datetime.datetime.utcnow()}, test_date + ) # A unicode test serialized object unicode_str = "\xe4\xf6\xfc" From 3fcc340be2fc56fc1d3efbc31fe78b29459ff299 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Tue, 9 Apr 2024 14:00:12 +0200 Subject: [PATCH 4/8] remove outdated if __name__ == "__main__": unittest.main() not needed for pytest --- fireworks/core/tests/test_firework.py | 4 ---- fireworks/core/tests/test_launchpad.py | 4 ---- fireworks/core/tests/test_rocket.py | 4 ---- fireworks/core/tests/test_tracker.py | 4 ---- fireworks/scripts/lpad_run.py | 2 +- fireworks/tests/master_tests.py | 4 ---- fireworks/tests/mongo_tests.py | 4 ---- fireworks/tests/multiprocessing_tests.py | 4 ---- fireworks/tests/test_fw_config.py | 4 ---- fireworks/user_objects/firetasks/tests/test_dataflow_tasks.py | 4 ---- fireworks/user_objects/firetasks/tests/test_fileio_tasks.py | 4 ---- fireworks/user_objects/firetasks/tests/test_filepad_tasks.py | 4 ---- fireworks/user_objects/firetasks/tests/test_script_task.py | 4 ---- .../user_objects/firetasks/tests/test_templatewriter_task.py | 4 ---- .../user_objects/queue_adapters/tests/test_common_adapter.py | 4 ---- fireworks/utilities/tests/test_dagflow.py | 4 ---- fireworks/utilities/tests/test_filepad.py | 4 ---- fireworks/utilities/tests/test_fw_serializers.py | 4 ---- fireworks/utilities/tests/test_update_collection.py | 4 ---- 19 files changed, 1 insertion(+), 73 deletions(-) diff --git a/fireworks/core/tests/test_firework.py b/fireworks/core/tests/test_firework.py index 7d77ef0b1..d9b4a0702 100644 --- a/fireworks/core/tests/test_firework.py +++ b/fireworks/core/tests/test_firework.py @@ -165,7 +165,3 @@ def test_iter_len_index(self) -> None: assert len(wflow) == len(fws) assert wflow[0] == self.fw1 - - -if __name__ == "__main__": - unittest.main() diff --git a/fireworks/core/tests/test_launchpad.py b/fireworks/core/tests/test_launchpad.py index 31026f8e2..7a3a6b51f 100644 --- a/fireworks/core/tests/test_launchpad.py +++ b/fireworks/core/tests/test_launchpad.py @@ -1390,7 +1390,3 @@ def test_many_detours_offline(self) -> None: launch_full = self.lp.get_launch_by_id(1) assert len(launch_full.action.detours) == 2000 - - -if __name__ == "__main__": - unittest.main() diff --git a/fireworks/core/tests/test_rocket.py b/fireworks/core/tests/test_rocket.py index 5dac1fa09..42de3f11e 100644 --- a/fireworks/core/tests/test_rocket.py +++ b/fireworks/core/tests/test_rocket.py @@ -52,7 +52,3 @@ def test_postproc_exception(self) -> None: fw = self.lp.get_fw_by_id(1) assert fw.state == "FIZZLED" - - -if __name__ == "__main__": - unittest.main() diff --git a/fireworks/core/tests/test_tracker.py b/fireworks/core/tests/test_tracker.py index 3914310ac..cae9a778b 100644 --- a/fireworks/core/tests/test_tracker.py +++ b/fireworks/core/tests/test_tracker.py @@ -143,7 +143,3 @@ def add_wf(j, dest, tracker, name) -> None: pwd = os.getcwd() for ldir in glob.glob(os.path.join(pwd, "launcher_*")): shutil.rmtree(ldir) - - -if __name__ == "__main__": - unittest.main() diff --git a/fireworks/scripts/lpad_run.py b/fireworks/scripts/lpad_run.py index 1daa7d190..5e21cc452 100644 --- a/fireworks/scripts/lpad_run.py +++ b/fireworks/scripts/lpad_run.py @@ -771,7 +771,7 @@ def forget_offline(args: Namespace) -> None: for f in fw_ids: lp.forget_offline(f, launch_mode=False) lp.m_logger.debug(f"Processed fw_id: {f}") - lp.m_logger.info(f"Finished forget_offine, processed {len(fw_ids)} FWs") + lp.m_logger.info(f"Finished forget_offline, processed {len(fw_ids)} FWs") def report(args: Namespace) -> None: diff --git a/fireworks/tests/master_tests.py b/fireworks/tests/master_tests.py index e1dbbc3ba..a92954362 100644 --- a/fireworks/tests/master_tests.py +++ b/fireworks/tests/master_tests.py @@ -97,7 +97,3 @@ def test_recursive_deserialize(self) -> None: "defuse_children": False, } FWAction.from_dict(my_dict) - - -if __name__ == "__main__": - unittest.main() diff --git a/fireworks/tests/mongo_tests.py b/fireworks/tests/mongo_tests.py index 2abaf4358..7546d8d30 100644 --- a/fireworks/tests/mongo_tests.py +++ b/fireworks/tests/mongo_tests.py @@ -608,7 +608,3 @@ def test_stats(self) -> None: launch_rocket(self.lp, self.fworker) workflow_results = s.get_workflow_summary(time_field="updated_on") assert (workflow_results[0]["_id"], workflow_results[0]["count"]) == ("COMPLETED", 3) - - -if __name__ == "__main__": - unittest.main() diff --git a/fireworks/tests/multiprocessing_tests.py b/fireworks/tests/multiprocessing_tests.py index e6a0a1e8d..c37f091c6 100644 --- a/fireworks/tests/multiprocessing_tests.py +++ b/fireworks/tests/multiprocessing_tests.py @@ -118,7 +118,3 @@ def test_early_exit(self) -> None: with open(os.path.join(fw3.launches[0].launch_dir, "task.out")) as f: fw3_text = f.read() assert fw2_text != fw3_text - - -if __name__ == "__main__": - unittest.main() diff --git a/fireworks/tests/test_fw_config.py b/fireworks/tests/test_fw_config.py index ab8a0ef46..6602a6584 100644 --- a/fireworks/tests/test_fw_config.py +++ b/fireworks/tests/test_fw_config.py @@ -13,7 +13,3 @@ class ConfigTest(unittest.TestCase): def test_config(self) -> None: d = config_to_dict() assert "NEGATIVE_FWID_CTR" not in d - - -if __name__ == "__main__": - unittest.main() diff --git a/fireworks/user_objects/firetasks/tests/test_dataflow_tasks.py b/fireworks/user_objects/firetasks/tests/test_dataflow_tasks.py index 098213fa1..7f737a570 100644 --- a/fireworks/user_objects/firetasks/tests/test_dataflow_tasks.py +++ b/fireworks/user_objects/firetasks/tests/test_dataflow_tasks.py @@ -310,7 +310,3 @@ def test_import_data_task(self) -> None: assert "value" in root["temperature"] assert root["temperature"]["units"] == temperature["units"] os.remove(filename) - - -if __name__ == "__main__": - unittest.main() diff --git a/fireworks/user_objects/firetasks/tests/test_fileio_tasks.py b/fireworks/user_objects/firetasks/tests/test_fileio_tasks.py index 81f34c0e2..eff1c1e98 100644 --- a/fireworks/user_objects/firetasks/tests/test_fileio_tasks.py +++ b/fireworks/user_objects/firetasks/tests/test_fileio_tasks.py @@ -63,7 +63,3 @@ def test_archive_dir(self) -> None: def tearDown(self) -> None: os.chdir(self.cwd) - - -if __name__ == "__main__": - unittest.main() diff --git a/fireworks/user_objects/firetasks/tests/test_filepad_tasks.py b/fireworks/user_objects/firetasks/tests/test_filepad_tasks.py index 7b807a681..831246fef 100644 --- a/fireworks/user_objects/firetasks/tests/test_filepad_tasks.py +++ b/fireworks/user_objects/firetasks/tests/test_filepad_tasks.py @@ -256,7 +256,3 @@ def test_addfilesfrompatterntask_run(self) -> None: def tearDown(self) -> None: self.fp.reset() - - -if __name__ == "__main__": - unittest.main() diff --git a/fireworks/user_objects/firetasks/tests/test_script_task.py b/fireworks/user_objects/firetasks/tests/test_script_task.py index f5d82c2b8..e176df0b7 100644 --- a/fireworks/user_objects/firetasks/tests/test_script_task.py +++ b/fireworks/user_objects/firetasks/tests/test_script_task.py @@ -92,7 +92,3 @@ def test_task_data_flow(self) -> None: action = PyTask(**params).run_task(spec) assert action.update_spec["first"] == 1 assert action.update_spec["second"] == 2 - - -if __name__ == "__main__": - unittest.main() diff --git a/fireworks/user_objects/firetasks/tests/test_templatewriter_task.py b/fireworks/user_objects/firetasks/tests/test_templatewriter_task.py index 6f2b40f20..4adbf0ba8 100644 --- a/fireworks/user_objects/firetasks/tests/test_templatewriter_task.py +++ b/fireworks/user_objects/firetasks/tests/test_templatewriter_task.py @@ -35,7 +35,3 @@ def test_task(self) -> None: os.remove("out_template.txt") if os.path.exists("test_template.txt"): os.remove("test_template.txt") - - -if __name__ == "__main__": - unittest.main() diff --git a/fireworks/user_objects/queue_adapters/tests/test_common_adapter.py b/fireworks/user_objects/queue_adapters/tests/test_common_adapter.py index 6f89c0cb5..5c4468483 100644 --- a/fireworks/user_objects/queue_adapters/tests/test_common_adapter.py +++ b/fireworks/user_objects/queue_adapters/tests/test_common_adapter.py @@ -119,7 +119,3 @@ def test_override(self) -> None: assert p._get_status_cmd("my_name") == ["my_qstatus", "-u", "my_name"] assert p.q_commands["PBS"]["submit_cmd"] == "my_qsubmit" - - -if __name__ == "__main__": - unittest.main() diff --git a/fireworks/utilities/tests/test_dagflow.py b/fireworks/utilities/tests/test_dagflow.py index a2b94d551..8509db812 100644 --- a/fireworks/utilities/tests/test_dagflow.py +++ b/fireworks/utilities/tests/test_dagflow.py @@ -175,7 +175,3 @@ def test_dagflow_view(self) -> None: dagf.to_dot(filename, view="controlflow") assert os.path.exists(filename) os.remove(filename) - - -if __name__ == "__main__": - unittest.main() diff --git a/fireworks/utilities/tests/test_filepad.py b/fireworks/utilities/tests/test_filepad.py index f2d628f1e..60308981d 100644 --- a/fireworks/utilities/tests/test_filepad.py +++ b/fireworks/utilities/tests/test_filepad.py @@ -57,7 +57,3 @@ def test_update_file_by_id(self) -> None: def tearDown(self) -> None: self.fp.reset() - - -if __name__ == "__main__": - unittest.main() diff --git a/fireworks/utilities/tests/test_fw_serializers.py b/fireworks/utilities/tests/test_fw_serializers.py index a317874bb..828181950 100644 --- a/fireworks/utilities/tests/test_fw_serializers.py +++ b/fireworks/utilities/tests/test_fw_serializers.py @@ -143,7 +143,3 @@ def setUp(self) -> None: def test_explicit_serialization(self) -> None: assert load_object(self.s_dict) == self.s_obj - - -if __name__ == "__main__": - unittest.main() diff --git a/fireworks/utilities/tests/test_update_collection.py b/fireworks/utilities/tests/test_update_collection.py index 620694bf1..44d95b4bf 100644 --- a/fireworks/utilities/tests/test_update_collection.py +++ b/fireworks/utilities/tests/test_update_collection.py @@ -44,7 +44,3 @@ def test_update_path(self) -> None: assert test_doc["foo_list"][1]["foo2"] == "foo/new/path/bar" test_doc_archived = self.lp.db[f"test_coll_xiv_{datetime.date.today()}"].find_one() assert test_doc_archived["foo_list"][1]["foo2"] == "foo/old/path/bar" - - -if __name__ == "__main__": - unittest.main() From f47c01bb5fa5f75eb4665425a14a9fa16315e118 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Tue, 9 Apr 2024 14:04:44 +0200 Subject: [PATCH 5/8] maybe fix RocketTest.test_postproc_exception ______________________ RocketTest.test_postproc_exception ______________________ self = def test_postproc_exception(self) -> None: fw = Firework(MalformedAdditionTask()) self.lp.add_wf(fw) launch_rocket(self.lp, self.fworker) fw = self.lp.get_fw_by_id(1) > assert fw.state == "FIZZLED" E AssertionError: assert 'COMPLETED' == 'FIZZLED' E E - FIZZLED E + COMPLETED --- fireworks/core/launchpad.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fireworks/core/launchpad.py b/fireworks/core/launchpad.py index 80415ac70..a44a4370c 100644 --- a/fireworks/core/launchpad.py +++ b/fireworks/core/launchpad.py @@ -49,7 +49,7 @@ def sort_aggregation(sort): Args: sort [(str,int)]: sorting keys and directions as a list of - (str, int) tuples, i.e. [('updated_on', 1)] + (str, int) tuples, i.e. [('updated_on', 1)] """ # Fix for sorting by dates which are actually stored as strings: # Not sure about the underlying issue's source, but apparently some @@ -887,7 +887,7 @@ def future_run_exists(self, fworker=None) -> bool: return True # retrieve all [RUNNING/RESERVED] fireworks q = fworker.query if fworker else {} - q.update(state={"$in": ["RUNNING", "RESERVED"]}) + q.update({"state": {"$in": ["RUNNING", "RESERVED"]}}) active = self.get_fw_ids(q) # then check if they have WAITING children for fw_id in active: From 994af3a70f978798e28ae62efd68af05d78b44c9 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Tue, 9 Apr 2024 14:17:37 +0200 Subject: [PATCH 6/8] breaking: Workflow.from_firework method name capitalization --- fireworks/core/firework.py | 4 ++-- fireworks/core/launchpad.py | 2 +- .../examples/custom_firetasks/hello_world/hello_world_run.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/fireworks/core/firework.py b/fireworks/core/firework.py index 93fd18c41..58ee39624 100644 --- a/fireworks/core/firework.py +++ b/fireworks/core/firework.py @@ -1331,10 +1331,10 @@ def from_dict(cls, m_dict: dict[str, Any]) -> Workflow: created_on, updated_on, ) - return Workflow.from_Firework(Firework.from_dict(m_dict)) + return Workflow.from_firework(Firework.from_dict(m_dict)) @classmethod - def from_Firework(cls, fw: Firework, name: str | None = None, metadata=None) -> Workflow: + def from_firework(cls, fw: Firework, name: str | None = None, metadata=None) -> Workflow: """ Return Workflow from the given Firework. diff --git a/fireworks/core/launchpad.py b/fireworks/core/launchpad.py index a44a4370c..4afc4ec01 100644 --- a/fireworks/core/launchpad.py +++ b/fireworks/core/launchpad.py @@ -373,7 +373,7 @@ def add_wf(self, wf, reassign_all=True): dict: mapping between old and new Firework ids """ if isinstance(wf, Firework): - wf = Workflow.from_Firework(wf) + wf = Workflow.from_firework(wf) # sets the root FWs as READY # prefer to wf.refresh() for speed reasons w/many root FWs for fw_id in wf.root_fw_ids: diff --git a/fireworks/examples/custom_firetasks/hello_world/hello_world_run.py b/fireworks/examples/custom_firetasks/hello_world/hello_world_run.py index 8f059d3cf..ae5b7f124 100644 --- a/fireworks/examples/custom_firetasks/hello_world/hello_world_run.py +++ b/fireworks/examples/custom_firetasks/hello_world/hello_world_run.py @@ -9,7 +9,7 @@ # create the workflow and store it in the database my_fw = Firework([HelloTask()]) - my_wflow = Workflow.from_Firework(my_fw) + my_wflow = Workflow.from_firework(my_fw) lp.add_wf(my_wflow) # run the workflow From 22ab2e5390416df1edf20ef5b003c9e54594146c Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Tue, 9 Apr 2024 14:18:08 +0200 Subject: [PATCH 7/8] test.yml move pip install . into deps install step --- .github/workflows/test.yml | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1326ed539..89ba28317 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -28,9 +28,7 @@ jobs: - name: Install dependencies run: | pip install -r requirements.txt -r requirements-ci.txt + pip install '.[workflow-checks,graph-plotting,flask-plotting]' - name: Run fireworks tests - shell: bash -l {0} - run: | - pip install .[workflow-checks,graph-plotting,flask-plotting] - pytest fireworks + run: pytest fireworks From 884c3ecd8755fe5ab83a6c8a29794ae592b2ea4f Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Tue, 9 Apr 2024 15:49:58 +0200 Subject: [PATCH 8/8] Revert "simplify code" This reverts commit 249a543190776f9d2f131e94c3ac2080e185404b. --- fireworks/core/firework.py | 18 +++---- fireworks/core/launchpad.py | 47 ++++++++++++------- fireworks/features/stats.py | 10 ++-- fireworks/queue/queue_launcher.py | 2 +- .../queue_adapters/common_adapter.py | 2 +- 5 files changed, 45 insertions(+), 34 deletions(-) diff --git a/fireworks/core/firework.py b/fireworks/core/firework.py index 58ee39624..b043d0102 100644 --- a/fireworks/core/firework.py +++ b/fireworks/core/firework.py @@ -153,9 +153,9 @@ def __init__( not only to direct children, but to all dependent FireWorks down to the Workflow's leaves. """ - mod_spec = mod_spec or [] - additions = additions or [] - detours = detours or [] + mod_spec = mod_spec if mod_spec is not None else [] + additions = additions if additions is not None else [] + detours = detours if detours is not None else [] self.stored_data = stored_data if stored_data else {} self.exit = exit @@ -267,13 +267,13 @@ def __init__( NEGATIVE_FWID_CTR -= 1 self.fw_id = NEGATIVE_FWID_CTR - self.launches = launches or [] - self.archived_launches = archived_launches or [] + self.launches = launches if launches else [] + self.archived_launches = archived_launches if archived_launches else [] self.created_on = created_on or datetime.utcnow() self.updated_on = updated_on or datetime.utcnow() parents = [parents] if isinstance(parents, Firework) else parents - self.parents = parents or [] + self.parents = parents if parents else [] self._state = state @@ -476,9 +476,9 @@ def __init__( self.fworker = fworker or FWorker() self.host = host or get_my_host() self.ip = ip or get_my_ip() - self.trackers = trackers or [] + self.trackers = trackers if trackers else [] self.action = action if action else None - self.state_history = state_history or [] + self.state_history = state_history if state_history else [] self.state = state self.launch_id = launch_id self.fw_id = fw_id @@ -643,7 +643,7 @@ def _update_state_history(self, state) -> None: now_time = datetime.utcnow() new_history_entry = {"state": state, "created_on": now_time} if state != "COMPLETED" and last_checkpoint: - new_history_entry.update(checkpoint=last_checkpoint) + new_history_entry.update({"checkpoint": last_checkpoint}) self.state_history.append(new_history_entry) if state in ["RUNNING", "RESERVED"]: self.touch_history() # add updated_on key diff --git a/fireworks/core/launchpad.py b/fireworks/core/launchpad.py index 4afc4ec01..c43a343dd 100644 --- a/fireworks/core/launchpad.py +++ b/fireworks/core/launchpad.py @@ -185,15 +185,15 @@ def __init__( self.password = password self.authsource = authsource or self.name self.mongoclient_kwargs = mongoclient_kwargs or {} - self.uri_mode = bool(uri_mode) + self.uri_mode = uri_mode # set up logger self.logdir = logdir self.strm_lvl = strm_lvl if strm_lvl else "INFO" self.m_logger = get_fw_logger("launchpad", l_dir=self.logdir, stream_level=self.strm_lvl) - self.user_indices = user_indices or [] - self.wf_user_indices = wf_user_indices or [] + self.user_indices = user_indices if user_indices else [] + self.wf_user_indices = wf_user_indices if wf_user_indices else [] # get connection if uri_mode: @@ -267,20 +267,31 @@ def update_spec(self, fw_ids, spec_document, mongo=False) -> None: ) @classmethod - def from_dict(cls, dct): + def from_dict(cls, d): + port = d.get("port", None) + name = d.get("name", None) + username = d.get("username", None) + password = d.get("password", None) + logdir = d.get("logdir", None) + strm_lvl = d.get("strm_lvl", None) + user_indices = d.get("user_indices", []) + wf_user_indices = d.get("wf_user_indices", []) + authsource = d.get("authsource", None) + uri_mode = d.get("uri_mode", False) + mongoclient_kwargs = d.get("mongoclient_kwargs", None) return LaunchPad( - dct["host"], - port=dct.get("port"), - name=dct.get("name"), - username=dct.get("username"), - password=dct.get("password"), - logdir=dct.get("logdir"), - strm_lvl=dct.get("strm_lvl"), - user_indices=dct.get("user_indices"), - wf_user_indices=dct.get("wf_user_indices"), - authsource=dct.get("authsource"), - uri_mode=dct.get("uri_mode", False), - mongoclient_kwargs=dct.get("mongoclient_kwargs"), + d["host"], + port, + name, + username, + password, + logdir, + strm_lvl, + user_indices, + wf_user_indices, + authsource, + uri_mode, + mongoclient_kwargs, ) @classmethod @@ -1659,7 +1670,7 @@ def rerun_fw(self, fw_id, rerun_duplicates=True, recover_launch=None, recover_mo # Launch recovery if recover_launch is not None: recovery = self.get_recovery(fw_id, recover_launch) - recovery.update(_mode=recover_mode) + recovery.update({"_mode": recover_mode}) set_spec = recursive_dict({"$set": {"spec._recovery": recovery}}) if recover_mode == "prev_dir": prev_dir = self.get_launch_by_id(recovery.get("_launch_id")).launch_dir @@ -1703,7 +1714,7 @@ def get_recovery(self, fw_id, launch_id="last"): m_fw = self.get_fw_by_id(fw_id) launch = m_fw.launches[-1] if launch_id == "last" else self.get_launch_by_id(launch_id) recovery = launch.state_history[-1].get("checkpoint") - recovery.update(_prev_dir=launch.launch_dir, _launch_id=launch.launch_id) + recovery.update({"_prev_dir": launch.launch_dir, "_launch_id": launch.launch_id}) return recovery def _refresh_wf(self, fw_id) -> None: diff --git a/fireworks/features/stats.py b/fireworks/features/stats.py index 402e5ef6b..c381e64d6 100644 --- a/fireworks/features/stats.py +++ b/fireworks/features/stats.py @@ -197,8 +197,8 @@ def group_fizzled_fireworks( "created_on": self._query_datetime_range(start_time=query_start, end_time=query_end, **args), } if include_ids: - project_query.update(fw_id=1) - group_query.update(fw_id={"$push": "$fw_id"}) + project_query.update({"fw_id": 1}) + group_query.update({"fw_id": {"$push": "$fw_id"}}) if query: match_query.update(query) return self._aggregate( @@ -306,11 +306,11 @@ def _get_summary( } match_query.update(query) if runtime_stats: - project_query.update(runtime_secs=1) + project_query.update({"runtime_secs": 1}) group_query.update(RUNTIME_STATS) if include_ids: project_query.update({id_field: 1}) - group_query.update(ids={"$push": "$" + id_field}) + group_query.update({"ids": {"$push": "$" + id_field}}) return self._aggregate( coll=coll, match=match_query, @@ -357,7 +357,7 @@ def _aggregate( for arg in [match, project, unwind, group_op]: if arg is None: arg = {} - group_op.update(_id=f"${group_by}") + group_op.update({"_id": "$" + group_by}) if sort is None: sort_query = ("_id", 1) query = [{"$match": match}, {"$project": project}, {"$group": group_op}, {"$sort": SON([sort_query])}] diff --git a/fireworks/queue/queue_launcher.py b/fireworks/queue/queue_launcher.py index 781ee59c6..b99827c93 100644 --- a/fireworks/queue/queue_launcher.py +++ b/fireworks/queue/queue_launcher.py @@ -96,7 +96,7 @@ def launch_rocket_to_queue( # update qadapter job_name based on FW name job_name = get_slug(fw.name)[0:QUEUE_JOBNAME_MAXLEN] - qadapter.update(job_name=job_name) + qadapter.update({"job_name": job_name}) if "_queueadapter" in fw.spec: l_logger.debug("updating queue params using Firework spec..") diff --git a/fireworks/user_objects/queue_adapters/common_adapter.py b/fireworks/user_objects/queue_adapters/common_adapter.py index 259d6ada7..6e1380b25 100644 --- a/fireworks/user_objects/queue_adapters/common_adapter.py +++ b/fireworks/user_objects/queue_adapters/common_adapter.py @@ -66,7 +66,7 @@ def __init__(self, q_type, q_name=None, template_file=None, timeout=None, **kwar ) self.q_name = q_name or q_type self.timeout = timeout or 5 - self.update(kwargs) + self.update(dict(kwargs)) self.q_commands = copy.deepcopy(CommonAdapter.default_q_commands) if "_q_commands_override" in self: