From 82a7228faac9f2e39bd138a61845ba90d5efb2e3 Mon Sep 17 00:00:00 2001 From: Philipp Gadow Date: Wed, 8 Nov 2023 09:38:01 -0600 Subject: [PATCH] modify top label in Xbb and add catch for legacy tagger plotting --- puma/hlplots/results.py | 16 +++++++++++++--- puma/tests/hlplots/test_results.py | 21 +++++++++++++++++++++ puma/tests/hlplots/test_tagger.py | 4 ++-- requirements.txt | 2 +- 4 files changed, 37 insertions(+), 6 deletions(-) diff --git a/puma/hlplots/results.py b/puma/hlplots/results.py index fc6a9434..4bfc1450 100644 --- a/puma/hlplots/results.py +++ b/puma/hlplots/results.py @@ -48,9 +48,9 @@ def __post_init__(self): elif self.signal == Flavours.cjets: self.backgrounds = [Flavours.bjets, Flavours.ujets] elif self.signal == Flavours.hbb: - self.backgrounds = [Flavours.hcc, Flavours.top, Flavours.qcd] + self.backgrounds = [Flavours.hcc, Flavours.tqqb, Flavours.qcd] elif self.signal == Flavours.hcc: - self.backgrounds = [Flavours.hbb, Flavours.top, Flavours.qcd] + self.backgrounds = [Flavours.hbb, Flavours.tqqb, Flavours.qcd] else: raise ValueError(f"Unsupported signal class {self.signal}.") @@ -153,7 +153,17 @@ def check_nan(data: np.ndarray) -> np.ndarray: # load data reader = H5Reader(file_path, precision="full") - data = reader.load({key: var_list}, num_jets)[key] + try: + data = reader.load({key: var_list}, num_jets)[key] + except Exception as e: + # backward-compatibility for first versions of Xbb tagger in which + # fully contained top class is called top and not tqqb change names + if "tqqb" not in var_list: raise e + new_var_list = [v if v=="tqqb" else "top" for v in var_list] + var_list = new_var_list + new_backgrounds = [b if b!=Flavours.tqqb else Flavours.top for b in self.backgrounds] + self.backgrounds = new_backgrounds + data = reader.load({key: var_list}, num_jets)[key] # check for nan values data = check_nan(data) diff --git a/puma/tests/hlplots/test_results.py b/puma/tests/hlplots/test_results.py index 5da4a1a6..be1a31d8 100644 --- a/puma/tests/hlplots/test_results.py +++ b/puma/tests/hlplots/test_results.py @@ -74,6 +74,27 @@ def test_add_taggers_with_cuts(self): self.assertEqual(list(results.taggers.values()), taggers) def test_add_taggers_hbb(self): + # get mock file and rename variables match hbb + f = get_mock_file()[1] + d = {} + d["R10TruthLabel"] = f["jets"]["HadronConeExclTruthLabelID"] + d["MockTagger_phbb"] = f["jets"]["MockTagger_pb"] + d["MockTagger_phcc"] = f["jets"]["MockTagger_pc"] + d["MockTagger_ptqqb"] = f["jets"]["MockTagger_pu"] + d["MockTagger_pqcd"] = f["jets"]["MockTagger_pu"] + d["pt"] = f["jets"]["pt"] + array = structured_from_dict(d) + with tempfile.TemporaryDirectory() as tmp_file: + fname = Path(tmp_file) / "test.h5" + with h5py.File(fname, "w") as f: + f.create_dataset("jets", data=array) + + results = Results(signal="hbb", sample="test") + results.add_taggers_from_file( + [Tagger("MockTagger")], fname, label_var="R10TruthLabel" + ) + + def test_add_taggers_hbb_legacy(self): # get mock file and rename variables match hbb f = get_mock_file()[1] d = {} diff --git a/puma/tests/hlplots/test_tagger.py b/puma/tests/hlplots/test_tagger.py index 08abcbe5..b15eaa1b 100644 --- a/puma/tests/hlplots/test_tagger.py +++ b/puma/tests/hlplots/test_tagger.py @@ -175,14 +175,14 @@ def test_disc_hbb_calc(self): from ftag import Flavours as F tagger = Tagger( - "dummy", output_flavours=[F["hbb"], F["hcc"], F["top"], F["qcd"]] + "dummy", output_flavours=[F["hbb"], F["hcc"], F["tqqb"], F["qcd"]] ) tagger.scores = u2s( np.column_stack((np.ones(10), np.ones(10), np.ones(10), np.ones(10))), dtype=[ ("dummy_phbb", "f4"), ("dummy_phcc", "f4"), - ("dummy_ptop", "f4"), + ("dummy_ptqqb", "f4"), ("dummy_pqcd", "f4"), ], ) diff --git a/requirements.txt b/requirements.txt index 10238dd6..d9efb729 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,4 +16,4 @@ scipy==1.10.1 tables==3.7.0 testfixtures==7.0.0 palettable==3.3.0 -atlas-ftag-tools==0.1.11 +atlas-ftag-tools==0.1.12