Skip to content

Commit

Permalink
Merge pull request materialsproject#527 from materialsproject/simplify
Browse files Browse the repository at this point in the history
Simplify code
  • Loading branch information
janosh authored Apr 9, 2024
2 parents 5fd3884 + 884c3ec commit ab775d9
Show file tree
Hide file tree
Showing 26 changed files with 60 additions and 133 deletions.
6 changes: 2 additions & 4 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ jobs:
- name: Install dependencies
run: |
pip install -r requirements.txt -r requirements-ci.txt
pip install '.[workflow-checks,graph-plotting,flask-plotting]'
- name: Run fireworks tests
shell: bash -l {0}
run: |
pip install .[workflow-checks,graph-plotting,flask-plotting]
pytest fireworks
run: pytest fireworks
8 changes: 4 additions & 4 deletions fireworks/core/firework.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ def __init__(
fw_id (int): id of the Firework this Launch is running.
"""
if state not in Firework.STATE_RANKS:
raise ValueError(f"Invalid launch state: {state}")
raise ValueError(f"Invalid launch {state=}")
self.launch_dir = launch_dir
self.fworker = fworker or FWorker()
self.host = host or get_my_host()
Expand Down Expand Up @@ -1029,7 +1029,7 @@ def append_wf(self, new_wf, fw_ids, detour=False, pull_spec_mods=False):
ready_run = [(f >= 0 and Firework.STATE_RANKS[self.fw_states[f]] > 1) for f in self.links[fw_id]]
if any(ready_run):
raise ValueError(
f"fw_id: {fw_id}: Detour option only works if all children "
f"{fw_id=}: Detour option only works if all children "
"of detours are not READY to run and have not already run"
)

Expand Down Expand Up @@ -1331,10 +1331,10 @@ def from_dict(cls, m_dict: dict[str, Any]) -> Workflow:
created_on,
updated_on,
)
return Workflow.from_Firework(Firework.from_dict(m_dict))
return Workflow.from_firework(Firework.from_dict(m_dict))

@classmethod
def from_Firework(cls, fw: Firework, name: str | None = None, metadata=None) -> Workflow:
def from_firework(cls, fw: Firework, name: str | None = None, metadata=None) -> Workflow:
"""
Return Workflow from the given Firework.
Expand Down
18 changes: 9 additions & 9 deletions fireworks/core/launchpad.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def sort_aggregation(sort):
Args:
sort [(str,int)]: sorting keys and directions as a list of
(str, int) tuples, i.e. [('updated_on', 1)]
(str, int) tuples, i.e. [('updated_on', 1)]
"""
# Fix for sorting by dates which are actually stored as strings:
# Not sure about the underlying issue's source, but apparently some
Expand Down Expand Up @@ -384,7 +384,7 @@ def add_wf(self, wf, reassign_all=True):
dict: mapping between old and new Firework ids
"""
if isinstance(wf, Firework):
wf = Workflow.from_Firework(wf)
wf = Workflow.from_firework(wf)
# sets the root FWs as READY
# prefer to wf.refresh() for speed reasons w/many root FWs
for fw_id in wf.root_fw_ids:
Expand Down Expand Up @@ -471,7 +471,7 @@ def get_launch_by_id(self, launch_id):
if m_launch:
m_launch["action"] = get_action_from_gridfs(m_launch.get("action"), self.gridfs_fallback)
return Launch.from_dict(m_launch)
raise ValueError(f"No Launch exists with launch_id: {launch_id}")
raise ValueError(f"No Launch exists with {launch_id=}")

def get_fw_dict_by_id(self, fw_id):
"""
Expand Down Expand Up @@ -524,7 +524,7 @@ def get_wf_by_fw_id(self, fw_id):
"""
links_dict = self.workflows.find_one({"nodes": fw_id})
if not links_dict:
raise ValueError(f"Could not find a Workflow with fw_id: {fw_id}")
raise ValueError(f"Could not find a Workflow with {fw_id=}")
fws = map(self.get_fw_by_id, links_dict["nodes"])
return Workflow(
fws,
Expand All @@ -546,7 +546,7 @@ def get_wf_by_fw_id_lzyfw(self, fw_id: int) -> Workflow:
"""
links_dict = self.workflows.find_one({"nodes": fw_id})
if not links_dict:
raise ValueError(f"Could not find a Workflow with fw_id: {fw_id}")
raise ValueError(f"Could not find a Workflow with {fw_id=}")

fws = [
LazyFirework(fw_id, self.fireworks, self.launches, self.gridfs_fallback) for fw_id in links_dict["nodes"]
Expand Down Expand Up @@ -972,7 +972,7 @@ def pause_fw(self, fw_id):
if f:
self._refresh_wf(fw_id)
if not f:
self.m_logger.error(f"No pausable (WAITING,READY,RESERVED) Firework exists with fw_id: {fw_id}")
self.m_logger.error(f"No pausable (WAITING,READY,RESERVED) Firework exists with {fw_id=}")
return f

def defuse_fw(self, fw_id, rerun_duplicates=True):
Expand Down Expand Up @@ -1439,7 +1439,7 @@ def checkout_fw(self, fworker, launch_dir, fw_id=None, host=None, ip=None, state
# insert the launch
self.launches.find_one_and_replace({"launch_id": m_launch.launch_id}, m_launch.to_db_dict(), upsert=True)

self.m_logger.debug(f"Created/updated Launch with launch_id: {launch_id}")
self.m_logger.debug(f"Created/updated Launch with {launch_id=}")

# update the firework's launches
if not reserved_launch:
Expand Down Expand Up @@ -1684,9 +1684,9 @@ def rerun_fw(self, fw_id, rerun_duplicates=True, recover_launch=None, recover_mo

# rerun this FW
if m_fw["state"] in ["ARCHIVED", "DEFUSED"]:
self.m_logger.info(f"Cannot rerun fw_id: {fw_id}: it is {m_fw['state']}.")
self.m_logger.info(f"Cannot rerun {fw_id=}: it is {m_fw['state']}.")
elif m_fw["state"] == "WAITING" and not recover_launch:
self.m_logger.debug(f"Skipping rerun fw_id: {fw_id}: it is already WAITING.")
self.m_logger.debug(f"Skipping rerun {fw_id=}: it is already WAITING.")
else:
with WFLock(self, fw_id):
wf = self.get_wf_by_fw_id_lzyfw(fw_id)
Expand Down
4 changes: 0 additions & 4 deletions fireworks/core/tests/test_firework.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,3 @@ def test_iter_len_index(self) -> None:
assert len(wflow) == len(fws)

assert wflow[0] == self.fw1


if __name__ == "__main__":
unittest.main()
4 changes: 0 additions & 4 deletions fireworks/core/tests/test_launchpad.py
Original file line number Diff line number Diff line change
Expand Up @@ -1390,7 +1390,3 @@ def test_many_detours_offline(self) -> None:

launch_full = self.lp.get_launch_by_id(1)
assert len(launch_full.action.detours) == 2000


if __name__ == "__main__":
unittest.main()
4 changes: 0 additions & 4 deletions fireworks/core/tests/test_rocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,3 @@ def test_postproc_exception(self) -> None:
fw = self.lp.get_fw_by_id(1)

assert fw.state == "FIZZLED"


if __name__ == "__main__":
unittest.main()
4 changes: 0 additions & 4 deletions fireworks/core/tests/test_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,3 @@ def add_wf(j, dest, tracker, name) -> None:
pwd = os.getcwd()
for ldir in glob.glob(os.path.join(pwd, "launcher_*")):
shutil.rmtree(ldir)


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

# create the workflow and store it in the database
my_fw = Firework([HelloTask()])
my_wflow = Workflow.from_Firework(my_fw)
my_wflow = Workflow.from_firework(my_fw)
lp.add_wf(my_wflow)

# run the workflow
Expand Down
37 changes: 18 additions & 19 deletions fireworks/flask_site/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,23 +114,22 @@ def home():

# Newest Workflows table data
wfs_shown = app.lp.workflows.find(_addq_WF({}), limit=PER_PAGE, sort=[("_id", DESCENDING)])
wf_info = []
for item in wfs_shown:
wf_info.append(
{
"id": item["nodes"][0],
"name": item["name"],
"state": item["state"],
"fireworks": list(
app.lp.fireworks.find(
{"fw_id": {"$in": item["nodes"]}},
limit=PER_PAGE,
sort=[("fw_id", DESCENDING)],
projection=["state", "name", "fw_id"],
)
),
}
)
wf_info = [
{
"id": item["nodes"][0],
"name": item["name"],
"state": item["state"],
"fireworks": list(
app.lp.fireworks.find(
{"fw_id": {"$in": item["nodes"]}},
limit=PER_PAGE,
sort=[("fw_id", DESCENDING)],
projection=["state", "name", "fw_id"],
)
),
}
for item in wfs_shown
]

PLOTTING = False
try:
Expand All @@ -144,7 +143,7 @@ def home():
@app.route("/fw/<int:fw_id>/details")
@requires_auth
def get_fw_details(fw_id):
# just fill out whatever attributse you want to see per step, then edit the handlebars template in
# just fill out whatever attributes you want to see per step, then edit the handlebars template in
# wf_details.html
# to control their display
fw = app.lp.get_fw_dict_by_id(fw_id)
Expand All @@ -162,7 +161,7 @@ def fw_details(fw_id):
try:
int(fw_id)
except Exception:
raise ValueError(f"Invalid fw_id: {fw_id}")
raise ValueError(f"Invalid {fw_id=}")
fw = app.lp.get_fw_dict_by_id(fw_id)
fw = json.loads(json.dumps(fw, default=DATETIME_HANDLER)) # formats ObjectIds
return render_template("fw_details.html", **locals())
Expand Down
30 changes: 15 additions & 15 deletions fireworks/scripts/lpad_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def init_yaml(args: Namespace) -> None:
print("Please supply the following configuration values")
print("(press Enter if you want to accept the defaults)\n")
for k, default, helptext in fields:
val = input(f"Enter {k} parameter. (default: {default}). {helptext}: ")
val = input(f"Enter {k} parameter. ({default=}). {helptext}: ")
doc[k] = val or default
if "port" in doc:
doc["port"] = int(doc["port"]) # enforce the port as an int
Expand Down Expand Up @@ -627,7 +627,7 @@ def rerun_fws(args: Namespace) -> None:
launch_ids = [None] * len(fw_ids)
for fw_id, l_id in zip(fw_ids, launch_ids):
lp.rerun_fw(int(fw_id), recover_launch=l_id, recover_mode=args.recover_mode)
lp.m_logger.debug(f"Processed fw_id: {fw_id}")
lp.m_logger.debug(f"Processed {fw_id=}")
lp.m_logger.info(f"Finished setting {len(fw_ids)} FWs to rerun")


Expand All @@ -645,10 +645,10 @@ def refresh(args: Namespace) -> None:
def unlock(args: Namespace) -> None:
lp = get_lp(args)
fw_ids = parse_helper(lp, args, wf_mode=True)
for f in fw_ids:
with WFLock(lp, f, expire_secs=0, kill=True):
lp.m_logger.warning(f"FORCIBLY RELEASING LOCK DUE TO USER COMMAND, WF: {f}")
lp.m_logger.debug(f"Processed Workflow with fw_id: {f}")
for fw_id in fw_ids:
with WFLock(lp, fw_id, expire_secs=0, kill=True):
lp.m_logger.warning(f"FORCIBLY RELEASING LOCK DUE TO USER COMMAND, WF: {fw_id}")
lp.m_logger.debug(f"Processed Workflow with {fw_id=}")
lp.m_logger.info(f"Finished unlocking {len(fw_ids)} Workflows")


Expand Down Expand Up @@ -771,7 +771,7 @@ def forget_offline(args: Namespace) -> None:
for f in fw_ids:
lp.forget_offline(f, launch_mode=False)
lp.m_logger.debug(f"Processed fw_id: {f}")
lp.m_logger.info(f"Finished forget_offine, processed {len(fw_ids)} FWs")
lp.m_logger.info(f"Finished forget_offline, processed {len(fw_ids)} FWs")


def report(args: Namespace) -> None:
Expand Down Expand Up @@ -813,16 +813,16 @@ def track_fws(args: Namespace) -> None:
include = args.include
exclude = args.exclude
first_print = True # used to control newline
for f in fw_ids:
data = lp.get_tracker_data(f)
for fw_id in fw_ids:
data = lp.get_tracker_data(fw_id)
output = []
for d in data:
for t in d["trackers"]:
if (not include or t.filename in include) and (not exclude or t.filename not in exclude):
output.extend((f"## Launch id: {d['launch_id']}", str(t)))
for dct in data:
for tracker in dct["trackers"]:
if (not include or tracker.filename in include) and (not exclude or tracker.filename not in exclude):
output.extend((f"## Launch id: {dct['launch_id']}", str(tracker)))
if output:
name = lp.fireworks.find_one({"fw_id": f}, {"name": 1})["name"]
output.insert(0, f"# FW id: {f}, FW name: {name}")
name = lp.fireworks.find_one({"fw_id": fw_id}, {"name": 1})["name"]
output.insert(0, f"# FW id: {fw_id}, FW {name=}")
if first_print:
first_print = False
else:
Expand Down
4 changes: 0 additions & 4 deletions fireworks/tests/master_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,3 @@ def test_recursive_deserialize(self) -> None:
"defuse_children": False,
}
FWAction.from_dict(my_dict)


if __name__ == "__main__":
unittest.main()
6 changes: 1 addition & 5 deletions fireworks/tests/mongo_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ def test_append_wf(self) -> None:
assert new_fw.spec["dummy2"] == [True]

new_wf = Workflow([Firework([ModSpecTask()])])
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="Cannot append to a FW that is not in the original Workflow"):
self.lp.append_wf(new_wf, [4], detour=True)

def test_append_wf_detour(self) -> None:
Expand Down Expand Up @@ -608,7 +608,3 @@ def test_stats(self) -> None:
launch_rocket(self.lp, self.fworker)
workflow_results = s.get_workflow_summary(time_field="updated_on")
assert (workflow_results[0]["_id"], workflow_results[0]["count"]) == ("COMPLETED", 3)


if __name__ == "__main__":
unittest.main()
4 changes: 0 additions & 4 deletions fireworks/tests/multiprocessing_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,3 @@ def test_early_exit(self) -> None:
with open(os.path.join(fw3.launches[0].launch_dir, "task.out")) as f:
fw3_text = f.read()
assert fw2_text != fw3_text


if __name__ == "__main__":
unittest.main()
4 changes: 0 additions & 4 deletions fireworks/tests/test_fw_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,3 @@ class ConfigTest(unittest.TestCase):
def test_config(self) -> None:
d = config_to_dict()
assert "NEGATIVE_FWID_CTR" not in d


if __name__ == "__main__":
unittest.main()
4 changes: 0 additions & 4 deletions fireworks/user_objects/firetasks/tests/test_dataflow_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,3 @@ def test_import_data_task(self) -> None:
assert "value" in root["temperature"]
assert root["temperature"]["units"] == temperature["units"]
os.remove(filename)


if __name__ == "__main__":
unittest.main()
4 changes: 0 additions & 4 deletions fireworks/user_objects/firetasks/tests/test_fileio_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,3 @@ def test_archive_dir(self) -> None:

def tearDown(self) -> None:
os.chdir(self.cwd)


if __name__ == "__main__":
unittest.main()
4 changes: 0 additions & 4 deletions fireworks/user_objects/firetasks/tests/test_filepad_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,3 @@ def test_addfilesfrompatterntask_run(self) -> None:

def tearDown(self) -> None:
self.fp.reset()


if __name__ == "__main__":
unittest.main()
4 changes: 0 additions & 4 deletions fireworks/user_objects/firetasks/tests/test_script_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,3 @@ def test_task_data_flow(self) -> None:
action = PyTask(**params).run_task(spec)
assert action.update_spec["first"] == 1
assert action.update_spec["second"] == 2


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,3 @@ def test_task(self) -> None:
os.remove("out_template.txt")
if os.path.exists("test_template.txt"):
os.remove("test_template.txt")


if __name__ == "__main__":
unittest.main()
6 changes: 3 additions & 3 deletions fireworks/user_objects/firetasks/unittest_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
__date__ = "Jan 21, 2014"


class TestSerializer(FWSerializable):
class UnitTestSerializer(FWSerializable):
_fw_name = "TestSerializer Name"

def __init__(self, a, m_date) -> None:
Expand All @@ -28,7 +28,7 @@ def to_dict(self):

@classmethod
def from_dict(cls, m_dict):
return TestSerializer(m_dict["a"], m_dict["m_date"])
return cls(m_dict["a"], m_dict["m_date"])


class ExportTestSerializer(FWSerializable):
Expand All @@ -45,4 +45,4 @@ def to_dict(self):

@classmethod
def from_dict(cls, m_dict):
return ExportTestSerializer(m_dict["a"])
return cls(m_dict["a"])
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,3 @@ def test_override(self) -> None:

assert p._get_status_cmd("my_name") == ["my_qstatus", "-u", "my_name"]
assert p.q_commands["PBS"]["submit_cmd"] == "my_qsubmit"


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit ab775d9

Please sign in to comment.