From 66f77fdf76a17005bfd342aa62a0cb474465ed54 Mon Sep 17 00:00:00 2001 From: Weronika Hryniewska Date: Wed, 30 Mar 2022 14:07:08 +0200 Subject: [PATCH] Add new superpixel functionalites, improve display, add more hints --- README.md | 4 +- code/dashboard_LIMEcraft.ipynb | 156 ++++++++++++++++++++++----------- code/explanations.py | 4 +- code/utils.py | 36 +++++++- 4 files changed, 144 insertions(+), 56 deletions(-) diff --git a/README.md b/README.md index 27a9152..3b6a99b 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,8 @@ go to code/dashboard_LIMEcraft.ipynb\ choose venv in Kernel -> Change kernel run the whole notebook\ type http://127.0.0.1:8001/ in the web browser +\ +If you have problems with running code on GPU, put *os.environ['CUDA_VISIBLE_DEVICES'] = '-1'* after *import os* in code/dashboard_LIMEcraft.ipynb. ### How to test own model? @@ -68,4 +70,4 @@ If you find our work useful, please cite our paper: keywords = {Explainable AI, superpixels, LIME, image features, interactive User Interface}, howpublished = {\url{https://arxiv.org/abs/2111.08094}}, } -``` \ No newline at end of file +``` diff --git a/code/dashboard_LIMEcraft.ipynb b/code/dashboard_LIMEcraft.ipynb index 2f09363..d6c00ba 100644 --- a/code/dashboard_LIMEcraft.ipynb +++ b/code/dashboard_LIMEcraft.ipynb @@ -26,6 +26,8 @@ "import utils # python script\n", "\n", "import os\n", + "os.environ['CUDA_VISIBLE_DEVICES'] = '-1'\n", + "\n", "import plotly.express as px\n", "import dash\n", "from dash import dcc, html, dash_table\n", @@ -145,7 +147,8 @@ "fig.update_layout(dragmode=\"drawclosedpath\",\n", " newshape=dict(fillcolor=\"cyan\", opacity=0.3, line=dict(color=\"darkblue\", width=2)),\n", " yaxis_visible=False, yaxis_showticklabels=False,\n", - " xaxis_visible=False, xaxis_showticklabels=False\n", + " xaxis_visible=False, xaxis_showticklabels=False,\n", + " margin=dict(l=0, r=0, t=30, b=10)\n", " )\n", "\n", "mask = np.zeros(mask_shape)\n", @@ -191,6 +194,7 @@ " html.Div([\n", " html.Div([\n", " html.H4(\"Mark areas in the image\", className=\"my-0 font-weight-normal\"),\n", + " html.Div('Mark areas which you want to maintain the semantic meaning of the explained objects.'),\n", " ], className=\"card-header\"),\n", " html.Div([\n", " dcc.Graph(id=\"graph-camera\", figure=fig, config=config, ),\n", @@ -201,12 +205,13 @@ " ),\n", " \n", " ], className=\"card-body\"),\n", - " ], className=\"card mb-4 box-shadow\"),\n", + " ], className=\"card mb-4 box-shadow col-sm\"),\n", " \n", " \n", " html.Div([\n", " html.Div([\n", " html.H4(\"Options\", className=\"my-0 font-weight-normal\"),\n", + " html.Div('Upload image or mask.\\n Adjust the number of superpixels inner and outer the selection (made in step \"Mark areas in the image\" or by uploaded mask).\\n Then, generate results.'),\n", " ], className=\"card-header\"),\n", " \n", " html.Div([\n", @@ -224,20 +229,20 @@ " html.Button(\"submit uploading\", type=\"button\", className=\"btn btn-lg btn-block btn-outline-primary\", id=\"submit-uploading\", n_clicks=0, style={\"display\": \"none\"}), \n", " \n", " \n", - " html.Label('Number of inner segments', className=\"form-label mt-4\"),\n", + " html.Label('Number of inner segments', className=\"form-label mt-4\", id='number_od_inner_seg'),\n", " dcc.Input(id=\"inner-segments\", placeholder='Number of inner segments...',\n", " type='number', value='10', min=0, max=500, className=\"form-control\"),\n", - " html.Label('Number of outer segments', className=\"form-label mt-4\"),\n", + " html.Label('Number of outer segments', className=\"form-label mt-4\", id='number_od_outer_seg'),\n", " dcc.Input(id=\"outer-segments\", placeholder='Number of outer segments...',\n", " type='number', value='50', min=0, max=500, className=\"form-control\"),\n", " \n", " html.Button(\"Generate results\", type=\"button\", className=\"btn btn-lg btn-block btn-primary\", id=\"submit-val\", n_clicks=0, disabled=False),\n", - " html.Button(\"Enable button\", n_clicks=0, type=\"button\", className=\"btn btn-lg btn-block btn-outline-primary\", id=\"enable-button\"),\n", + " html.Button(\"Enable generation\", n_clicks=0, type=\"button\", className=\"btn btn-lg btn-block btn-outline-primary\", id=\"enable-button\"),\n", " \n", " ], className=\"card-body\"),\n", - " ], className=\"card mb-4 box-shadow\"),\n", + " ], className=\"card mb-4 box-shadow col-sm\"),\n", " \n", - " ], className=\"card-deck mb-3 text-center\"),\n", + " ], className=\"card-deck mb-3 text-center row\"),\n", "], className=\"container\")\n", "\n", "\n", @@ -247,6 +252,7 @@ " html.Div([\n", " html.Div([\n", " html.H4(\"LIMEcraft results\", className=\"my-0 font-weight-normal\"),\n", + " html.P(id=\"proportions_original\"),\n", " ], className=\"card-header\"),\n", " html.Div(id=\"progress-bar1\", children=\"\"),\n", " \n", @@ -259,20 +265,22 @@ "\n", " \n", " ], className=\"card-body\"),\n", - " ], className=\"card mb-4 box-shadow\"),\n", + " ], className=\"card mb-4 box-shadow col-sm\"),\n", " \n", " \n", " html.Div([\n", " html.Div([\n", " html.H4(\"LIME results\", className=\"my-0 font-weight-normal\"),\n", + " html.P(id=\"proportions_original_lime\"),\n", " ], className=\"card-header\"),\n", " html.Div([\n", - " \n", + " html.P(id=\"result_1\"),\n", + "\n", " ## LIME results\n", " dbc.Spinner(children=[dcc.Graph(id=\"original-lime\", figure=fig)],),\n", " ], className=\"card-body\"),\n", - " ], className=\"card mb-4 box-shadow\"),\n", - " ], className=\"card-deck mb-3 text-center\"),\n", + " ], className=\"card mb-4 box-shadow col-sm\"),\n", + " ], className=\"card-deck mb-3 text-center row\"),\n", "], className=\"container\")\n", " " ] @@ -302,13 +310,14 @@ " html.Div([\n", " html.Div([\n", " html.H4(\"Edit image color\", className=\"my-0 font-weight-normal\"),\n", + " html.Div('Mark areas that you want to change color.'),\n", " ], className=\"card-header\"),\n", " html.Div([\n", " ##window 1\n", " dcc.Graph(id=\"graph-camera-to-edition\", figure=fig, config=config, ),\n", " \n", " ], className=\"card-body\", style={\"margin\": 0, \"display\": \"inline-block\", \"padding\": \"0 0\"}),\n", - " ], className=\"card mb-4 box-shadow\"),\n", + " ], className=\"card mb-4 box-shadow col-sm\"),\n", " \n", " html.Div([\n", " html.Div([\n", @@ -365,7 +374,7 @@ " html.Button(\"Generate results\", type=\"button\", className=\"btn btn-lg btn-block btn-primary\", id=\"submit-edition\"),\n", " \n", " ], className=\"card-body\"),\n", - " ], className=\"card mb-4 box-shadow\", style={\"max-width\": 250}),\n", + " ], className=\"card mb-4 box-shadow col-sm\", style={\"max-width\": 250}),\n", " \n", " \n", " html.Div([\n", @@ -383,9 +392,9 @@ " dcc.Graph(id=\"graph-camera-edited\", figure=fig,),\n", " \n", " ], className=\"card-body\", style={\"margin\": 0, \"display\": \"inline-block\", \"padding\": \"0 0\"}),\n", - " ], className=\"card mb-4 box-shadow\"),\n", + " ], className=\"card mb-4 box-shadow col-sm\"),\n", " \n", - " ], className=\"card-deck mb-3 text-center\"),\n", + " ], className=\"card-deck mb-3 text-center row\"),\n", "], className=\"container\")\n", "\n", "\n", @@ -394,6 +403,7 @@ " html.Div([\n", " html.Div([\n", " html.H4(\"LIMEcraft results\", className=\"my-0 font-weight-normal\"),\n", + " html.P(id=\"proportions_color_edited\"),\n", " ], className=\"card-header\"),\n", " \n", " html.Div([\n", @@ -411,7 +421,7 @@ " dbc.Spinner(children=[dcc.Graph(id=\"graph-lime-edited\", figure=fig),]),\n", " \n", " ], className=\"card-body\"),\n", - " ], className=\"card mb-4 box-shadow\"),\n", + " ], className=\"card mb-4 box-shadow col-sm\"),\n", " \n", " \n", " html.Div([\n", @@ -432,8 +442,8 @@ "\n", "\n", " ], className=\"card-body\"),\n", - " ], className=\"card mb-4 box-shadow\"),\n", - " ], className=\"card-deck mb-3 text-center\"),\n", + " ], className=\"card mb-4 box-shadow col-sm\"),\n", + " ], className=\"card-deck mb-3 text-center row\"),\n", "], className=\"container\")\n" ] }, @@ -454,7 +464,8 @@ " html.Div([\n", " html.Div([\n", " html.Div([\n", - " html.H4(\"Move and rotate parts of the image\", className=\"my-0 font-weight-normal\"),\n", + " html.H4(\"Move, rotate or remove parts of the image\", className=\"my-0 font-weight-normal\"),\n", + " html.Div('Mark areas that you want to edit.'),\n", " ], className=\"card-header\"),\n", " html.Div([\n", " ##window 1\n", @@ -467,7 +478,7 @@ "\n", " \n", " ], className=\"card-body\", style={\"margin\": 0, \"display\": \"inline-block\", \"padding\": \"0 0\"}),\n", - " ], className=\"card mb-4 box-shadow\"),\n", + " ], className=\"card mb-4 box-shadow col-sm\"),\n", " \n", " html.Div([\n", " html.Div([\n", @@ -495,7 +506,7 @@ " html.Button(\"Generate results\", type=\"button\", className=\"btn btn-lg btn-block btn-primary\", id=\"submit-edition-rotation\"),\n", " \n", " ], className=\"card-body\"),\n", - " ], className=\"card mb-4 box-shadow\", style={\"max-width\": 250}),\n", + " ], className=\"card mb-4 box-shadow col-sm\", style={\"max-width\": 250}),\n", " \n", " \n", " html.Div([\n", @@ -510,9 +521,9 @@ "\n", " \n", " ], className=\"card-body\", style={\"margin\": 0, \"display\": \"inline-block\", \"padding\": \"0 0\"}),\n", - " ], className=\"card mb-4 box-shadow\"),\n", + " ], className=\"card mb-4 box-shadow col-sm\"),\n", " \n", - " ], className=\"card-deck mb-3 text-center\"),\n", + " ], className=\"card-deck mb-3 text-center row\"),\n", "], className=\"container\")\n", "\n", "\n", @@ -521,6 +532,7 @@ " html.Div([\n", " html.Div([\n", " html.H4(\"LIMEcraft results\", className=\"my-0 font-weight-normal\"),\n", + " html.P(id=\"proportions_rotation_edited\"),\n", " ], className=\"card-header\"),\n", " \n", " html.Div([\n", @@ -537,7 +549,7 @@ "\n", " \n", " ], className=\"card-body\"),\n", - " ], className=\"card mb-4 box-shadow\"),\n", + " ], className=\"card mb-4 box-shadow col-sm\"),\n", " \n", " \n", " html.Div([\n", @@ -559,8 +571,8 @@ " \n", "\n", " ], className=\"card-body\"),\n", - " ], className=\"card mb-4 box-shadow\"),\n", - " ], className=\"card-deck mb-3 text-center\"),\n", + " ], className=\"card mb-4 box-shadow col-sm\"),\n", + " ], className=\"card-deck mb-3 text-center row\"),\n", "], className=\"container\")" ] }, @@ -582,6 +594,7 @@ " html.Div([\n", " html.Div([\n", " html.H4(\"Edit image shape\", className=\"my-0 font-weight-normal\"),\n", + " html.Div('Mark areas that you want to edit.'),\n", " ], className=\"card-header\"),\n", " html.Div([\n", " ##window 1\n", @@ -593,7 +606,7 @@ " style={\"display\": \"none\",}),\n", " \n", " ], className=\"card-body\", style={\"margin\": 0, \"display\": \"inline-block\", \"padding\": \"0 0\"}),\n", - " ], className=\"card mb-4 box-shadow\"),\n", + " ], className=\"card mb-4 box-shadow col-sm\"),\n", " \n", " html.Div([\n", " html.Div([\n", @@ -603,14 +616,14 @@ " html.Div([\n", " ##window 2\n", " html.Label('Power of expansion', className=\"form-label mt-4\"),\n", - " dcc.Input(id=\"power-input\", placeholder='Enter a value of power...', type='number', value='1.4',\n", + " dcc.Input(id=\"power-input\", placeholder='Enter a value of power...', type='number', value='1.4', step='.1',\n", " min=0, max=10, className=\"form-control\"),\n", " \n", " html.Button(\"Change another element\", type=\"button\", className=\"btn btn-lg btn-block btn-outline-primary\", id=\"submit-step-shape\"),\n", " html.Button(\"Revert edition\", type=\"button\", className=\"btn btn-lg btn-block btn-outline-primary\", id=\"clear-image-shape\"),\n", " html.Button(\"Generate results\", type=\"button\", className=\"btn btn-lg btn-block btn-primary\", id=\"submit-edition-shape\"), \n", " ], className=\"card-body\"),\n", - " ], className=\"card mb-4 box-shadow\", style={\"max-width\": 250}),\n", + " ], className=\"card mb-4 box-shadow col-sm\", style={\"max-width\": 250}),\n", " \n", " \n", " html.Div([\n", @@ -624,9 +637,9 @@ " dcc.Graph(id=\"graph-camera-edited-shape\", figure=fig,),\n", " \n", " ], className=\"card-body\", style={\"margin\": 0, \"display\": \"inline-block\", \"padding\": \"0 0\"}),\n", - " ], className=\"card mb-4 box-shadow\"),\n", + " ], className=\"card mb-4 box-shadow col-sm\"),\n", " \n", - " ], className=\"card-deck mb-3 text-center\"),\n", + " ], className=\"card-deck mb-3 text-center row\"),\n", "], className=\"container\")\n", "\n", "\n", @@ -635,6 +648,7 @@ " html.Div([\n", " html.Div([\n", " html.H4(\"LIMEcraft results\", className=\"my-0 font-weight-normal\"),\n", + " html.P(id=\"proportions_shape_edited\"),\n", " ], className=\"card-header\"),\n", " \n", " html.Div([\n", @@ -651,7 +665,7 @@ " \n", " \n", " ], className=\"card-body\"),\n", - " ], className=\"card mb-4 box-shadow\"),\n", + " ], className=\"card mb-4 box-shadow col-sm\"),\n", " \n", " \n", " html.Div([\n", @@ -672,8 +686,8 @@ " ),\n", "\n", " ], className=\"card-body\"),\n", - " ], className=\"card mb-4 box-shadow\"),\n", - " ], className=\"card-deck mb-3 text-center\"),\n", + " ], className=\"card mb-4 box-shadow col-sm\"),\n", + " ], className=\"card-deck mb-3 text-center row\"),\n", "], className=\"container\")" ] }, @@ -751,6 +765,33 @@ "# Functionalities" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Optimize numer of superpixels" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "@app.callback(\n", + " Output('number_od_inner_seg', 'children'),\n", + " Output('number_od_outer_seg', 'children'),\n", + " Input('graph-mask', 'figure')\n", + ")\n", + "def optimize_superpix_count(figure):\n", + " mask = np.array(figure[\"data\"][0][\"z\"])\n", + " counts = utils.count_superpix(mask,100)\n", + " return [\n", + " ''.join(str(e) for e in ['Number of inner segments (suggested: ', counts[0], ')']),\n", + " ''.join(str(e) for e in ['Number of outner segments (suggested: ', counts[1], ')']),\n", + " ]" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -865,6 +906,10 @@ " Output(\"graph-mask-edited-shape\", \"figure\"),\n", " \n", " Output(\"original-lime\", \"figure\"),\n", + " Output(\"result_1\", \"children\"),\n", + " \n", + " Output(\"proportions_original\", \"children\"),\n", + " Output(\"proportions_original_lime\", \"children\"),\n", "\n", " Input('submit-val', 'n_clicks'),\n", " Input(\"graph-mask\", \"figure\"),\n", @@ -882,7 +927,7 @@ " changed_id = [p['prop_id'] for p in dash.callback_context.triggered][0]\n", "\n", " if 'submit-uploading' in changed_id:\n", - " return [fig_base, _, _, _, _, _, fig_base]\n", + " return [fig_base, _, _, _, _, _, fig_base, _, _, _]\n", " \n", " if 'submit-val' in changed_id:\n", " \n", @@ -903,8 +948,8 @@ " img_ = np.array(img_)\n", " \n", " \n", - " [img_with_LIMEcraft, preds] = explanations.test_limecraft(model, img_, mask, int(inner_n_segments), int(outer_n_segments))\n", - " img_with_lime = explanations.test_lime(model, img_)\n", + " [img_with_LIMEcraft, preds, pos_neg_mask] = explanations.test_limecraft(model, img_, mask, int(inner_n_segments), int(outer_n_segments))\n", + " [img_with_lime, pos_neg_mask_lime] = explanations.test_lime(model, img_)\n", "\n", " \n", " \n", @@ -926,8 +971,11 @@ " str(preds),\n", " px.imshow(mask),\n", " px.imshow(mask),\n", - " px.imshow(mask),\n", - " img_with_lime_fig]\n", + " px.imshow(mask),\n", + " img_with_lime_fig,\n", + " 'This is a/an \"{}\" with probability {}%'.format(preds[0][1], np.round(np.double(preds[0][2])*100,2)),\n", + " utils.count_pix_proportions(pos_neg_mask),\n", + " utils.count_pix_proportions(pos_neg_mask_lime)]\n", " \n", " else:\n", " return dash.no_update" @@ -1364,6 +1412,7 @@ " Output(\"result_0-edited\", \"children\"),\n", " Output(\"details-raport-edited\", \"children\"),\n", "\n", + " Output(\"proportions_color_edited\", \"children\"),\n", " \n", " Input('submit-edition', 'n_clicks'),\n", " Input(\"graph-mask-edited\", \"figure\"),\n", @@ -1382,7 +1431,7 @@ " changed_id = [p['prop_id'] for p in dash.callback_context.triggered][0]\n", "\n", " if 'submit-uploading' in changed_id:\n", - " return [fig_base, _, _]\n", + " return [fig_base, _, _, _]\n", " if 'submit-edition' in changed_id:\n", " img_str = img_fig_[\"data\"][0][\"source\"]\n", " img_str = img_str.split(\",\")\n", @@ -1391,7 +1440,7 @@ " \n", " mask = np.array(mask_fig_[\"data\"][0][\"z\"]) # extract array from figure\n", " \n", - " [img_with_LIMEcraft, preds] = explanations.test_limecraft(model, img_, mask, int(inner_n_segments), int(outer_n_segments))\n", + " [img_with_LIMEcraft, preds, pos_neg_mask] = explanations.test_limecraft(model, img_, mask, int(inner_n_segments), int(outer_n_segments))\n", " \n", " lime_fig = px.imshow(img_with_LIMEcraft)\n", " lime_fig.update_layout(\n", @@ -1400,9 +1449,10 @@ " margin=dict(l=0, r=0, t=30, b=10)\n", " )\n", " \n", - " return [lime_fig, \n", + " return [lime_fig,\n", " 'This is a/an \"{}\" with probability {}%'.format(preds[0][1], np.round(np.double(preds[0][2])*100,2)),\n", - " str(preds)] \n", + " str(preds),\n", + " utils.count_pix_proportions(pos_neg_mask)] \n", " else:\n", " return dash.no_update\n", " \n", @@ -1411,6 +1461,8 @@ " Output(\"graph-lime-edited-rotation\", \"figure\"),\n", " Output(\"result_0-edited-rotation\", \"children\"),\n", " Output(\"details-raport-edited-rotation\", \"children\"),\n", + " \n", + " Output(\"proportions_rotation_edited\", \"children\"),\n", " \n", " Input('submit-edition-rotation', 'n_clicks'),\n", " Input(\"graph-mask-edited-rotation\", \"figure\"),\n", @@ -1429,7 +1481,7 @@ " changed_id = [p['prop_id'] for p in dash.callback_context.triggered][0]\n", "\n", " if 'submit-uploading' in changed_id:\n", - " return [fig_base, _, _]\n", + " return [fig_base, _, _, _]\n", " if 'submit-edition-rotation' in changed_id:\n", " img_str = img_fig_[\"data\"][0][\"source\"]\n", " img_str = img_str.split(\",\")\n", @@ -1438,7 +1490,7 @@ " \n", " mask = np.array(mask_fig_[\"data\"][0][\"z\"]) # extract array from figure\n", " \n", - " [img_with_LIMEcraft, preds] = explanations.test_limecraft(model, img, mask, int(inner_n_segments), int(outer_n_segments))\n", + " [img_with_LIMEcraft, preds, pos_neg_mask] = explanations.test_limecraft(model, img, mask, int(inner_n_segments), int(outer_n_segments))\n", " \n", " lime_fig = px.imshow(img_with_LIMEcraft)\n", " lime_fig.update_layout(\n", @@ -1448,7 +1500,8 @@ " )\n", " return [lime_fig, \n", " 'This is a/an \"{}\" with probability {}%'.format(preds[0][1], np.round(np.double(preds[0][2])*100,2)),\n", - " str(preds)] \n", + " str(preds),\n", + " utils.count_pix_proportions(pos_neg_mask)]\n", " else:\n", " return dash.no_update\n", " \n", @@ -1457,6 +1510,8 @@ " Output(\"graph-lime-edited-shape\", \"figure\"),\n", " Output(\"result_0-edited-shape\", \"children\"),\n", " Output(\"details-raport-edited-shape\", \"children\"),\n", + " \n", + " Output(\"proportions_shape_edited\", \"children\"),\n", " \n", " Input('submit-edition-shape', 'n_clicks'),\n", " Input(\"graph-mask-edited-shape\", \"figure\"),\n", @@ -1475,7 +1530,7 @@ " changed_id = [p['prop_id'] for p in dash.callback_context.triggered][0]\n", "\n", " if 'submit-uploading' in changed_id:\n", - " return [fig_base, _, _]\n", + " return [fig_base, _, _, _]\n", " if 'submit-edition-shape' in changed_id:\n", " img_str = img_fig_[\"data\"][0][\"source\"]\n", " img_str = img_str.split(\",\")\n", @@ -1484,7 +1539,7 @@ " \n", " mask = np.array(mask_fig_[\"data\"][0][\"z\"]) # extract array from figure\n", " \n", - " [img_with_LIMEcraft, preds] = explanations.test_limecraft(model, img, mask, int(inner_n_segments), int(outer_n_segments))\n", + " [img_with_LIMEcraft, preds, pos_neg_mask] = explanations.test_limecraft(model, img, mask, int(inner_n_segments), int(outer_n_segments))\n", " \n", " lime_fig = px.imshow(img_with_LIMEcraft)\n", " lime_fig.update_layout(\n", @@ -1494,7 +1549,8 @@ " )\n", " return [lime_fig, \n", " 'This is a/an \"{}\" with probability {}%'.format(preds[0][1], np.round(np.double(preds[0][2])*100,2)),\n", - " str(preds)] \n", + " str(preds),\n", + " utils.count_pix_proportions(pos_neg_mask)]\n", " else:\n", " return dash.no_update" ] @@ -1695,7 +1751,7 @@ " pdf.ln(line_height)\n", " \n", "\n", - " pdf.output('report.pdf', 'F')\n", + " pdf.output('report.pdf', 'D')\n", " \n", " return\n", " else:\n", diff --git a/code/explanations.py b/code/explanations.py index c73777b..d73353c 100644 --- a/code/explanations.py +++ b/code/explanations.py @@ -280,7 +280,7 @@ def test_lime(model, img): temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=False, num_features=10, hide_rest=False) out2 = mark_boundaries(mark_boundaries(temp / 2 + 0.5, mask), segments) - return out2 + return [out2, mask] @@ -316,4 +316,4 @@ def test_limecraft(model, img, mask, inner_n_segments, outer_n_segments): segments = explanation.segments out1 = mark_boundaries(mark_boundaries(temp, mask2), segments) - return [out1, decoded_predictions] + return [out1, decoded_predictions, mask2] diff --git a/code/utils.py b/code/utils.py index 1e51fbe..ff16407 100644 --- a/code/utils.py +++ b/code/utils.py @@ -3,6 +3,7 @@ import cv2 import pandas as pd from PIL import Image +import tensorflow from tensorflow.keras.preprocessing import image import matplotlib.pyplot as plt @@ -86,13 +87,13 @@ def find_centrum(path): -def circle_rotate(im, x, y, radius, degree, mask, sub_mask=None, left_right=0, up_down=0): +def circle_rotate(im, x, y, radius, degree, img_shape, mask, sub_mask=None, left_right=0, up_down=0): """ Rotates and shifts the selected part of the picture """ img_arr = np.array(im) - box = (max(x-radius,0), max(y-radius,0), min(x+radius+1,224), min(y+radius+1,224)) + box = (max(x-radius,0), max(y-radius,0), min(x+radius+1,img_shape[1]), min(y+radius+1,img_shape[0])) crop = im.crop(box=box) crop_arr = np.asarray(crop) # build the circle mask @@ -200,4 +201,33 @@ def merge_all_preds(str_original, str_color, str_rotation, str_shape): df=pd.merge(df, df3, on="class", how='outer').fillna('-') df=pd.merge(df, df4, on="class", how='outer').fillna('-') - return df \ No newline at end of file + return df + +def count_superpix(mask, superpix_count): + """ + superpix_count (int) - summarized superpixels count + return optimal inside and outside of mask superpixels count + """ + + number_of_inside_pix = np.sum(mask != 0) + number_of_outside_pix = np.sum(mask == 0) + count_inside_superpix = np.int32(np.round(superpix_count * number_of_inside_pix / (number_of_inside_pix+number_of_outside_pix))) + count_outside_superpix = superpix_count - count_inside_superpix + return [count_inside_superpix, count_outside_superpix] + + +def count_pix_proportions(mask): + """ + Mask with negative (-1), neutral (0) and positive (1) meaning of values of superpixels. + It counts proportions of each of their's pixels number to all pixels in the mask. + """ + neg_pix_n = np.sum(mask==-1) + pos_pix_n = np.sum(mask==1) + neu_pix_n = mask.shape[0]*mask.shape[1]-neg_pix_n-pos_pix_n + + all_pix_n = mask.shape[0]*mask.shape[1] + result = ['Superpixels that contribute:\n', + 'positively (green color): ', str(np.round(pos_pix_n/all_pix_n*100,2)), '%,\n', + 'negatively (red color): ', str(np.round(neg_pix_n/all_pix_n*100,2)), '%,\n', + 'neutrally: ', str(np.round(neu_pix_n/all_pix_n*100,2)), '%'] + return ''.join(result)