From 35c140c7e8a168482e8d120321115bd28245645a Mon Sep 17 00:00:00 2001 From: John Zhou <60048652+johnlyzhou@users.noreply.github.com> Date: Sun, 9 May 2021 01:12:07 -0400 Subject: [PATCH] remove hardcoded n_epochs remove the hardcoded conditional requiring n_epochs==200 --- behavenet/plotting/cond_ae_utils.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/behavenet/plotting/cond_ae_utils.py b/behavenet/plotting/cond_ae_utils.py index 8befa29..bd2e466 100644 --- a/behavenet/plotting/cond_ae_utils.py +++ b/behavenet/plotting/cond_ae_utils.py @@ -1645,8 +1645,7 @@ def despine(ax): ax_pixel_mse_bg = fig.add_subplot(gs[0, 6:9]) data_queried = metrics_df_frame_bg[ (metrics_df_frame_bg.dtype == 'test') & - (metrics_df_frame_bg.loss == 'loss_data_mse') & - (metrics_df_frame_bg.epoch == 200)] + (metrics_df_frame_bg.loss == 'loss_data_mse')] sns.barplot(x='beta', y='val', hue='gamma', data=data_queried, ax=ax_pixel_mse_bg) ax_pixel_mse_bg.legend().set_visible(False) ax_pixel_mse_bg.set_xlabel('Beta') @@ -1673,8 +1672,7 @@ def despine(ax): ax_icmi = fig.add_subplot(gs[1, 0:4]) data_queried = metrics_df_frame_bg[ (metrics_df_frame_bg.dtype == 'test') & - (metrics_df_frame_bg.loss == 'loss_zu_mi') & - (metrics_df_frame_bg.epoch == 200)] + (metrics_df_frame_bg.loss == 'loss_zu_mi')] sns.lineplot( x='beta', y='val', hue='gamma', data=data_queried, ax=ax_icmi, ci=None, palette=gamma_palette) @@ -1690,8 +1688,7 @@ def despine(ax): ax_tc = fig.add_subplot(gs[1, 4:8]) data_queried = metrics_df_frame_bg[ (metrics_df_frame_bg.dtype == 'test') & - (metrics_df_frame_bg.loss == 'loss_zu_tc') & - (metrics_df_frame_bg.epoch == 200)] + (metrics_df_frame_bg.loss == 'loss_zu_tc')] sns.lineplot( x='beta', y='val', hue='gamma', data=data_queried, ax=ax_tc, ci=None, palette=gamma_palette) @@ -1707,8 +1704,7 @@ def despine(ax): ax_dwkl = fig.add_subplot(gs[1, 8:12]) data_queried = metrics_df_frame_bg[ (metrics_df_frame_bg.dtype == 'test') & - (metrics_df_frame_bg.loss == 'loss_zu_dwkl') & - (metrics_df_frame_bg.epoch == 200)] + (metrics_df_frame_bg.loss == 'loss_zu_dwkl')] sns.lineplot( x='beta', y='val', hue='gamma', data=data_queried, ax=ax_dwkl, ci=None, palette=gamma_palette) @@ -1739,7 +1735,6 @@ def despine(ax): data_queried = metrics_df_frame_bg[ (metrics_df_frame_bg.dtype == 'test') & (metrics_df_frame_bg.loss == 'loss_AB_orth') & - (metrics_df_frame_bg.epoch == 200) & ~metrics_df_frame_bg.val.isna()] sns.lineplot( x='gamma', y='val', hue='beta', data=data_queried, ax=ax_orth, ci=None,