From 20fec4676e01934d9b6c58c39899c85db782fa71 Mon Sep 17 00:00:00 2001 From: Tomer Ashauch Date: Thu, 3 Aug 2023 21:22:17 +0300 Subject: [PATCH] fixing key error bug when using gpt-j-6B with memit --- rome/repr_tools.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/rome/repr_tools.py b/rome/repr_tools.py index 1dc22c8..8ee38fa 100644 --- a/rome/repr_tools.py +++ b/rome/repr_tools.py @@ -128,6 +128,8 @@ def _batch(n): def _process(cur_repr, batch_idxs, key): nonlocal to_return + if len(cur_repr)==0: + return cur_repr = cur_repr[0] if type(cur_repr) is tuple else cur_repr for i, idx_list in enumerate(batch_idxs): to_return[key].append(cur_repr[i][idx_list].mean(0)) @@ -154,6 +156,6 @@ def _process(cur_repr, batch_idxs, key): to_return = {k: torch.stack(v, 0) for k, v in to_return.items() if len(v) > 0} if len(to_return) == 1: - return to_return["in"] if tin else to_return["out"] - else: - return to_return["in"], to_return["out"] + single_key = list(to_return.keys())[0] + dummy_tensor = torch.zeros_like(to_return[single_key], device="cuda") + return (to_return["in"], dummy_tensor) if (single_key=="in" and tin) else (dummy_tensor, to_return["out"])