Skip to content

Commit

Permalink
use separate save and param dict
Browse files Browse the repository at this point in the history
  • Loading branch information
jykr committed Mar 30, 2024
1 parent 65ad2cb commit 8314104
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 3 additions & 3 deletions bean/cli/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def main(args):
with open(f"{prefix}/{model_label}.result.pkl", "rb") as handle:
param_history_dict = pkl.load(handle)
else:
param_history_dict = deepcopy(
param_history_dict, save_dict = deepcopy(
run_inference(model, guide, ndata, num_steps=args.n_iter)
)
if args.fit_negctrl:
Expand All @@ -167,7 +167,7 @@ def main(args):
f"Using {len(negctrl_idx)} negative control elements to adjust phenotypic effect sizes..."
)
ndata_negctrl = ndata[negctrl_idx]
param_history_dict["negctrl"] = deepcopy(
param_history_dict["negctrl"], save_dict["negctrl"] = deepcopy(
run_inference(
negctrl_model, negctrl_guide, ndata_negctrl, num_steps=args.n_iter
)
Expand All @@ -181,7 +181,7 @@ def main(args):
os.makedirs(prefix)
with open(f"{prefix}/{model_label}.result{args.result_suffix}.pkl", "wb") as handle:
# try:
pkl.dump(param_history_dict, handle)
pkl.dump(save_dict, handle)
# except TypeError as exc:
# print(exc.message)
# print(param_history_dict)
Expand Down
2 changes: 1 addition & 1 deletion bean/model/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def run_inference(
raise ValueError(
f"Fitting halted for command: {' '.join(sys.argv)} with following error: \n {exc}"
)
return {
return pyro.get_param_store(), {
"loss": losses,
"params": pyro.get_param_store().get_state(),
}
Expand Down

0 comments on commit 8314104

Please sign in to comment.