Skip to content

Commit

Permalink
Franzi Review MMD Notebooks
Browse files Browse the repository at this point in the history
  • Loading branch information
franzigrkn committed Mar 8, 2024
1 parent 10e4270 commit 94e4351
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 66 deletions.
2 changes: 1 addition & 1 deletion configs/conf_main_scaling_experiment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ augmentation: ['gauss', 'one_dim_shift', 'one_dim_shift',]
n: [10000,10000,10000] #samples
d: [2, 10, 1000] # dimensions

# MMD bandwith parameter
mmd_bandwidth: [[1,1,1], [5,5,5], [10,10,10]]

# sample size experiments
Expand All @@ -23,7 +24,6 @@ runs_dim: 5 # number of sample selection for errorbars
# seed for reproducibility
seed: 0


# for the reduced sample size experiments
#sample_size: [8, 10, 20, 50, 80]
#n: [500, 500, 500]
4 changes: 2 additions & 2 deletions docs/notebooks/comparison/main_scaling_experiment.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Perform main sample size and dimensionality scaling experiment"
"# Perform main sample size and dimensionality scaling experiment"
]
},
{
Expand All @@ -21,7 +21,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Load the configs and set up the plotting "
"## Load the configs and set up the plotting "
]
},
{
Expand Down
108 changes: 53 additions & 55 deletions docs/notebooks/mmd/MMD_MainFig.ipynb

Large diffs are not rendered by default.

6 changes: 1 addition & 5 deletions docs/notebooks/mmd/MMD_bandwidth_experiment.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@
"# inline plotting\n",
"%matplotlib inline\n",
"\n",
"\n",
"\n",
"print(\"Running experiments...\")\n",
"# load the config file\n",
"cfg = get_cfg_from_file(\"conf_mmd_bandwidth_experiment\")\n",
Expand All @@ -59,14 +57,12 @@
"assert len(cfg.data) == len(cfg.n) == len(cfg.d), \"Data, n and d must be lists of the same length\"\n",
" \n",
"# setup colors and labels for plotting\n",
"\n",
"color_dict = {\"wasserstein\": \"#cc241d\",\n",
" \"mmd\": \"#eebd35\",\n",
" \"c2st\": \"#458588\",\n",
" \"fid\": \"#8ec07c\", \n",
" \"kl\": \"#8ec07c\"}\n",
"\n",
"\n",
"col_map = {'ScaleSampleSizeKL':'kl', 'ScaleSampleSizeSW':'wasserstein',\n",
" 'ScaleSampleSizeMMD':'mmd', 'ScaleSampleSizeC2ST':'c2st',\n",
" 'ScaleSampleSizeFID':'fid', 'ScaleDimKL':'kl', 'ScaleDimSW':'wasserstein',\n",
Expand Down Expand Up @@ -96,7 +92,7 @@
" \n",
"label_list = [label_true, label_shift]\n",
"label_list[1]['toy_2d'] = 'approx.'\n",
"label_list[1]['random'] = 'shifted'\n"
"label_list[1]['random'] = 'shifted'"
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Test MMD bandwidth sensitivity in all scaling experiments"
"# Test MMD bandwidth sensitivity in all scaling experiments"
]
},
{
Expand All @@ -21,7 +21,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Load the configs and set up the plotting "
"## Load the configs and set up the plotting "
]
},
{
Expand Down Expand Up @@ -99,7 +99,7 @@
" \n",
"label_list = [label_true, label_shift]\n",
"label_list[1]['toy_2d'] = 'approx.'\n",
"label_list[1]['random'] = 'shifted'\n"
"label_list[1]['random'] = 'shifted'"
]
},
{
Expand Down

0 comments on commit 94e4351

Please sign in to comment.