diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..57867dc --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2019 BMW Group + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..a75e0bd --- /dev/null +++ b/README.md @@ -0,0 +1,238 @@ +# Tensorflow CPU Inference API For Windows and Linux +This is a repository for an object detection inference API using the Tensorflow framework. + +This repo is based on [Tensorflow Object Detection API](https://github.com/tensorflow/models/tree/master/research/object_detection). + +The Tensorflow version used is 1.13.1. The inference REST API works on CPU and doesn't require any GPU usage. It's supported on both Windows and Linux Operating systems. + +Models trained using our training tensorflow repository can be deployed in this API. Several object detection models can be loaded and used at the same time. + +![predict image](./docs/4.gif) + +## Prerequisites + +- OS: + - Ubuntu 16.04/18.04 + - Windows 10 pro/enterprise +- Docker + +### Check for prerequisites + +To check if you have docker-ce installed: + +```sh +docker --version +``` + +### Install prerequisites + +#### Ubuntu + +Use the following command to install docker on Ubuntu: + +```sh +chmod +x install_prerequisites.sh && source install_prerequisites.sh +``` + +#### Windows 10 + +To [install Docker on Windows](https://docs.docker.com/docker-for-windows/install/), please follow the link. + +**P.S: For Windows users, open the Docker Desktop menu by clicking the Docker Icon in the Notifications area. Select Settings, and then Advanced tab to adjust the resources available to Docker Engine.** + +## Build The Docker Image + +In order to build the project run the following command from the project's root directory: + +```sh +sudo docker build -t tensorflow_inference_api_cpu -f docker/dockerfile . +``` + +### Behind a proxy + +```sh +sudo docker build --build-arg http_proxy='' --build-arg https_proxy='' -t tensorflow_inference_api_cpu -f ./docker/dockerfile . +``` + +## Run the docker container + +To run the API, go to the project's root directory and run the following: + +#### Using Linux based docker: + +```sh +sudo docker run -itv $(pwd)/models:/models -p :4343 tensorflow_inference_api_cpu +``` + +#### Using Windows based docker: + +```sh +docker run -itv ${PWD}/models:/models -p :4343 tensorflow_inference_api_cpu +``` + +The can be any unique port of your choice. + +The API file will be run automatically, and the service will listen to http requests on the chosen port. + +## API Endpoints + +To see all available endpoints, open your favorite browser and navigate to: + +``` +http://:/docs +``` +The 'predict_batch' endpoint is not shown on swagger. The list of files input is not yet supported. + +**P.S: If you are using custom endpoints like /load, /detect, and /get_labels, you should always use the /load endpoint first and then use /detect or /get_labels** + +### Endpoints summary + +#### /load (GET) + +Loads all available models and returns every model with it's hashed value. Loaded models are stored and aren't loaded again + +![load model](./docs/1.gif) + +#### /detect (POST) + +Performs inference on specified model, image, and returns bounding-boxes + +![detect image](./docs/3.gif) + +#### /get_labels (POST) + +Returns all of the specified model labels with their hashed values + +![get model labels](./docs/2.gif) + +#### /models/{model_name}/predict_image (POST) + +Performs inference on specified model, image, draws bounding boxes on the image, and returns the actual image as response + +![predict image](./docs/4.gif) + +#### /models (GET) + +Lists all available models + +#### /models/{model_name}/load (GET) + +Loads the specified model. Loaded models are stored and aren't loaded again + +#### /models/{model_name}/predict (POST) + +Performs inference on specified model, image, and returns bounding boxes. + +#### /models/{model_name}/labels (GET) + +Returns all of the specified model labels + +#### /models/{model_name}/config (GET) + +Returns the specified model's configuration + +#### /models/{model_name}/predict_batch (POST) + +Performs inference on specified model and a list of images, and returns bounding boxes + +## Model structure + +The folder "models" contains subfolders of all the models to be loaded. +Inside each subfolder there should be a: + +- pb file: contains the model weights +- pbtxt file: contains model classes +- Config.json (This is a json file containing information about the model) + + ```json + { + "inference_engine_name": "tensorflow_detection", + "confidence": 60, + "predictions": 15, + "number_of_classes": 2, + "framework": "tensorflow", + "type": "detection", + "network": "inception" + } + ``` + P.S: + - "number_of_classes" value should be equal to your model's number of classes + - You can change "confidence" and "predictions" values while running the API + - The API will return bounding boxes with a confidence higher than the "confidence" value. A high "confidence" can show you only accurate predictions. "confidence" value should be between 0 and 100 + - The "predictions" value specifies the maximum number of bounding boxes in the API response. It should be positive + + +## Benchmarking + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
WindowsUbuntu
Network\HardwareIntel Xeon CPU 2.3 GHzIntel Xeon CPU 2.3 GHzIntel Xeon CPU 3.60 GHzGeForce GTX 1080
ssd_fpn0.867 seconds/image1.016 seconds/image0.434 seconds/image0.0658 seconds/image
frcnn_resnet_504.029 seconds/image4.219 seconds/image1.994 seconds/image0.148 seconds/image
ssd_mobilenet0.055 seconds/image0.106 seconds/image0.051 seconds/image0.052 seconds/image
frcnn_resnet_1014.469 seconds/image4.985 seconds/image2.254 seconds/image0.364 seconds/image
ssd_resnet_501.34 seconds/image1.462 seconds/image0.668 seconds/image0.091 seconds/image
ssd_inception0.094 seconds/image0.15 seconds/image0.074 seconds/image0.0513 seconds/image
+ +## Acknowledgment + +[inmind.ai](https://inmind.ai) + +[robotron.de](https://robotron.de) + +Joe Sleiman, inmind.ai , Beirut, Lebanon + +Antoine Charbel, inmind.ai, Beirut, Lebanon diff --git a/docker/dockerfile b/docker/dockerfile new file mode 100644 index 0000000..7e76d38 --- /dev/null +++ b/docker/dockerfile @@ -0,0 +1,12 @@ +FROM python:3.6 + +LABEL maintainer="antoine.charbel@inmind.ai" + +COPY docker/requirements.txt . +COPY src/main /main + +RUN pip install -r requirements.txt + +WORKDIR /main + +CMD ["uvicorn", "start:app", "--host", "0.0.0.0", "--port", "4343"] diff --git a/docker/requirements.txt b/docker/requirements.txt new file mode 100644 index 0000000..1ea75f1 --- /dev/null +++ b/docker/requirements.txt @@ -0,0 +1,24 @@ +aiofiles +celery +fastapi +h5py +matplotlib +numpy +opencv-python +python-multipart +pandas +Pillow +python-socketio +requests +scipy +sklearn +socketIO-client-nexus +tensorflow==1.13.1 +uvicorn +jsonschema + + + + + + diff --git a/docs/1.gif b/docs/1.gif new file mode 100644 index 0000000..5c72af7 Binary files /dev/null and b/docs/1.gif differ diff --git a/docs/2.gif b/docs/2.gif new file mode 100644 index 0000000..cd18da6 Binary files /dev/null and b/docs/2.gif differ diff --git a/docs/3.gif b/docs/3.gif new file mode 100644 index 0000000..eea6990 Binary files /dev/null and b/docs/3.gif differ diff --git a/docs/4.gif b/docs/4.gif new file mode 100644 index 0000000..30b2750 Binary files /dev/null and b/docs/4.gif differ diff --git a/docs/swagger_endpoints.png b/docs/swagger_endpoints.png new file mode 100644 index 0000000..af07981 Binary files /dev/null and b/docs/swagger_endpoints.png differ diff --git a/docs/uml/InferenceClassDiagram.drawio b/docs/uml/InferenceClassDiagram.drawio new file mode 100644 index 0000000..ebe1a45 --- /dev/null +++ b/docs/uml/InferenceClassDiagram.drawio @@ -0,0 +1 @@ +7V1bk6O2Ev41rjp58BZ348fxXHZzMptsZU7tyebFJYNsk8WICDyeya+PBBIGJF8mHiObaGtrFzWSkLo/Wt2tFh7Yt6uXjxiky88ohPHAMsKXgX03sMgfwyX/UcprSTF9xyspCxyFjLYlPEV/QUY0GHUdhTBrVMwRivMobRIDlCQwyBs0gDHaNKvNUdx8agoWUCA8BSAWqf+PwnxZUn3X2NI/wWix5E82DXZnBoLvC4zWCXvewLLnxZ/y9grwvlj9bAlCtKmR7PuBfYsRysur1cstjClzOdvKdg877lbjxjDJj2owAsEMuCPHh6499mZD1sMziNeMFz8mc0i6CyAbcf7KuUQGn9LL9Sp+wGBFLiebZZTDpxQElL4h6CC0Zb6KSckklwVnIH22QUrV5GkhQKsoYNcxmMF4UrHyFsUIk1sJSugzshyj75VcaLdzlOQPYBXFFG9fIQ5BAhiZYcuk3YI4WiSkEBDmQNLhROQWnz7EOXypkRj3PkK0gjl+JVXY3aHJMcuwPnTHjLDZIsesBL6swWbMmwIG10XV/VZk5IJJ7UgJ8ifVRHhLGIFRHJNJCzLcRKsYFHytyangHKtUCGcZxeEjeEVrOvQsJ6LhpckS4egvUh/wxuQ25ny3PNpbFMctGW4bPdHO2GMwzEizL1weZkV6BFle4SSOQZpFs2JwtMoK4EWUTFCeoxVHFpvVQ+3J21exjzh0Wyi0JCi0DAkILecsIBT1SBhPM4ifI6oabsitOwjTRwhwEiWLJ3ajjU4y+byJzJLpLUlIAMZ5HMM57YEyMiL6/YaRc0Q1U0YUFXn6Y1HnztlSfmX8cXaoNET6m8cFZpZRGMKkwFQOcjCr3pkURUle8NSdkL+Ey7fGB3fgknndkrK5LZO/tDrOyXtK5geiAgWQYH4DKe5l+Nj51h9GDIOI5R0JEO8c+HAEfECybuNpjBbkZV6UELmnJA2KDkHhWt2B4jl5fUx+h/Dr0Ln5/BO+T3/xbiX2x8CaxAiE//mhxMRNGv0Ks5QwRauLLpEx8hUjw5IiI8ry6Yo6IJkGiFqAmIajGCG2DCF4nZQA0fhQjA/7WIv0XPgQbY46PqYzkAdLjRLFKHE7NEulKHH3rzPTwinUq41qnPiqLVXvAE4ClMyjhcaJYi/XUG23jgScHBP50HG5fsTlbIvpqbcH5t4jOixFpC8gcli6UNMwopspRXSOXmlFdaSi2vviX2RETjrisQiMGcggW9DCCJfYeMoxDdFpdHSKDvWhOXFjicXmmg72BKEYgkTjo1t8KA/QmfLYrRiheyQkjY5u0aE+OmdKA7jMpyZKBFYhfr2+KEGI8vicKQ3gLmDeirlo/aEAHcrjcqY0fEvR0Yy0aM9FATqUR+NMMRwnzx7QgZV+BFbcMdMHbw6sVAvd+4NQjPUNQ5BDppjolVZMpyim6i2/npiKKYm20SQnyIIpjyzhSQOjW2CoD6dIom1LkISxRoZaZCgPpEgyuImlGyxh8H1KlxNu6X5FpBuNjm7RoT6QYssCKV5MZ502sOD9uabnSJi9WGLGSF/IvwWDjJI+LARL7zm1exRGQ4YIeo+ZblWf5GrB/i+eHNUIYEUFHoulH2kfcwoLTm42afRIeBO1nzLDAoUTJiCDzOKv2s/atQktbdOWmDKNH0Hi8zN3T/UfMbl8qSout/sekhtSq2DXvGUzOcvA5EtSt4Mg2o++FMkcNVWfyBxab1PurIsVux1ukc8uG29zGC31Lde6LdVc06VUj7c8LqvtmH2C8TOkvQ7OfxjrdMfO9444ymKOZY6ddS7HzvcFQcFwAXnEgKxDS7RACYjvt9QWf7d1HhGVaSGCP2CevzK5gXWOmgKCL1H+G21O1ryy9K125+6F9VwUXnkhIfP9rV74tu2BFrfNitK2XXhDD05uJUsoDwW47gzpesoEm6E1DuAe3tksoS4n7w7ctxZ7zHumnN0LFAxjkEfPsDEOmcyLpmRa4LVWgRkZ256/UMIWf2brKBULbT3sqD50Tqo+dEZuC57lgLdgrWZ+gskgJjVWxzrvE6LS4AMIcsTeFB0u62G4rAm7I51f71wK1ZZmT9JgPizgyNfNFkq1t3OSt1OpgesJnznirs8hTGiV1Q+VZR+rpfxzaSlXuuWYYhTALCNaQG85vptqqt7z61FNrvSkSIjBZjqj7xaRBbl4od8x0dE6RSBRHuX3XSWeI/cCq4JqL9C1r8wLNPymxTxufUjnUP1uHDtXljKlY8Egg4KVeK1R4SL1rZ74pjwkXOZbrTF5/VBSt4A6HkeZO5oW39Sq542qiPbWP1GxLziN10nTaux6sGXfxAxIhIGVJ6BrQ2rW6pytRKlFdL9z2gRc9wPJqiRDdWOomPFHhpJjRqM3FoSMsbFgDMr3Ec7nUIphr30Hg3Ukox+RjFbwtdqnOpipeLYjoK48UxHoeMY7uaru9WUqupJMxWITW5/KUQgL5REMV8xTpCbROqBBUH3iUyU2lGcqemII9CZNYyIqapfqAxh9NmqEL1u4jnKzxhPt62EI52Ad59MVUVfFZ7/1YvZOCsu7vu1kT2L3gjCMqLoCscbIBWBEucHjiXYwy01h8NDn1JUCRLnVM7IEmXeYDcqvazt615QNOjhuG5B/LeL9tgFPE7mjUuTGVYvcOlLkPP6gfOe3adaaI7sOn0P1Tb+Dbd+RmqyCXmggfu7yIB7tC1NBo+uT+QfDbGWgeIcFT0pfII4I16ireu68FOdY7WRfFhrGekH6xxrgWCOEZ6krXpGqA1DVksQesWtJEhrY+5OX7JOqO43BnGfB4zuy16T8LgXuHMWHVZx7USqOf1NIy/x8Njc3hi5Owx0wukUNd0Bnue1TpW9uYLL0kfPqOVtj/gjM7/w9sWtJL7ZG7UPOlrsXjkKDTtKLfTVhjzNlr4vbZyfYkMZ1Ac429quzww1s/4DCdKwTG3QDaa9PkN4J4GvBpWM1QWAdWJaFcxZvbrAfZgcHyHcf+ARLTcEatSzet8H1/v5/P8P16M+R735bfJpN/vtsBZLfQ+Qaat+Gk7Df395KWkVhWCKbZjvs3ACiezy0O4JndhZir9Y8PnlA+L1ak+OxsatTkhqbOu+QOiBltHiQpheMtg3zwhgtnqntBaNFRPMPlKpitJib1QtGu8w13JrO3IlVxWgx6cj88OGYbMiL57XZVB6mJ0La6ZLTYvZOLyDdjr+aY1stosUMmF7wWdDRlqNYR4u51UZPVIfltwwPZ+yqVR6S3y/qC7Ndy24zW8y07JbZPXVcbKMVtLBMWTpal5y2+srp9qqonNO99REbjJaeTuySzf8SD1E5n3vqIJotPHuWYj6L/mFf7A6LfxiXp4Io9g8lP8DSF1a7fG67tUe3rO6pj+i2v/skO7LUJZ9FF7EXfB754wafR4pXQ8mPbvSCzw5LLGh9n1odn3tqdYyE3yNTHZa2JHHpPnDab39CvTKs35/TpIgR/RRTde8jBunyM/3yGCH+DQ== \ No newline at end of file diff --git a/docs/uml/InferenceClassDiagram.png b/docs/uml/InferenceClassDiagram.png new file mode 100644 index 0000000..6be3c26 Binary files /dev/null and b/docs/uml/InferenceClassDiagram.png differ diff --git a/docs/uml/InferenceSequenceDiagram.png b/docs/uml/InferenceSequenceDiagram.png new file mode 100644 index 0000000..63546a3 Binary files /dev/null and b/docs/uml/InferenceSequenceDiagram.png differ diff --git a/docs/uml/InferenceSequenceDiagram.xml b/docs/uml/InferenceSequenceDiagram.xml new file mode 100644 index 0000000..0170352 --- /dev/null +++ b/docs/uml/InferenceSequenceDiagram.xml @@ -0,0 +1 @@ +7VxZd6M2FP41fkyOxM5jlknb02nPdDKnnT5ikG06GLkCZ+mvrwRik4SNEwROTvIwAxchgfTdT3fDC/Nm+/QTCXab33CEkoUBoqeFebswDAgNn/7HJM+lxPZgKViTOOKNGsF9/B/iQsCl+zhCWadhjnGSx7uuMMRpisK8IwsIwY/dZiucdEfdBWskCe7DIJGlf8VRvqneC4Dmws8oXm/40J7NLyyD8Mea4H3Kx1sY5qr4Ky9vg6ov3j7bBBF+bInMTwvzhmCcl0fbpxuUsLmtpq28767nav3cBKX5kBvcZbAMYGgtQ8NdWUtwYZQ9PATJns/Fwry6RWj3GQUkjdP1PSIPcYj44+fP1ZTRN9mxw/02+RyvUBKn9Ox6h0i8RTki9ErCxV8a2fXjJs7R/S4I2a2PFExUtsm3CT2D9JCubx7QW0h9niTBLouXxaiASggK9ySLH9BXlJUwYlK8z9lINzU8iqZsWVDEu6pnHhT9buOQHyfBEiXX9Tre4ASz4VNcvFCWE/wDVUK6vKD4q69UcGFDrOIkabW8K/6YnL7VXbCNE6YdfyISBWnAxVwVoMHPVQMFSbxOqSyka1xMorzoHAcPiOToqSXiIPgJYboA5Jk24VcthwOSK6zHTx9b6K8wvmkB37G4MOAat667blBHDzjwBoKwooE2CkW8tVCyw3GaF+Pb1wv7VoAdJvkGr3EaJG3gNWAA7x0MvTo+GB1uFxw1DbbRoQCHMQY4vv/x5Sr+zc5urn7/+rP/a/Trt1+eL0xbO0AmmEfLswbNIwSOBiUzVUz/S7pC9I1D9CldUwK9C8Ic83E+yP5M9ftksnfBcbJ3FDg0XS1krwDiB9mPRvbmyeiwu2QP7WFkX7UbneyhJeEBRdRc56fNkn9qpMKSN20+Y7zji/QPyvNnPu3BPsddNho+tRnekxAdQXcekDXq66hqxF7r4BIRlAQ55b+O4qsmvLj1ipDgudWAq07T8xcmaFYeAmHpXcF5ENr7XRoRfA16UD5As/L1m7wCDKqd30nozF5H8QM9XLPDryiIKjEdpnVF0ZjNUHpS83CDcYbYqm7Yv1WjJRFvO9pZiAlhmxZd+2LD7btRwD9V5bwLV8J2xKDZKmWalMiDUUJMfd4rfmEbR1GhOz082PiyKu04qLoS+9SuO3/kRdv9VbESuORdnaYCEmatLsQvTLPbA16tMqqnImedhlxJwT2VrUUNFrpxJAldjA/76oy31JPtK8P3jtpXYDJn2v8wrzSaV96p4BAIyBtoXdm6rCvT0o6PCaaxjl8dc6X1xKugNId3mDwGJKptBIL+3aMsX9SBX+XsilvyEuc53tILKI2uWFybyRIc/jiqTgNVoW1Io2SJH9s2dCGgF6qnKp6DztV3rtDFyd/s5BLQfZQLbp/al2+f22ctVBTCXrI9bC4DLeayZCuYttm1hz330vF8x7aha1jQrMI3VY+lI8A7OWA324YQT/O8bkfl20sdjWGJQFndb4LilhKnlRkKqD/EJpDZ0eCfDKelybaK12PhtwdKcyG7F4qH3TswBK963DsJVy4UHDIABgF0NN9M3jlmcdQVC6d8XnciGrFEdReWpUfdT3XfTYlWDrvvR9pr8t/laN89Yt43SHDBNfVG+VZd3r6swMku7wW4hIZrdNbognf1Skgal4IbDIUMwyhu8CGlO2Ru9ni/AVu7eBcU00+vXbHcxHxusYilnPHVdUbHidP1t4K8TGcYtJQedi+vvTjf5Mo2sqlyQ+0RbGTl2ivDIELKaSgaPmIhbyIW4nldDKpiIYauWIgShB/xkBHxcFDNB4PE7u5Gdd3AEWdetKPGi4koiEp7TETLVFr+zFMJfWnmJvAJat8SXgKjjkyUoQrXMsYJVRxasKO+B58WwffoOpL9N47vOnZZ2hQ9x5FcFE9Q9CMJRrG57wt41OCgWHBmwEIPioD1zwCwFSeejtjyTt2QtUWrVQ9kbfMkyFZPpdendmaFLPVTgWWL8WB31Hjw2fCksL6m4Xa7GBoD9oXgi2kIieiR4Cs/sDkBIGUnnweZeZBntU/DPGZB5bca5XF6DLMXRXl8o5touBin1qF6yMr16t6vL8RTQbu1/J/LdS9zDNvye5Exlp6yEvfOYQsJCVrlwyI0dKZHcXN9wc01lflIVWJXl+FtzG3HAMfr7gm+PZLhPTS+78tbxbmW3nliaTiYwG4wVG6uopYu35N0ISUI4zTLgzREr6h+K/SDqscJ7cuHYNsIihoqeVdFc73s84KiOWjCbi1S9dHCK7cW6te7Qr/TbS/zmruU2aDgofljeWiDUs6DKor1lEicSmuO3UWfBbwJaM2dGx+2K+DDnxIfQ0oS1OGnqeFhCR+iWIcqyOX8taO+W0MZjRpnciah3ivv92GIsux9GbkWgELe2FSGl6c0c6sE5hnWnPQp3dHQhZ66lJO/BrEPKuep7fVwvSlXXtaqB4pz8I3s5aTym9ZDKEbrVVqorCLWpoXGrFsui6GYVmfPnTn6WFZVjx5+9CyBgG0xTT40/igWpUkdjRR/NDz1OH3PdaS9JhbpLUrbERTFxfdZb70urVLRj7q0nvmZxascZEgc/G5Dd/VqlfCqFVL8CHwkougbR6/iy4mK3g833pMFYXpO30ZyrDBLnw0xb6VIwWpO129vCaaJWFe/a3I2SXZf1EpNteuOaI5MUYtejSmnKZttv8lUyo7mm9n6q0DFGFs/cJxu5HecXKUQpa6q6PXv+5YhgeBLufpjrvq5EL+Yp3SgTPvulKxvzfLjHu1yK1sI5nvQmpj0vaEcr94dpg4M+QJXO4dqnGQ7zxXvFp5Pc9TWkn9ApI7a7tqKT7V6h9PsncWOHHH6TYXdp/rxJ30MMHe6BroiA9gvStegpzjv1sNRwd/FEACY/LzpkJ08t07GtyIHRqlmJxTR9XOcwzEiq9qz1e2PpY18cbSJCUiuuqgJaFzOmcHUrJR5jNoFAzhijcEopiYUY0x1VmsCa3MWJ/cVVGINpJLzyFmJ0WPnyAcGR9rrcTtt2e2Ug07v0faAPYZjx/ZQfZH6AtuDnja/ol2uXPNT5ean/wE= \ No newline at end of file diff --git a/install_prerequisites.sh b/install_prerequisites.sh new file mode 100644 index 0000000..fc7493d --- /dev/null +++ b/install_prerequisites.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +# This will install docker following [https://docs.docker.com/install/linux/docker-ce/ubuntu/] +sudo apt-get remove docker docker-engine docker.io +sudo apt-get update + +sudo apt-get install \ + apt-transport-https \ + ca-certificates \ + curl \ + software-properties-common + +curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo apt-key add - +sudo apt-key fingerprint 0EBFCD88 + +sudo add-apt-repository \ + "deb [arch=amd64] https://download.docker.com/linux/ubuntu \ + $(lsb_release -cs) \ + stable" + +sudo apt-get update +sudo apt-get install -y docker-ce +sudo groupadd docker +sudo usermod -aG docker ${USER} +docker run hello-world + diff --git a/models/.gitignore b/models/.gitignore new file mode 100644 index 0000000..76bedae --- /dev/null +++ b/models/.gitignore @@ -0,0 +1,5 @@ +# Ignore everything in this directory +* +# Except this file +!.gitignore + diff --git a/src/main/.gitignore b/src/main/.gitignore new file mode 100644 index 0000000..5b478e6 --- /dev/null +++ b/src/main/.gitignore @@ -0,0 +1,2 @@ +__pycache__/ +*.log diff --git a/src/main/deep_learning_service.py b/src/main/deep_learning_service.py new file mode 100644 index 0000000..0459d56 --- /dev/null +++ b/src/main/deep_learning_service.py @@ -0,0 +1,167 @@ +import os +import json +import uuid +import re +from inference.inference_engines_factory import InferenceEngineFactory +from inference.exceptions import ModelNotFound, InvalidModelConfiguration, ModelNotLoaded, InferenceEngineNotFound, \ + InvalidInputData, ApplicationError + + +class DeepLearningService: + + def __init__(self): + """ + Sets the models base directory, and initializes some dictionaries. + Saves the loaded model's hashes to a json file, so the values are saved even though the API went down. + """ + # dictionary to hold the model instances (model_name: string -> model_instance: AbstractInferenceEngine) + self.models_dict = {} + # read from json file and append to dict + file_name = 'model_hash.json' + file_exists = os.path.exists(file_name) + if file_exists: + try: + with open(file_name) as json_file: + self.models_hash_dict = json.load(json_file) + except: + self.models_hash_dict = {} + else: + with open('model_hash.json', 'w'): + self.models_hash_dict = {} + self.labels_hash_dict = {} + self.base_models_dir = '/models' + + def load_model(self, model_name, force_reload=False): + """ + Loads a model by passing the model path to the factory. + The factory will return a loaded model instance that will be stored in a dictionary. + :param model_name: Model name + :param force_reload: Boolean to specify if we need to reload a model on each call + :return: Boolean + """ + if not force_reload and self.model_loaded(model_name): + return True + model_path = os.path.join(self.base_models_dir, model_name) + try: + self.models_dict[model_name] = InferenceEngineFactory.get_engine(model_path) + return True + except ApplicationError as e: + raise e + + def load_all_models(self): + """ + Loads all the available models. + :return: Returns a List of all models and their respective hashed values + """ + self.load_models(self.list_models()) + models = self.list_models() + for model in models: + if model not in self.models_hash_dict: + self.models_hash_dict[model] = str(uuid.uuid4()) + for key in list(self.models_hash_dict): + if key not in models: + del self.models_hash_dict[key] + # append to json file + with open('model_hash.json', "w") as fp: + json.dump(self.models_hash_dict, fp) + return self.models_hash_dict + + def load_models(self, model_names): + """ + Loads a set of available models. + :param model_names: List of available models + :return: + """ + for model in model_names: + self.load_model(model) + + async def run_model(self, model_name, input_data, draw_boxes, predict_batch): + """ + Loads the model in case it was never loaded and calls the inference engine class to get a prediction. + :param model_name: Model name + :param input_data: Batch of images or a single image + :param draw_boxes: Boolean to specify if we need to draw the response on the input image + :param predict_batch: Boolean to specify if there is a batch of images in a request or not + :return: Model response in case draw_boxes was set to False, else an actual image + """ + if re.match(r'[0-9a-fA-F]{8}\-[0-9a-fA-F]{4}\-[0-9a-fA-F]{4}\-[0-9a-fA-F]{4}\-[0-9a-fA-F]{12}', model_name, + flags=0): + for key, value in self.models_hash_dict.items(): + if value == model_name: + model_name = key + if self.model_loaded(model_name): + try: + if predict_batch: + return await self.models_dict[model_name].run_batch(input_data, draw_boxes, predict_batch) + else: + if not draw_boxes: + return await self.models_dict[model_name].run(input_data, draw_boxes, predict_batch) + else: + await self.models_dict[model_name].run(input_data, draw_boxes, predict_batch) + except ApplicationError as e: + raise e + else: + try: + self.load_model(model_name) + return await self.run_model(model_name, input_data, draw_boxes, predict_batch) + except ApplicationError as e: + raise e + + def list_models(self): + """ + Lists all the available models. + :return: List of models + """ + return [folder for folder in os.listdir(self.base_models_dir) if + os.path.isdir(os.path.join(self.base_models_dir, folder))] + + def model_loaded(self, model_name): + """ + Returns the model in case it was loaded. + :param model_name: Model name + :return: Model name + """ + return model_name in self.models_dict.keys() + + def get_labels(self, model_name): + """ + Loads the model in case it's not loaded. + Returns the model's labels. + :param model_name: Model name + :return: List of model labels + """ + if not self.model_loaded(model_name): + self.load_model(model_name) + return self.models_dict[model_name].labels + + def get_labels_custom(self, model_name): + """ + Hashes every label of a specific model. + :param model_name: Model name + :return: A list of mode's labels with their hashed values + """ + if re.match(r'[0-9a-fA-F]{8}\-[0-9a-fA-F]{4}\-[0-9a-fA-F]{4}\-[0-9a-fA-F]{4}\-[0-9a-fA-F]{12}', model_name, + flags=0): + for key, value in self.models_hash_dict.items(): + if value == model_name: + model_name = key + models = self.list_models() + if model_name not in self.labels_hash_dict: + model_dict = {} + for label in self.models_dict[model_name].labels: + model_dict[label] = str(uuid.uuid4()) + self.labels_hash_dict[model_name] = model_dict + for key in list(self.labels_hash_dict): + if key not in models: + del self.labels_hash_dict[key] + return self.labels_hash_dict[model_name] + + def get_config(self, model_name): + """ + Returns the model's configuration. + :param model_name: Model name + :return: List of model's configuration + """ + if not self.model_loaded(model_name): + self.load_model(model_name) + return self.models_dict[model_name].configuration diff --git a/src/main/fonts/DejaVuSans.ttf b/src/main/fonts/DejaVuSans.ttf new file mode 100644 index 0000000..e5f7eec Binary files /dev/null and b/src/main/fonts/DejaVuSans.ttf differ diff --git a/src/main/inference/ConfigurationSchema.json b/src/main/inference/ConfigurationSchema.json new file mode 100644 index 0000000..66c7699 --- /dev/null +++ b/src/main/inference/ConfigurationSchema.json @@ -0,0 +1,38 @@ +{ + "type": "object", + "properties": { + "inference_engine_name": { + "type": "string" + }, + "confidence": { + "type": "number", + "minimum": 0, + "maximum": 100 + }, + "predictions": { + "type": "number", + "minimum": 0 + }, + "number_of_classes": { + "type": "number" + }, + "framework": { + "type": "string" + }, + "type": { + "type": "string" + }, + "network": { + "type": "string" + } + }, + "required": [ + "inference_engine_name", + "confidence", + "predictions", + "number_of_classes", + "framework", + "type", + "network" + ] +} \ No newline at end of file diff --git a/src/main/inference/__init__.py b/src/main/inference/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/main/inference/base_error.py b/src/main/inference/base_error.py new file mode 100644 index 0000000..e7d4fca --- /dev/null +++ b/src/main/inference/base_error.py @@ -0,0 +1,46 @@ +import logging +from datetime import datetime +from abc import ABC, abstractmethod + + +class AbstractError(ABC): + + def __init__(self): + """ + Sets the logger file, level, and format. + The logging file will contain the logging level, request date, request status, and model response. + """ + self.logger = logging.getLogger('logger') + date = datetime.now().strftime('%Y-%m-%d') + file_path = 'logs/tensorflow_inference_engine_' + date + '.log' + self.handler = logging.FileHandler(file_path) + self.handler.setLevel(logging.INFO) + self.handler.setFormatter(logging.Formatter("%(levelname)s;%(asctime)s;%(message)s")) + self.logger.addHandler(self.handler) + + @abstractmethod + def info(self, message): + """ + Logs an info message to the logging file. + :param message: Containing the request status and the model response + :return: + """ + pass + + @abstractmethod + def warning(self, message): + """ + Logs a warning message to the logging file. + :param message: Containing the request status and the model response + :return: + """ + pass + + @abstractmethod + def error(self, message): + """ + Logs an Error message to the logging file. + :param message: Containing the request status and the model response + :return: + """ + pass diff --git a/src/main/inference/base_inference_engine.py b/src/main/inference/base_inference_engine.py new file mode 100644 index 0000000..4238c11 --- /dev/null +++ b/src/main/inference/base_inference_engine.py @@ -0,0 +1,92 @@ +from abc import ABC, abstractmethod +from inference.exceptions import InvalidModelConfiguration, ModelNotLoaded, ApplicationError + + +class AbstractInferenceEngine(ABC): + + def __init__(self, model_path): + """ + Takes a model path and calls the load function. + :param model_path: The model's path + :return: + """ + self.labels = [] + self.configuration = {} + self.model_path = model_path + try: + self.validate_configuration() + except ApplicationError as e: + raise e + try: + self.load() + except ApplicationError as e: + raise e + except Exception as e: + raise ModelNotLoaded() + + @abstractmethod + def load(self): + """ + Loads the model based on the underlying implementation. + """ + pass + + @abstractmethod + def free(self): + """ + Performs any manual memory implementation required to when unloading a model. + Will be called when the class's destructor is called. + """ + pass + + @abstractmethod + async def run(self, input_data, draw_boxes, predict_batch): + """ + Performs the required inference based on the underlying implementation of this class. + Could be used to return classification predictions, object detection coordinates... + :param predict_batch: Boolean + :param input_data: A single image + :param draw_boxes: Used to draw bounding boxes on image instead of returning them + :return: A bounding-box + """ + pass + + @abstractmethod + async def run_batch(self, input_data, draw_boxes, predict_batch): + """ + Iterates over images and returns a prediction for each one. + :param predict_batch: Boolean + :param input_data: List of images + :param draw_boxes: Used to draw bounding boxes on image instead of returning them + :return: List of bounding-boxes + """ + pass + + @abstractmethod + def validate_configuration(self): + """ + Validates that the model and its files are valid based on the underlying implementation's requirements. + Can check for configuration values, folder structure... + """ + pass + + @abstractmethod + def set_configuration(self, data): + """ + Takes the configuration from the config.json file + :param data: Json data + :return: + """ + pass + + @abstractmethod + def validate_json_configuration(self, data): + """ + Validates the configuration of the config.json file. + :param data: Json data + :return: + """ + pass + + def __del__(self): + self.free() diff --git a/src/main/inference/errors.py b/src/main/inference/errors.py new file mode 100644 index 0000000..ab49b87 --- /dev/null +++ b/src/main/inference/errors.py @@ -0,0 +1,47 @@ +import os +import logging +from datetime import datetime, date +from inference.base_error import AbstractError + + +class Error(AbstractError): + + def __init__(self): + if 'logs' not in os.listdir(): + os.mkdir('logs') + self.date = None + super().__init__() + + def info(self, message): + self.check_date() + self.logger.info(message) + + def warning(self, message): + self.check_date() + self.logger.warning(message) + + def error(self, message): + self.check_date() + self.logger.error(message) + + def check_date(self): + """ + Divides logging per day. Each logging file corresponds to a specific day. + It also removes all logging files exceeding one year. + :return: + """ + self.date = datetime.now().strftime('%Y-%m-%d') + file_path = 'tensorflow_inference_engine_' + self.date + '.log' + if file_path not in os.listdir('logs'): + self.logger.removeHandler(self.handler) + self.handler = logging.FileHandler('logs/' + file_path) + self.handler.setLevel(logging.INFO) + self.handler.setFormatter(logging.Formatter("%(levelname)s;%(asctime)s;%(message)s")) + self.logger.addHandler(self.handler) + oldest_log_file = os.listdir('logs')[0] + oldest_date = oldest_log_file.split("_")[3].split('.')[0] + a = datetime.strptime(datetime.now().strftime('%Y-%m-%d'), '%Y-%m-%d') + b = datetime.strptime(oldest_date, '%Y-%m-%d') + delta = a - b + if delta.days > 365: + os.remove('logs/' + oldest_log_file) diff --git a/src/main/inference/exceptions.py b/src/main/inference/exceptions.py new file mode 100644 index 0000000..1dc4f4a --- /dev/null +++ b/src/main/inference/exceptions.py @@ -0,0 +1,56 @@ +__metaclass__ = type + + +class ApplicationError(Exception): + """Base class for other exceptions""" + + def __init__(self, default_message, additional_message=''): + self.default_message = default_message + self.additional_message = additional_message + + def __str__(self): + return self.get_message() + + def get_message(self): + return self.default_message if self.additional_message == '' else "{}: {}".format(self.default_message, + self.additional_message) + + +class InvalidModelConfiguration(ApplicationError): + """Raised when the model's configuration is corrupted""" + + def __init__(self, additional_message=''): + # super('Invalid model configuration', additional_message) + super().__init__('Invalid model configuration', additional_message) + + +class ModelNotFound(ApplicationError): + """Raised when the model is not found""" + + def __init__(self, additional_message=''): + # super('Model not found', additional_message) + super().__init__('Model not found', additional_message) + + +class ModelNotLoaded(ApplicationError): + """Raised when the model is not loaded""" + + def __init__(self, additional_message=''): + # super('Error loading model', additional_message) + super().__init__('Error loading model', additional_message) + + +class InvalidInputData(ApplicationError): + """Raised when the input data is corrupted""" + + def __init__(self, additional_message=''): + # super('Invalid input data', additional_message) + super().__init__('Invalid input data', additional_message) + + +class InferenceEngineNotFound(ApplicationError): + """Raised when the Inference Engine is not found""" + + def __init__(self, additional_message=''): + # super('Inference engine not found', additional_message) + super().__init__('Inference engine not found', additional_message) diff --git a/src/main/inference/inference_engines_factory.py b/src/main/inference/inference_engines_factory.py new file mode 100644 index 0000000..e09e114 --- /dev/null +++ b/src/main/inference/inference_engines_factory.py @@ -0,0 +1,32 @@ +import os +import json +from inference.exceptions import ModelNotFound, ApplicationError, InvalidModelConfiguration, InferenceEngineNotFound, ModelNotLoaded + + +class InferenceEngineFactory: + + @staticmethod + def get_engine(path_to_model): + """ + Reads the model's inference engine from the model's configuration and calls the right inference engine class. + :param path_to_model: Model's path + :return: The model's instance + """ + if not os.path.exists(path_to_model): + raise ModelNotFound() + try: + configuration = json.loads(open(os.path.join(path_to_model, 'config.json')).read()) + except Exception: + raise InvalidModelConfiguration('config.json not found or corrupted') + try: + inference_engine_name = configuration['inference_engine_name'] + except Exception: + raise InvalidModelConfiguration('missing inference engine name in config.json') + try: + # import one of the available inference engine class (in this project there's only one), and return a + # model instance + return getattr(__import__(inference_engine_name), 'InferenceEngine')(path_to_model) + except ApplicationError as e: + raise e + except Exception as e: + raise InferenceEngineNotFound(inference_engine_name) diff --git a/src/main/inference/tensorflow_detection.py b/src/main/inference/tensorflow_detection.py new file mode 100644 index 0000000..614ae12 --- /dev/null +++ b/src/main/inference/tensorflow_detection.py @@ -0,0 +1,225 @@ +import os +import uuid +import jsonschema +import asyncio +import json +import numpy as np +import tensorflow as tf +from PIL import Image, ImageDraw, ImageFont +from object_detection.utils import label_map_util +from inference.base_inference_engine import AbstractInferenceEngine +from inference.exceptions import InvalidModelConfiguration, InvalidInputData, ApplicationError + + +class InferenceEngine(AbstractInferenceEngine): + + def __init__(self, model_path): + self.label_path = "" + self.NUM_CLASSES = None + self.sess = None + self.label_map = None + self.categories = None + self.category_index = None + self.detection_graph = None + self.image_tensor = None + self.d_boxes = None + self.d_scores = None + self.d_classes = None + self.num_d = None + self.font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 20) + super().__init__(model_path) + + def load(self): + with open(os.path.join(self.model_path, 'config.json')) as f: + data = json.load(f) + try: + self.validate_json_configuration(data) + self.set_configuration(data) + except ApplicationError as e: + raise e + + self.label_path = os.path.join(self.model_path, 'object-detection.pbtxt') + self.label_map = label_map_util.load_labelmap(self.label_path) + self.categories = label_map_util.convert_label_map_to_categories(self.label_map, + max_num_classes=self.NUM_CLASSES, + use_display_name=True) + for dict in self.categories: + self.labels.append(dict['name']) + + self.category_index = label_map_util.create_category_index(self.categories) + self.detection_graph = tf.Graph() + with self.detection_graph.as_default(): + od_graph_def = tf.GraphDef() + with tf.gfile.GFile(os.path.join(self.model_path, 'frozen_inference_graph.pb'), 'rb') as fid: + serialized_graph = fid.read() + od_graph_def.ParseFromString(serialized_graph) + tf.import_graph_def(od_graph_def, name='') + self.image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0') + self.d_boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0') + self.d_scores = self.detection_graph.get_tensor_by_name('detection_scores:0') + self.d_classes = self.detection_graph.get_tensor_by_name('detection_classes:0') + self.num_d = self.detection_graph.get_tensor_by_name('num_detections:0') + self.sess = tf.Session(graph=self.detection_graph) + img = Image.open("object_detection/image1.jpg") + img_expanded = np.expand_dims(img, axis=0) + (boxes, scores, classes, num) = self.sess.run( + [self.d_boxes, self.d_scores, self.d_classes, self.num_d], + feed_dict={self.image_tensor: img_expanded}) + + async def run(self, input_data, draw_boxes, predict_batch): + image_path = '/main/' + str(input_data.filename) + open(image_path, 'wb').write(input_data.file.read()) + try: + post_process = await self.processing(image_path, predict_batch) + except ApplicationError as e: + os.remove(image_path) + raise e + except Exception as e: + os.remove(image_path) + raise InvalidInputData() + # pass + if not draw_boxes: + os.remove(image_path) + return post_process + else: + try: + self.draw_bounding_boxes(input_data, post_process['bounding-boxes']) + except ApplicationError as e: + raise e + except Exception as e: + raise e + + async def run_batch(self, input_data, draw_boxes, predict_batch): + result_list = [] + for image in input_data: + post_process = await self.run(image, draw_boxes, predict_batch) + if post_process is not None: + result_list.append(post_process) + return result_list + + def get_classification(self, img): + """ + Processes image and returns tensors. + :param img: Processed image + :return: Tensors to form a prediction + """ + # Bounding Box Detection. + with self.detection_graph.as_default(): + # Expand dimension since the model expects image to have shape [1, None, None, 3]. + img_expanded = np.expand_dims(img, axis=0) + (boxes, scores, classes, num) = self.sess.run( + [self.d_boxes, self.d_scores, self.d_classes, self.num_d], + feed_dict={self.image_tensor: img_expanded}) + classes_names = ([self.category_index.get(i) for i in classes[0]]) + return boxes, scores, classes, classes_names, num + + async def processing(self, image_path, predict_batch): + """ + Preprocesses image and form a prediction layout. + :param predict_batch: Boolean + :param image_path: Image path + :return: Image prediction + """ + await asyncio.sleep(0.00001) + try: + with open(self.model_path + '/config.json') as f: + data = json.load(f) + except Exception as e: + raise InvalidModelConfiguration('config.json not found or corrupted') + + json_confidence = data['confidence'] + json_predictions = data['predictions'] + image = Image.open(image_path).convert('RGB') + (boxes, scores, classes, classes_names, num) = self.get_classification(image) + names_start = [] + for name in classes_names: + if name is not None: + names_start.append(name['name']) + + width, height = image.size + + names = [] + confidence = [] + ids = [] + bounding_boxes = [] + # conf_predictions = 100 + # conf_confidence = 0.0 + + for i in range(json_predictions): + if scores[0][i] * 100 >= json_confidence: + ymin = int(round(boxes[0][i][0] * height)) if int(round(boxes[0][i][0] * height)) > 0 else 0 + xmin = int(round(boxes[0][i][1] * width)) if int(round(boxes[0][i][1] * height)) > 0 else 0 + ymax = int(round(boxes[0][i][2] * height)) if int(round(boxes[0][i][2] * height)) > 0 else 0 + xmax = int(round(boxes[0][i][3] * width)) if int(round(boxes[0][i][3] * height)) > 0 else 0 + tmp = dict([('left', xmin), ('top', ymin), ('right', xmax), ('bottom', ymax)]) + bounding_boxes.append(tmp) + confidence.append(float(scores[0][i] * 100)) + ids.append(int(classes[0][i])) + names.append(names_start[i]) + + responses_list = zip(names, confidence, bounding_boxes, ids) + + output = [] + + for response in responses_list: + tmp = dict([('ObjectClassName', response[0]), ('confidence', response[1]), ('coordinates', response[2]), + ('ObjectClassId', response[3])]) + output.append(tmp) + if predict_batch: + results = dict([('bounding-boxes', output), ('ImageName', image_path.split('/')[2])]) + else: + results = dict([('bounding-boxes', output)]) + return results + + def draw_bounding_boxes(self, input_data, bboxes): + """ + Draws bounding boxes on image and saves it. + :param input_data: A single image + :param bboxes: Bounding boxes + :return: + """ + left = 0 + top = 0 + conf = 0 + # image_path = '/main/result.jpg' + image_path = '/main/' + str(input_data.filename) + # open(image_path, 'wb').write(input_data.file.read()) + image = Image.open(image_path) + draw = ImageDraw.Draw(image) + for bbox in bboxes: + draw.rectangle([bbox['coordinates']['left'], bbox['coordinates']['top'], bbox['coordinates']['right'], + bbox['coordinates']['bottom']], outline="red") + left = bbox['coordinates']['left'] + top = bbox['coordinates']['top'] + conf = "{0:.2f}".format(bbox['confidence']) + draw.text((int(left), int(top) - 20), str(conf) + "% " + str(bbox['ObjectClassName']), 'red', self.font) + os.remove(image_path) + image.save('/main/result.jpg', 'PNG') + + + def free(self): + pass + + def validate_configuration(self): + # check if weights file exists + if not os.path.exists(os.path.join(self.model_path, 'frozen_inference_graph.pb')): + raise InvalidModelConfiguration('frozen_inference_graph.pb not found') + # check if labels file exists + if not os.path.exists(os.path.join(self.model_path, 'object-detection.pbtxt')): + raise InvalidModelConfiguration('object-detection.pbtxt not found') + return True + + def set_configuration(self, data): + self.configuration['framework'] = data['framework'] + self.configuration['type'] = data['type'] + self.configuration['network'] = data['network'] + self.NUM_CLASSES = data['number_of_classes'] + + def validate_json_configuration(self, data): + with open(os.path.join('inference', 'ConfigurationSchema.json')) as f: + schema = json.load(f) + try: + jsonschema.validate(data, schema) + except Exception as e: + raise InvalidModelConfiguration(e) + diff --git a/src/main/model_hash.json b/src/main/model_hash.json new file mode 100644 index 0000000..a366785 --- /dev/null +++ b/src/main/model_hash.json @@ -0,0 +1 @@ +{"schriftzug5": "6ca289a2-2a93-4ceb-ae0b-527ca1ce8ed9"} \ No newline at end of file diff --git a/src/main/models.py b/src/main/models.py new file mode 100644 index 0000000..cc58851 --- /dev/null +++ b/src/main/models.py @@ -0,0 +1,12 @@ +class ApiResponse: + + def __init__(self, success=True, data=None, error=None): + """ + Defines the response shape + :param success: A boolean that returns if the request has succeeded or not + :param data: The model's response + :param error: The error in case an exception was raised + """ + self.data = data + self.error = error.get_message() if error is not None else '' + self.success = success diff --git a/src/main/object_detection/core/__init__.py b/src/main/object_detection/core/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/main/object_detection/core/__init__.py @@ -0,0 +1 @@ + diff --git a/src/main/object_detection/core/box_list.py b/src/main/object_detection/core/box_list.py new file mode 100644 index 0000000..c0196f0 --- /dev/null +++ b/src/main/object_detection/core/box_list.py @@ -0,0 +1,207 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Bounding Box List definition. + +BoxList represents a list of bounding boxes as tensorflow +tensors, where each bounding box is represented as a row of 4 numbers, +[y_min, x_min, y_max, x_max]. It is assumed that all bounding boxes +within a given list correspond to a single image. See also +box_list_ops.py for common box related operations (such as area, iou, etc). + +Optionally, users can add additional related fields (such as weights). +We assume the following things to be true about fields: +* they correspond to boxes in the box_list along the 0th dimension +* they have inferrable rank at graph construction time +* all dimensions except for possibly the 0th can be inferred + (i.e., not None) at graph construction time. + +Some other notes: + * Following tensorflow conventions, we use height, width ordering, + and correspondingly, y,x (or ymin, xmin, ymax, xmax) ordering + * Tensors are always provided as (flat) [N, 4] tensors. +""" + +import tensorflow as tf + + +class BoxList(object): + """Box collection.""" + + def __init__(self, boxes): + """Constructs box collection. + + Args: + boxes: a tensor of shape [N, 4] representing box corners + + Raises: + ValueError: if invalid dimensions for bbox data or if bbox data is not in + float32 format. + """ + if len(boxes.get_shape()) != 2 or boxes.get_shape()[-1] != 4: + raise ValueError('Invalid dimensions for box data.') + if boxes.dtype != tf.float32: + raise ValueError('Invalid tensor type: should be tf.float32') + self.data = {'boxes': boxes} + + def num_boxes(self): + """Returns number of boxes held in collection. + + Returns: + a tensor representing the number of boxes held in the collection. + """ + return tf.shape(self.data['boxes'])[0] + + def num_boxes_static(self): + """Returns number of boxes held in collection. + + This number is inferred at graph construction time rather than run-time. + + Returns: + Number of boxes held in collection (integer) or None if this is not + inferrable at graph construction time. + """ + return self.data['boxes'].get_shape()[0].value + + def get_all_fields(self): + """Returns all fields.""" + return self.data.keys() + + def get_extra_fields(self): + """Returns all non-box fields (i.e., everything not named 'boxes').""" + return [k for k in self.data.keys() if k != 'boxes'] + + def add_field(self, field, field_data): + """Add field to box list. + + This method can be used to add related box data such as + weights/labels, etc. + + Args: + field: a string key to access the data via `get` + field_data: a tensor containing the data to store in the BoxList + """ + self.data[field] = field_data + + def has_field(self, field): + return field in self.data + + def get(self): + """Convenience function for accessing box coordinates. + + Returns: + a tensor with shape [N, 4] representing box coordinates. + """ + return self.get_field('boxes') + + def set(self, boxes): + """Convenience function for setting box coordinates. + + Args: + boxes: a tensor of shape [N, 4] representing box corners + + Raises: + ValueError: if invalid dimensions for bbox data + """ + if len(boxes.get_shape()) != 2 or boxes.get_shape()[-1] != 4: + raise ValueError('Invalid dimensions for box data.') + self.data['boxes'] = boxes + + def get_field(self, field): + """Accesses a box collection and associated fields. + + This function returns specified field with object; if no field is specified, + it returns the box coordinates. + + Args: + field: this optional string parameter can be used to specify + a related field to be accessed. + + Returns: + a tensor representing the box collection or an associated field. + + Raises: + ValueError: if invalid field + """ + if not self.has_field(field): + raise ValueError('field ' + str(field) + ' does not exist') + return self.data[field] + + def set_field(self, field, value): + """Sets the value of a field. + + Updates the field of a box_list with a given value. + + Args: + field: (string) name of the field to set value. + value: the value to assign to the field. + + Raises: + ValueError: if the box_list does not have specified field. + """ + if not self.has_field(field): + raise ValueError('field %s does not exist' % field) + self.data[field] = value + + def get_center_coordinates_and_sizes(self, scope=None): + """Computes the center coordinates, height and width of the boxes. + + Args: + scope: name scope of the function. + + Returns: + a list of 4 1-D tensors [ycenter, xcenter, height, width]. + """ + with tf.name_scope(scope, 'get_center_coordinates_and_sizes'): + box_corners = self.get() + ymin, xmin, ymax, xmax = tf.unstack(tf.transpose(box_corners)) + width = xmax - xmin + height = ymax - ymin + ycenter = ymin + height / 2. + xcenter = xmin + width / 2. + return [ycenter, xcenter, height, width] + + def transpose_coordinates(self, scope=None): + """Transpose the coordinate representation in a boxlist. + + Args: + scope: name scope of the function. + """ + with tf.name_scope(scope, 'transpose_coordinates'): + y_min, x_min, y_max, x_max = tf.split( + value=self.get(), num_or_size_splits=4, axis=1) + self.set(tf.concat([x_min, y_min, x_max, y_max], 1)) + + def as_tensor_dict(self, fields=None): + """Retrieves specified fields as a dictionary of tensors. + + Args: + fields: (optional) list of fields to return in the dictionary. + If None (default), all fields are returned. + + Returns: + tensor_dict: A dictionary of tensors specified by fields. + + Raises: + ValueError: if specified field is not contained in boxlist. + """ + tensor_dict = {} + if fields is None: + fields = self.get_all_fields() + for field in fields: + if not self.has_field(field): + raise ValueError('boxlist must contain all specified fields') + tensor_dict[field] = self.get_field(field) + return tensor_dict diff --git a/src/main/object_detection/core/box_list_ops.py b/src/main/object_detection/core/box_list_ops.py new file mode 100644 index 0000000..a755ef6 --- /dev/null +++ b/src/main/object_detection/core/box_list_ops.py @@ -0,0 +1,1061 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Bounding Box List operations. + +Example box operations that are supported: + * areas: compute bounding box areas + * iou: pairwise intersection-over-union scores + * sq_dist: pairwise distances between bounding boxes + +Whenever box_list_ops functions output a BoxList, the fields of the incoming +BoxList are retained unless documented otherwise. +""" +import tensorflow as tf + +from object_detection.core import box_list +from object_detection.utils import shape_utils + + +class SortOrder(object): + """Enum class for sort order. + + Attributes: + ascend: ascend order. + descend: descend order. + """ + ascend = 1 + descend = 2 + + +def area(boxlist, scope=None): + """Computes area of boxes. + + Args: + boxlist: BoxList holding N boxes + scope: name scope. + + Returns: + a tensor with shape [N] representing box areas. + """ + with tf.name_scope(scope, 'Area'): + y_min, x_min, y_max, x_max = tf.split( + value=boxlist.get(), num_or_size_splits=4, axis=1) + return tf.squeeze((y_max - y_min) * (x_max - x_min), [1]) + + +def height_width(boxlist, scope=None): + """Computes height and width of boxes in boxlist. + + Args: + boxlist: BoxList holding N boxes + scope: name scope. + + Returns: + Height: A tensor with shape [N] representing box heights. + Width: A tensor with shape [N] representing box widths. + """ + with tf.name_scope(scope, 'HeightWidth'): + y_min, x_min, y_max, x_max = tf.split( + value=boxlist.get(), num_or_size_splits=4, axis=1) + return tf.squeeze(y_max - y_min, [1]), tf.squeeze(x_max - x_min, [1]) + + +def scale(boxlist, y_scale, x_scale, scope=None): + """scale box coordinates in x and y dimensions. + + Args: + boxlist: BoxList holding N boxes + y_scale: (float) scalar tensor + x_scale: (float) scalar tensor + scope: name scope. + + Returns: + boxlist: BoxList holding N boxes + """ + with tf.name_scope(scope, 'Scale'): + y_scale = tf.cast(y_scale, tf.float32) + x_scale = tf.cast(x_scale, tf.float32) + y_min, x_min, y_max, x_max = tf.split( + value=boxlist.get(), num_or_size_splits=4, axis=1) + y_min = y_scale * y_min + y_max = y_scale * y_max + x_min = x_scale * x_min + x_max = x_scale * x_max + scaled_boxlist = box_list.BoxList( + tf.concat([y_min, x_min, y_max, x_max], 1)) + return _copy_extra_fields(scaled_boxlist, boxlist) + + +def clip_to_window(boxlist, window, filter_nonoverlapping=True, scope=None): + """Clip bounding boxes to a window. + + This op clips any input bounding boxes (represented by bounding box + corners) to a window, optionally filtering out boxes that do not + overlap at all with the window. + + Args: + boxlist: BoxList holding M_in boxes + window: a tensor of shape [4] representing the [y_min, x_min, y_max, x_max] + window to which the op should clip boxes. + filter_nonoverlapping: whether to filter out boxes that do not overlap at + all with the window. + scope: name scope. + + Returns: + a BoxList holding M_out boxes where M_out <= M_in + """ + with tf.name_scope(scope, 'ClipToWindow'): + y_min, x_min, y_max, x_max = tf.split( + value=boxlist.get(), num_or_size_splits=4, axis=1) + win_y_min, win_x_min, win_y_max, win_x_max = tf.unstack(window) + y_min_clipped = tf.maximum(tf.minimum(y_min, win_y_max), win_y_min) + y_max_clipped = tf.maximum(tf.minimum(y_max, win_y_max), win_y_min) + x_min_clipped = tf.maximum(tf.minimum(x_min, win_x_max), win_x_min) + x_max_clipped = tf.maximum(tf.minimum(x_max, win_x_max), win_x_min) + clipped = box_list.BoxList( + tf.concat([y_min_clipped, x_min_clipped, y_max_clipped, x_max_clipped], + 1)) + clipped = _copy_extra_fields(clipped, boxlist) + if filter_nonoverlapping: + areas = area(clipped) + nonzero_area_indices = tf.cast( + tf.reshape(tf.where(tf.greater(areas, 0.0)), [-1]), tf.int32) + clipped = gather(clipped, nonzero_area_indices) + return clipped + + +def prune_outside_window(boxlist, window, scope=None): + """Prunes bounding boxes that fall outside a given window. + + This function prunes bounding boxes that even partially fall outside the given + window. See also clip_to_window which only prunes bounding boxes that fall + completely outside the window, and clips any bounding boxes that partially + overflow. + + Args: + boxlist: a BoxList holding M_in boxes. + window: a float tensor of shape [4] representing [ymin, xmin, ymax, xmax] + of the window + scope: name scope. + + Returns: + pruned_corners: a tensor with shape [M_out, 4] where M_out <= M_in + valid_indices: a tensor with shape [M_out] indexing the valid bounding boxes + in the input tensor. + """ + with tf.name_scope(scope, 'PruneOutsideWindow'): + y_min, x_min, y_max, x_max = tf.split( + value=boxlist.get(), num_or_size_splits=4, axis=1) + win_y_min, win_x_min, win_y_max, win_x_max = tf.unstack(window) + coordinate_violations = tf.concat([ + tf.less(y_min, win_y_min), tf.less(x_min, win_x_min), + tf.greater(y_max, win_y_max), tf.greater(x_max, win_x_max) + ], 1) + valid_indices = tf.reshape( + tf.where(tf.logical_not(tf.reduce_any(coordinate_violations, 1))), [-1]) + return gather(boxlist, valid_indices), valid_indices + + +def prune_completely_outside_window(boxlist, window, scope=None): + """Prunes bounding boxes that fall completely outside of the given window. + + The function clip_to_window prunes bounding boxes that fall + completely outside the window, but also clips any bounding boxes that + partially overflow. This function does not clip partially overflowing boxes. + + Args: + boxlist: a BoxList holding M_in boxes. + window: a float tensor of shape [4] representing [ymin, xmin, ymax, xmax] + of the window + scope: name scope. + + Returns: + pruned_boxlist: a new BoxList with all bounding boxes partially or fully in + the window. + valid_indices: a tensor with shape [M_out] indexing the valid bounding boxes + in the input tensor. + """ + with tf.name_scope(scope, 'PruneCompleteleyOutsideWindow'): + y_min, x_min, y_max, x_max = tf.split( + value=boxlist.get(), num_or_size_splits=4, axis=1) + win_y_min, win_x_min, win_y_max, win_x_max = tf.unstack(window) + coordinate_violations = tf.concat([ + tf.greater_equal(y_min, win_y_max), tf.greater_equal(x_min, win_x_max), + tf.less_equal(y_max, win_y_min), tf.less_equal(x_max, win_x_min) + ], 1) + valid_indices = tf.reshape( + tf.where(tf.logical_not(tf.reduce_any(coordinate_violations, 1))), [-1]) + return gather(boxlist, valid_indices), valid_indices + + +def intersection(boxlist1, boxlist2, scope=None): + """Compute pairwise intersection areas between boxes. + + Args: + boxlist1: BoxList holding N boxes + boxlist2: BoxList holding M boxes + scope: name scope. + + Returns: + a tensor with shape [N, M] representing pairwise intersections + """ + with tf.name_scope(scope, 'Intersection'): + y_min1, x_min1, y_max1, x_max1 = tf.split( + value=boxlist1.get(), num_or_size_splits=4, axis=1) + y_min2, x_min2, y_max2, x_max2 = tf.split( + value=boxlist2.get(), num_or_size_splits=4, axis=1) + all_pairs_min_ymax = tf.minimum(y_max1, tf.transpose(y_max2)) + all_pairs_max_ymin = tf.maximum(y_min1, tf.transpose(y_min2)) + intersect_heights = tf.maximum(0.0, all_pairs_min_ymax - all_pairs_max_ymin) + all_pairs_min_xmax = tf.minimum(x_max1, tf.transpose(x_max2)) + all_pairs_max_xmin = tf.maximum(x_min1, tf.transpose(x_min2)) + intersect_widths = tf.maximum(0.0, all_pairs_min_xmax - all_pairs_max_xmin) + return intersect_heights * intersect_widths + + +def matched_intersection(boxlist1, boxlist2, scope=None): + """Compute intersection areas between corresponding boxes in two boxlists. + + Args: + boxlist1: BoxList holding N boxes + boxlist2: BoxList holding N boxes + scope: name scope. + + Returns: + a tensor with shape [N] representing pairwise intersections + """ + with tf.name_scope(scope, 'MatchedIntersection'): + y_min1, x_min1, y_max1, x_max1 = tf.split( + value=boxlist1.get(), num_or_size_splits=4, axis=1) + y_min2, x_min2, y_max2, x_max2 = tf.split( + value=boxlist2.get(), num_or_size_splits=4, axis=1) + min_ymax = tf.minimum(y_max1, y_max2) + max_ymin = tf.maximum(y_min1, y_min2) + intersect_heights = tf.maximum(0.0, min_ymax - max_ymin) + min_xmax = tf.minimum(x_max1, x_max2) + max_xmin = tf.maximum(x_min1, x_min2) + intersect_widths = tf.maximum(0.0, min_xmax - max_xmin) + return tf.reshape(intersect_heights * intersect_widths, [-1]) + + +def iou(boxlist1, boxlist2, scope=None): + """Computes pairwise intersection-over-union between box collections. + + Args: + boxlist1: BoxList holding N boxes + boxlist2: BoxList holding M boxes + scope: name scope. + + Returns: + a tensor with shape [N, M] representing pairwise iou scores. + """ + with tf.name_scope(scope, 'IOU'): + intersections = intersection(boxlist1, boxlist2) + areas1 = area(boxlist1) + areas2 = area(boxlist2) + unions = ( + tf.expand_dims(areas1, 1) + tf.expand_dims(areas2, 0) - intersections) + return tf.where( + tf.equal(intersections, 0.0), + tf.zeros_like(intersections), tf.truediv(intersections, unions)) + + +def matched_iou(boxlist1, boxlist2, scope=None): + """Compute intersection-over-union between corresponding boxes in boxlists. + + Args: + boxlist1: BoxList holding N boxes + boxlist2: BoxList holding N boxes + scope: name scope. + + Returns: + a tensor with shape [N] representing pairwise iou scores. + """ + with tf.name_scope(scope, 'MatchedIOU'): + intersections = matched_intersection(boxlist1, boxlist2) + areas1 = area(boxlist1) + areas2 = area(boxlist2) + unions = areas1 + areas2 - intersections + return tf.where( + tf.equal(intersections, 0.0), + tf.zeros_like(intersections), tf.truediv(intersections, unions)) + + +def ioa(boxlist1, boxlist2, scope=None): + """Computes pairwise intersection-over-area between box collections. + + intersection-over-area (IOA) between two boxes box1 and box2 is defined as + their intersection area over box2's area. Note that ioa is not symmetric, + that is, ioa(box1, box2) != ioa(box2, box1). + + Args: + boxlist1: BoxList holding N boxes + boxlist2: BoxList holding M boxes + scope: name scope. + + Returns: + a tensor with shape [N, M] representing pairwise ioa scores. + """ + with tf.name_scope(scope, 'IOA'): + intersections = intersection(boxlist1, boxlist2) + areas = tf.expand_dims(area(boxlist2), 0) + return tf.truediv(intersections, areas) + + +def prune_non_overlapping_boxes( + boxlist1, boxlist2, min_overlap=0.0, scope=None): + """Prunes the boxes in boxlist1 that overlap less than thresh with boxlist2. + + For each box in boxlist1, we want its IOA to be more than minoverlap with + at least one of the boxes in boxlist2. If it does not, we remove it. + + Args: + boxlist1: BoxList holding N boxes. + boxlist2: BoxList holding M boxes. + min_overlap: Minimum required overlap between boxes, to count them as + overlapping. + scope: name scope. + + Returns: + new_boxlist1: A pruned boxlist with size [N', 4]. + keep_inds: A tensor with shape [N'] indexing kept bounding boxes in the + first input BoxList `boxlist1`. + """ + with tf.name_scope(scope, 'PruneNonOverlappingBoxes'): + ioa_ = ioa(boxlist2, boxlist1) # [M, N] tensor + ioa_ = tf.reduce_max(ioa_, reduction_indices=[0]) # [N] tensor + keep_bool = tf.greater_equal(ioa_, tf.constant(min_overlap)) + keep_inds = tf.squeeze(tf.where(keep_bool), squeeze_dims=[1]) + new_boxlist1 = gather(boxlist1, keep_inds) + return new_boxlist1, keep_inds + + +def prune_small_boxes(boxlist, min_side, scope=None): + """Prunes small boxes in the boxlist which have a side smaller than min_side. + + Args: + boxlist: BoxList holding N boxes. + min_side: Minimum width AND height of box to survive pruning. + scope: name scope. + + Returns: + A pruned boxlist. + """ + with tf.name_scope(scope, 'PruneSmallBoxes'): + height, width = height_width(boxlist) + is_valid = tf.logical_and(tf.greater_equal(width, min_side), + tf.greater_equal(height, min_side)) + return gather(boxlist, tf.reshape(tf.where(is_valid), [-1])) + + +def change_coordinate_frame(boxlist, window, scope=None): + """Change coordinate frame of the boxlist to be relative to window's frame. + + Given a window of the form [ymin, xmin, ymax, xmax], + changes bounding box coordinates from boxlist to be relative to this window + (e.g., the min corner maps to (0,0) and the max corner maps to (1,1)). + + An example use case is data augmentation: where we are given groundtruth + boxes (boxlist) and would like to randomly crop the image to some + window (window). In this case we need to change the coordinate frame of + each groundtruth box to be relative to this new window. + + Args: + boxlist: A BoxList object holding N boxes. + window: A rank 1 tensor [4]. + scope: name scope. + + Returns: + Returns a BoxList object with N boxes. + """ + with tf.name_scope(scope, 'ChangeCoordinateFrame'): + win_height = window[2] - window[0] + win_width = window[3] - window[1] + boxlist_new = scale(box_list.BoxList( + boxlist.get() - [window[0], window[1], window[0], window[1]]), + 1.0 / win_height, 1.0 / win_width) + boxlist_new = _copy_extra_fields(boxlist_new, boxlist) + return boxlist_new + + +def sq_dist(boxlist1, boxlist2, scope=None): + """Computes the pairwise squared distances between box corners. + + This op treats each box as if it were a point in a 4d Euclidean space and + computes pairwise squared distances. + + Mathematically, we are given two matrices of box coordinates X and Y, + where X(i,:) is the i'th row of X, containing the 4 numbers defining the + corners of the i'th box in boxlist1. Similarly Y(j,:) corresponds to + boxlist2. We compute + Z(i,j) = ||X(i,:) - Y(j,:)||^2 + = ||X(i,:)||^2 + ||Y(j,:)||^2 - 2 X(i,:)' * Y(j,:), + + Args: + boxlist1: BoxList holding N boxes + boxlist2: BoxList holding M boxes + scope: name scope. + + Returns: + a tensor with shape [N, M] representing pairwise distances + """ + with tf.name_scope(scope, 'SqDist'): + sqnorm1 = tf.reduce_sum(tf.square(boxlist1.get()), 1, keep_dims=True) + sqnorm2 = tf.reduce_sum(tf.square(boxlist2.get()), 1, keep_dims=True) + innerprod = tf.matmul(boxlist1.get(), boxlist2.get(), + transpose_a=False, transpose_b=True) + return sqnorm1 + tf.transpose(sqnorm2) - 2.0 * innerprod + + +def boolean_mask(boxlist, indicator, fields=None, scope=None): + """Select boxes from BoxList according to indicator and return new BoxList. + + `boolean_mask` returns the subset of boxes that are marked as "True" by the + indicator tensor. By default, `boolean_mask` returns boxes corresponding to + the input index list, as well as all additional fields stored in the boxlist + (indexing into the first dimension). However one can optionally only draw + from a subset of fields. + + Args: + boxlist: BoxList holding N boxes + indicator: a rank-1 boolean tensor + fields: (optional) list of fields to also gather from. If None (default), + all fields are gathered from. Pass an empty fields list to only gather + the box coordinates. + scope: name scope. + + Returns: + subboxlist: a BoxList corresponding to the subset of the input BoxList + specified by indicator + Raises: + ValueError: if `indicator` is not a rank-1 boolean tensor. + """ + with tf.name_scope(scope, 'BooleanMask'): + if indicator.shape.ndims != 1: + raise ValueError('indicator should have rank 1') + if indicator.dtype != tf.bool: + raise ValueError('indicator should be a boolean tensor') + subboxlist = box_list.BoxList(tf.boolean_mask(boxlist.get(), indicator)) + if fields is None: + fields = boxlist.get_extra_fields() + for field in fields: + if not boxlist.has_field(field): + raise ValueError('boxlist must contain all specified fields') + subfieldlist = tf.boolean_mask(boxlist.get_field(field), indicator) + subboxlist.add_field(field, subfieldlist) + return subboxlist + + +def gather(boxlist, indices, fields=None, scope=None): + """Gather boxes from BoxList according to indices and return new BoxList. + + By default, `gather` returns boxes corresponding to the input index list, as + well as all additional fields stored in the boxlist (indexing into the + first dimension). However one can optionally only gather from a + subset of fields. + + Args: + boxlist: BoxList holding N boxes + indices: a rank-1 tensor of type int32 / int64 + fields: (optional) list of fields to also gather from. If None (default), + all fields are gathered from. Pass an empty fields list to only gather + the box coordinates. + scope: name scope. + + Returns: + subboxlist: a BoxList corresponding to the subset of the input BoxList + specified by indices + Raises: + ValueError: if specified field is not contained in boxlist or if the + indices are not of type int32 + """ + with tf.name_scope(scope, 'Gather'): + if len(indices.shape.as_list()) != 1: + raise ValueError('indices should have rank 1') + if indices.dtype != tf.int32 and indices.dtype != tf.int64: + raise ValueError('indices should be an int32 / int64 tensor') + subboxlist = box_list.BoxList(tf.gather(boxlist.get(), indices)) + if fields is None: + fields = boxlist.get_extra_fields() + for field in fields: + if not boxlist.has_field(field): + raise ValueError('boxlist must contain all specified fields') + subfieldlist = tf.gather(boxlist.get_field(field), indices) + subboxlist.add_field(field, subfieldlist) + return subboxlist + + +def concatenate(boxlists, fields=None, scope=None): + """Concatenate list of BoxLists. + + This op concatenates a list of input BoxLists into a larger BoxList. It also + handles concatenation of BoxList fields as long as the field tensor shapes + are equal except for the first dimension. + + Args: + boxlists: list of BoxList objects + fields: optional list of fields to also concatenate. By default, all + fields from the first BoxList in the list are included in the + concatenation. + scope: name scope. + + Returns: + a BoxList with number of boxes equal to + sum([boxlist.num_boxes() for boxlist in BoxList]) + Raises: + ValueError: if boxlists is invalid (i.e., is not a list, is empty, or + contains non BoxList objects), or if requested fields are not contained in + all boxlists + """ + with tf.name_scope(scope, 'Concatenate'): + if not isinstance(boxlists, list): + raise ValueError('boxlists should be a list') + if not boxlists: + raise ValueError('boxlists should have nonzero length') + for boxlist in boxlists: + if not isinstance(boxlist, box_list.BoxList): + raise ValueError('all elements of boxlists should be BoxList objects') + concatenated = box_list.BoxList( + tf.concat([boxlist.get() for boxlist in boxlists], 0)) + if fields is None: + fields = boxlists[0].get_extra_fields() + for field in fields: + first_field_shape = boxlists[0].get_field(field).get_shape().as_list() + first_field_shape[0] = -1 + if None in first_field_shape: + raise ValueError('field %s must have fully defined shape except for the' + ' 0th dimension.' % field) + for boxlist in boxlists: + if not boxlist.has_field(field): + raise ValueError('boxlist must contain all requested fields') + field_shape = boxlist.get_field(field).get_shape().as_list() + field_shape[0] = -1 + if field_shape != first_field_shape: + raise ValueError('field %s must have same shape for all boxlists ' + 'except for the 0th dimension.' % field) + concatenated_field = tf.concat( + [boxlist.get_field(field) for boxlist in boxlists], 0) + concatenated.add_field(field, concatenated_field) + return concatenated + + +def sort_by_field(boxlist, field, order=SortOrder.descend, scope=None): + """Sort boxes and associated fields according to a scalar field. + + A common use case is reordering the boxes according to descending scores. + + Args: + boxlist: BoxList holding N boxes. + field: A BoxList field for sorting and reordering the BoxList. + order: (Optional) descend or ascend. Default is descend. + scope: name scope. + + Returns: + sorted_boxlist: A sorted BoxList with the field in the specified order. + + Raises: + ValueError: if specified field does not exist + ValueError: if the order is not either descend or ascend + """ + with tf.name_scope(scope, 'SortByField'): + if order != SortOrder.descend and order != SortOrder.ascend: + raise ValueError('Invalid sort order') + + field_to_sort = boxlist.get_field(field) + if len(field_to_sort.shape.as_list()) != 1: + raise ValueError('Field should have rank 1') + + num_boxes = boxlist.num_boxes() + num_entries = tf.size(field_to_sort) + length_assert = tf.Assert( + tf.equal(num_boxes, num_entries), + ['Incorrect field size: actual vs expected.', num_entries, num_boxes]) + + with tf.control_dependencies([length_assert]): + # TODO(derekjchow): Remove with tf.device when top_k operation runs + # correctly on GPU. + with tf.device('/cpu:0'): + _, sorted_indices = tf.nn.top_k(field_to_sort, num_boxes, sorted=True) + + if order == SortOrder.ascend: + sorted_indices = tf.reverse_v2(sorted_indices, [0]) + + return gather(boxlist, sorted_indices) + + +def visualize_boxes_in_image(image, boxlist, normalized=False, scope=None): + """Overlay bounding box list on image. + + Currently this visualization plots a 1 pixel thick red bounding box on top + of the image. Note that tf.image.draw_bounding_boxes essentially is + 1 indexed. + + Args: + image: an image tensor with shape [height, width, 3] + boxlist: a BoxList + normalized: (boolean) specify whether corners are to be interpreted + as absolute coordinates in image space or normalized with respect to the + image size. + scope: name scope. + + Returns: + image_and_boxes: an image tensor with shape [height, width, 3] + """ + with tf.name_scope(scope, 'VisualizeBoxesInImage'): + if not normalized: + height, width, _ = tf.unstack(tf.shape(image)) + boxlist = scale(boxlist, + 1.0 / tf.cast(height, tf.float32), + 1.0 / tf.cast(width, tf.float32)) + corners = tf.expand_dims(boxlist.get(), 0) + image = tf.expand_dims(image, 0) + return tf.squeeze(tf.image.draw_bounding_boxes(image, corners), [0]) + + +def filter_field_value_equals(boxlist, field, value, scope=None): + """Filter to keep only boxes with field entries equal to the given value. + + Args: + boxlist: BoxList holding N boxes. + field: field name for filtering. + value: scalar value. + scope: name scope. + + Returns: + a BoxList holding M boxes where M <= N + + Raises: + ValueError: if boxlist not a BoxList object or if it does not have + the specified field. + """ + with tf.name_scope(scope, 'FilterFieldValueEquals'): + if not isinstance(boxlist, box_list.BoxList): + raise ValueError('boxlist must be a BoxList') + if not boxlist.has_field(field): + raise ValueError('boxlist must contain the specified field') + filter_field = boxlist.get_field(field) + gather_index = tf.reshape(tf.where(tf.equal(filter_field, value)), [-1]) + return gather(boxlist, gather_index) + + +def filter_greater_than(boxlist, thresh, scope=None): + """Filter to keep only boxes with score exceeding a given threshold. + + This op keeps the collection of boxes whose corresponding scores are + greater than the input threshold. + + TODO(jonathanhuang): Change function name to filter_scores_greater_than + + Args: + boxlist: BoxList holding N boxes. Must contain a 'scores' field + representing detection scores. + thresh: scalar threshold + scope: name scope. + + Returns: + a BoxList holding M boxes where M <= N + + Raises: + ValueError: if boxlist not a BoxList object or if it does not + have a scores field + """ + with tf.name_scope(scope, 'FilterGreaterThan'): + if not isinstance(boxlist, box_list.BoxList): + raise ValueError('boxlist must be a BoxList') + if not boxlist.has_field('scores'): + raise ValueError('input boxlist must have \'scores\' field') + scores = boxlist.get_field('scores') + if len(scores.shape.as_list()) > 2: + raise ValueError('Scores should have rank 1 or 2') + if len(scores.shape.as_list()) == 2 and scores.shape.as_list()[1] != 1: + raise ValueError('Scores should have rank 1 or have shape ' + 'consistent with [None, 1]') + high_score_indices = tf.cast(tf.reshape( + tf.where(tf.greater(scores, thresh)), + [-1]), tf.int32) + return gather(boxlist, high_score_indices) + + +def non_max_suppression(boxlist, thresh, max_output_size, scope=None): + """Non maximum suppression. + + This op greedily selects a subset of detection bounding boxes, pruning + away boxes that have high IOU (intersection over union) overlap (> thresh) + with already selected boxes. Note that this only works for a single class --- + to apply NMS to multi-class predictions, use MultiClassNonMaxSuppression. + + Args: + boxlist: BoxList holding N boxes. Must contain a 'scores' field + representing detection scores. + thresh: scalar threshold + max_output_size: maximum number of retained boxes + scope: name scope. + + Returns: + a BoxList holding M boxes where M <= max_output_size + Raises: + ValueError: if thresh is not in [0, 1] + """ + with tf.name_scope(scope, 'NonMaxSuppression'): + if not 0 <= thresh <= 1.0: + raise ValueError('thresh must be between 0 and 1') + if not isinstance(boxlist, box_list.BoxList): + raise ValueError('boxlist must be a BoxList') + if not boxlist.has_field('scores'): + raise ValueError('input boxlist must have \'scores\' field') + selected_indices = tf.image.non_max_suppression( + boxlist.get(), boxlist.get_field('scores'), + max_output_size, iou_threshold=thresh) + return gather(boxlist, selected_indices) + + +def _copy_extra_fields(boxlist_to_copy_to, boxlist_to_copy_from): + """Copies the extra fields of boxlist_to_copy_from to boxlist_to_copy_to. + + Args: + boxlist_to_copy_to: BoxList to which extra fields are copied. + boxlist_to_copy_from: BoxList from which fields are copied. + + Returns: + boxlist_to_copy_to with extra fields. + """ + for field in boxlist_to_copy_from.get_extra_fields(): + boxlist_to_copy_to.add_field(field, boxlist_to_copy_from.get_field(field)) + return boxlist_to_copy_to + + +def to_normalized_coordinates(boxlist, height, width, + check_range=True, scope=None): + """Converts absolute box coordinates to normalized coordinates in [0, 1]. + + Usually one uses the dynamic shape of the image or conv-layer tensor: + boxlist = box_list_ops.to_normalized_coordinates(boxlist, + tf.shape(images)[1], + tf.shape(images)[2]), + + This function raises an assertion failed error at graph execution time when + the maximum coordinate is smaller than 1.01 (which means that coordinates are + already normalized). The value 1.01 is to deal with small rounding errors. + + Args: + boxlist: BoxList with coordinates in terms of pixel-locations. + height: Maximum value for height of absolute box coordinates. + width: Maximum value for width of absolute box coordinates. + check_range: If True, checks if the coordinates are normalized or not. + scope: name scope. + + Returns: + boxlist with normalized coordinates in [0, 1]. + """ + with tf.name_scope(scope, 'ToNormalizedCoordinates'): + height = tf.cast(height, tf.float32) + width = tf.cast(width, tf.float32) + + if check_range: + max_val = tf.reduce_max(boxlist.get()) + max_assert = tf.Assert(tf.greater(max_val, 1.01), + ['max value is lower than 1.01: ', max_val]) + with tf.control_dependencies([max_assert]): + width = tf.identity(width) + + return scale(boxlist, 1 / height, 1 / width) + + +def to_absolute_coordinates(boxlist, + height, + width, + check_range=True, + maximum_normalized_coordinate=1.1, + scope=None): + """Converts normalized box coordinates to absolute pixel coordinates. + + This function raises an assertion failed error when the maximum box coordinate + value is larger than maximum_normalized_coordinate (in which case coordinates + are already absolute). + + Args: + boxlist: BoxList with coordinates in range [0, 1]. + height: Maximum value for height of absolute box coordinates. + width: Maximum value for width of absolute box coordinates. + check_range: If True, checks if the coordinates are normalized or not. + maximum_normalized_coordinate: Maximum coordinate value to be considered + as normalized, default to 1.1. + scope: name scope. + + Returns: + boxlist with absolute coordinates in terms of the image size. + + """ + with tf.name_scope(scope, 'ToAbsoluteCoordinates'): + height = tf.cast(height, tf.float32) + width = tf.cast(width, tf.float32) + + # Ensure range of input boxes is correct. + if check_range: + box_maximum = tf.reduce_max(boxlist.get()) + max_assert = tf.Assert( + tf.greater_equal(maximum_normalized_coordinate, box_maximum), + ['maximum box coordinate value is larger ' + 'than %f: ' % maximum_normalized_coordinate, box_maximum]) + with tf.control_dependencies([max_assert]): + width = tf.identity(width) + + return scale(boxlist, height, width) + + +def refine_boxes_multi_class(pool_boxes, + num_classes, + nms_iou_thresh, + nms_max_detections, + voting_iou_thresh=0.5): + """Refines a pool of boxes using non max suppression and box voting. + + Box refinement is done independently for each class. + + Args: + pool_boxes: (BoxList) A collection of boxes to be refined. pool_boxes must + have a rank 1 'scores' field and a rank 1 'classes' field. + num_classes: (int scalar) Number of classes. + nms_iou_thresh: (float scalar) iou threshold for non max suppression (NMS). + nms_max_detections: (int scalar) maximum output size for NMS. + voting_iou_thresh: (float scalar) iou threshold for box voting. + + Returns: + BoxList of refined boxes. + + Raises: + ValueError: if + a) nms_iou_thresh or voting_iou_thresh is not in [0, 1]. + b) pool_boxes is not a BoxList. + c) pool_boxes does not have a scores and classes field. + """ + if not 0.0 <= nms_iou_thresh <= 1.0: + raise ValueError('nms_iou_thresh must be between 0 and 1') + if not 0.0 <= voting_iou_thresh <= 1.0: + raise ValueError('voting_iou_thresh must be between 0 and 1') + if not isinstance(pool_boxes, box_list.BoxList): + raise ValueError('pool_boxes must be a BoxList') + if not pool_boxes.has_field('scores'): + raise ValueError('pool_boxes must have a \'scores\' field') + if not pool_boxes.has_field('classes'): + raise ValueError('pool_boxes must have a \'classes\' field') + + refined_boxes = [] + for i in range(num_classes): + boxes_class = filter_field_value_equals(pool_boxes, 'classes', i) + refined_boxes_class = refine_boxes(boxes_class, nms_iou_thresh, + nms_max_detections, voting_iou_thresh) + refined_boxes.append(refined_boxes_class) + return sort_by_field(concatenate(refined_boxes), 'scores') + + +def refine_boxes(pool_boxes, + nms_iou_thresh, + nms_max_detections, + voting_iou_thresh=0.5): + """Refines a pool of boxes using non max suppression and box voting. + + Args: + pool_boxes: (BoxList) A collection of boxes to be refined. pool_boxes must + have a rank 1 'scores' field. + nms_iou_thresh: (float scalar) iou threshold for non max suppression (NMS). + nms_max_detections: (int scalar) maximum output size for NMS. + voting_iou_thresh: (float scalar) iou threshold for box voting. + + Returns: + BoxList of refined boxes. + + Raises: + ValueError: if + a) nms_iou_thresh or voting_iou_thresh is not in [0, 1]. + b) pool_boxes is not a BoxList. + c) pool_boxes does not have a scores field. + """ + if not 0.0 <= nms_iou_thresh <= 1.0: + raise ValueError('nms_iou_thresh must be between 0 and 1') + if not 0.0 <= voting_iou_thresh <= 1.0: + raise ValueError('voting_iou_thresh must be between 0 and 1') + if not isinstance(pool_boxes, box_list.BoxList): + raise ValueError('pool_boxes must be a BoxList') + if not pool_boxes.has_field('scores'): + raise ValueError('pool_boxes must have a \'scores\' field') + + nms_boxes = non_max_suppression( + pool_boxes, nms_iou_thresh, nms_max_detections) + return box_voting(nms_boxes, pool_boxes, voting_iou_thresh) + + +def box_voting(selected_boxes, pool_boxes, iou_thresh=0.5): + """Performs box voting as described in S. Gidaris and N. Komodakis, ICCV 2015. + + Performs box voting as described in 'Object detection via a multi-region & + semantic segmentation-aware CNN model', Gidaris and Komodakis, ICCV 2015. For + each box 'B' in selected_boxes, we find the set 'S' of boxes in pool_boxes + with iou overlap >= iou_thresh. The location of B is set to the weighted + average location of boxes in S (scores are used for weighting). And the score + of B is set to the average score of boxes in S. + + Args: + selected_boxes: BoxList containing a subset of boxes in pool_boxes. These + boxes are usually selected from pool_boxes using non max suppression. + pool_boxes: BoxList containing a set of (possibly redundant) boxes. + iou_thresh: (float scalar) iou threshold for matching boxes in + selected_boxes and pool_boxes. + + Returns: + BoxList containing averaged locations and scores for each box in + selected_boxes. + + Raises: + ValueError: if + a) selected_boxes or pool_boxes is not a BoxList. + b) if iou_thresh is not in [0, 1]. + c) pool_boxes does not have a scores field. + """ + if not 0.0 <= iou_thresh <= 1.0: + raise ValueError('iou_thresh must be between 0 and 1') + if not isinstance(selected_boxes, box_list.BoxList): + raise ValueError('selected_boxes must be a BoxList') + if not isinstance(pool_boxes, box_list.BoxList): + raise ValueError('pool_boxes must be a BoxList') + if not pool_boxes.has_field('scores'): + raise ValueError('pool_boxes must have a \'scores\' field') + + iou_ = iou(selected_boxes, pool_boxes) + match_indicator = tf.to_float(tf.greater(iou_, iou_thresh)) + num_matches = tf.reduce_sum(match_indicator, 1) + # TODO(kbanoop): Handle the case where some boxes in selected_boxes do not + # match to any boxes in pool_boxes. For such boxes without any matches, we + # should return the original boxes without voting. + match_assert = tf.Assert( + tf.reduce_all(tf.greater(num_matches, 0)), + ['Each box in selected_boxes must match with at least one box ' + 'in pool_boxes.']) + + scores = tf.expand_dims(pool_boxes.get_field('scores'), 1) + scores_assert = tf.Assert( + tf.reduce_all(tf.greater_equal(scores, 0)), + ['Scores must be non negative.']) + + with tf.control_dependencies([scores_assert, match_assert]): + sum_scores = tf.matmul(match_indicator, scores) + averaged_scores = tf.reshape(sum_scores, [-1]) / num_matches + + box_locations = tf.matmul(match_indicator, + pool_boxes.get() * scores) / sum_scores + averaged_boxes = box_list.BoxList(box_locations) + _copy_extra_fields(averaged_boxes, selected_boxes) + averaged_boxes.add_field('scores', averaged_scores) + return averaged_boxes + + +def pad_or_clip_box_list(boxlist, num_boxes, scope=None): + """Pads or clips all fields of a BoxList. + + Args: + boxlist: A BoxList with arbitrary of number of boxes. + num_boxes: First num_boxes in boxlist are kept. + The fields are zero-padded if num_boxes is bigger than the + actual number of boxes. + scope: name scope. + + Returns: + BoxList with all fields padded or clipped. + """ + with tf.name_scope(scope, 'PadOrClipBoxList'): + subboxlist = box_list.BoxList(shape_utils.pad_or_clip_tensor( + boxlist.get(), num_boxes)) + for field in boxlist.get_extra_fields(): + subfield = shape_utils.pad_or_clip_tensor( + boxlist.get_field(field), num_boxes) + subboxlist.add_field(field, subfield) + return subboxlist + + +def select_random_box(boxlist, + default_box=None, + seed=None, + scope=None): + """Selects a random bounding box from a `BoxList`. + + Args: + boxlist: A BoxList. + default_box: A [1, 4] float32 tensor. If no boxes are present in `boxlist`, + this default box will be returned. If None, will use a default box of + [[-1., -1., -1., -1.]]. + seed: Random seed. + scope: Name scope. + + Returns: + bbox: A [1, 4] tensor with a random bounding box. + valid: A bool tensor indicating whether a valid bounding box is returned + (True) or whether the default box is returned (False). + """ + with tf.name_scope(scope, 'SelectRandomBox'): + bboxes = boxlist.get() + combined_shape = shape_utils.combined_static_and_dynamic_shape(bboxes) + number_of_boxes = combined_shape[0] + default_box = default_box or tf.constant([[-1., -1., -1., -1.]]) + + def select_box(): + random_index = tf.random_uniform([], + maxval=number_of_boxes, + dtype=tf.int32, + seed=seed) + return tf.expand_dims(bboxes[random_index], axis=0), tf.constant(True) + + return tf.cond( + tf.greater_equal(number_of_boxes, 1), + true_fn=select_box, + false_fn=lambda: (default_box, tf.constant(False))) + + +def get_minimal_coverage_box(boxlist, + default_box=None, + scope=None): + """Creates a single bounding box which covers all boxes in the boxlist. + + Args: + boxlist: A Boxlist. + default_box: A [1, 4] float32 tensor. If no boxes are present in `boxlist`, + this default box will be returned. If None, will use a default box of + [[0., 0., 1., 1.]]. + scope: Name scope. + + Returns: + A [1, 4] float32 tensor with a bounding box that tightly covers all the + boxes in the box list. If the boxlist does not contain any boxes, the + default box is returned. + """ + with tf.name_scope(scope, 'CreateCoverageBox'): + num_boxes = boxlist.num_boxes() + + def coverage_box(bboxes): + y_min, x_min, y_max, x_max = tf.split( + value=bboxes, num_or_size_splits=4, axis=1) + y_min_coverage = tf.reduce_min(y_min, axis=0) + x_min_coverage = tf.reduce_min(x_min, axis=0) + y_max_coverage = tf.reduce_max(y_max, axis=0) + x_max_coverage = tf.reduce_max(x_max, axis=0) + return tf.stack( + [y_min_coverage, x_min_coverage, y_max_coverage, x_max_coverage], + axis=1) + + default_box = default_box or tf.constant([[0., 0., 1., 1.]]) + return tf.cond( + tf.greater_equal(num_boxes, 1), + true_fn=lambda: coverage_box(boxlist.get()), + false_fn=lambda: default_box) diff --git a/src/main/object_detection/core/standard_fields.py b/src/main/object_detection/core/standard_fields.py new file mode 100644 index 0000000..11282da --- /dev/null +++ b/src/main/object_detection/core/standard_fields.py @@ -0,0 +1,227 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Contains classes specifying naming conventions used for object detection. + + +Specifies: + InputDataFields: standard fields used by reader/preprocessor/batcher. + DetectionResultFields: standard fields returned by object detector. + BoxListFields: standard field used by BoxList + TfExampleFields: standard fields for tf-example data format (go/tf-example). +""" + + +class InputDataFields(object): + """Names for the input tensors. + + Holds the standard data field names to use for identifying input tensors. This + should be used by the decoder to identify keys for the returned tensor_dict + containing input tensors. And it should be used by the model to identify the + tensors it needs. + + Attributes: + image: image. + image_additional_channels: additional channels. + original_image: image in the original input size. + key: unique key corresponding to image. + source_id: source of the original image. + filename: original filename of the dataset (without common path). + groundtruth_image_classes: image-level class labels. + groundtruth_boxes: coordinates of the ground truth boxes in the image. + groundtruth_classes: box-level class labels. + groundtruth_label_types: box-level label types (e.g. explicit negative). + groundtruth_is_crowd: [DEPRECATED, use groundtruth_group_of instead] + is the groundtruth a single object or a crowd. + groundtruth_area: area of a groundtruth segment. + groundtruth_difficult: is a `difficult` object + groundtruth_group_of: is a `group_of` objects, e.g. multiple objects of the + same class, forming a connected group, where instances are heavily + occluding each other. + proposal_boxes: coordinates of object proposal boxes. + proposal_objectness: objectness score of each proposal. + groundtruth_instance_masks: ground truth instance masks. + groundtruth_instance_boundaries: ground truth instance boundaries. + groundtruth_instance_classes: instance mask-level class labels. + groundtruth_keypoints: ground truth keypoints. + groundtruth_keypoint_visibilities: ground truth keypoint visibilities. + groundtruth_label_scores: groundtruth label scores. + groundtruth_weights: groundtruth weight factor for bounding boxes. + num_groundtruth_boxes: number of groundtruth boxes. + true_image_shapes: true shapes of images in the resized images, as resized + images can be padded with zeros. + verified_labels: list of human-verified image-level labels (note, that a + label can be verified both as positive and negative). + multiclass_scores: the label score per class for each box. + """ + image = 'image' + image_additional_channels = 'image_additional_channels' + original_image = 'original_image' + key = 'key' + source_id = 'source_id' + filename = 'filename' + groundtruth_image_classes = 'groundtruth_image_classes' + groundtruth_boxes = 'groundtruth_boxes' + groundtruth_classes = 'groundtruth_classes' + groundtruth_label_types = 'groundtruth_label_types' + groundtruth_is_crowd = 'groundtruth_is_crowd' + groundtruth_area = 'groundtruth_area' + groundtruth_difficult = 'groundtruth_difficult' + groundtruth_group_of = 'groundtruth_group_of' + proposal_boxes = 'proposal_boxes' + proposal_objectness = 'proposal_objectness' + groundtruth_instance_masks = 'groundtruth_instance_masks' + groundtruth_instance_boundaries = 'groundtruth_instance_boundaries' + groundtruth_instance_classes = 'groundtruth_instance_classes' + groundtruth_keypoints = 'groundtruth_keypoints' + groundtruth_keypoint_visibilities = 'groundtruth_keypoint_visibilities' + groundtruth_label_scores = 'groundtruth_label_scores' + groundtruth_weights = 'groundtruth_weights' + num_groundtruth_boxes = 'num_groundtruth_boxes' + true_image_shape = 'true_image_shape' + verified_labels = 'verified_labels' + multiclass_scores = 'multiclass_scores' + + +class DetectionResultFields(object): + """Naming conventions for storing the output of the detector. + + Attributes: + source_id: source of the original image. + key: unique key corresponding to image. + detection_boxes: coordinates of the detection boxes in the image. + detection_scores: detection scores for the detection boxes in the image. + detection_classes: detection-level class labels. + detection_masks: contains a segmentation mask for each detection box. + detection_boundaries: contains an object boundary for each detection box. + detection_keypoints: contains detection keypoints for each detection box. + num_detections: number of detections in the batch. + """ + + source_id = 'source_id' + key = 'key' + detection_boxes = 'detection_boxes' + detection_scores = 'detection_scores' + detection_classes = 'detection_classes' + detection_masks = 'detection_masks' + detection_boundaries = 'detection_boundaries' + detection_keypoints = 'detection_keypoints' + num_detections = 'num_detections' + + +class BoxListFields(object): + """Naming conventions for BoxLists. + + Attributes: + boxes: bounding box coordinates. + classes: classes per bounding box. + scores: scores per bounding box. + weights: sample weights per bounding box. + objectness: objectness score per bounding box. + masks: masks per bounding box. + boundaries: boundaries per bounding box. + keypoints: keypoints per bounding box. + keypoint_heatmaps: keypoint heatmaps per bounding box. + is_crowd: is_crowd annotation per bounding box. + """ + boxes = 'boxes' + classes = 'classes' + scores = 'scores' + weights = 'weights' + objectness = 'objectness' + masks = 'masks' + boundaries = 'boundaries' + keypoints = 'keypoints' + keypoint_heatmaps = 'keypoint_heatmaps' + is_crowd = 'is_crowd' + + +class TfExampleFields(object): + """TF-example proto feature names for object detection. + + Holds the standard feature names to load from an Example proto for object + detection. + + Attributes: + image_encoded: JPEG encoded string + image_format: image format, e.g. "JPEG" + filename: filename + channels: number of channels of image + colorspace: colorspace, e.g. "RGB" + height: height of image in pixels, e.g. 462 + width: width of image in pixels, e.g. 581 + source_id: original source of the image + image_class_text: image-level label in text format + image_class_label: image-level label in numerical format + object_class_text: labels in text format, e.g. ["person", "cat"] + object_class_label: labels in numbers, e.g. [16, 8] + object_bbox_xmin: xmin coordinates of groundtruth box, e.g. 10, 30 + object_bbox_xmax: xmax coordinates of groundtruth box, e.g. 50, 40 + object_bbox_ymin: ymin coordinates of groundtruth box, e.g. 40, 50 + object_bbox_ymax: ymax coordinates of groundtruth box, e.g. 80, 70 + object_view: viewpoint of object, e.g. ["frontal", "left"] + object_truncated: is object truncated, e.g. [true, false] + object_occluded: is object occluded, e.g. [true, false] + object_difficult: is object difficult, e.g. [true, false] + object_group_of: is object a single object or a group of objects + object_depiction: is object a depiction + object_is_crowd: [DEPRECATED, use object_group_of instead] + is the object a single object or a crowd + object_segment_area: the area of the segment. + object_weight: a weight factor for the object's bounding box. + instance_masks: instance segmentation masks. + instance_boundaries: instance boundaries. + instance_classes: Classes for each instance segmentation mask. + detection_class_label: class label in numbers. + detection_bbox_ymin: ymin coordinates of a detection box. + detection_bbox_xmin: xmin coordinates of a detection box. + detection_bbox_ymax: ymax coordinates of a detection box. + detection_bbox_xmax: xmax coordinates of a detection box. + detection_score: detection score for the class label and box. + """ + image_encoded = 'image/encoded' + image_format = 'image/format' # format is reserved keyword + filename = 'image/filename' + channels = 'image/channels' + colorspace = 'image/colorspace' + height = 'image/height' + width = 'image/width' + source_id = 'image/source_id' + image_class_text = 'image/class/text' + image_class_label = 'image/class/label' + object_class_text = 'image/object/class/text' + object_class_label = 'image/object/class/label' + object_bbox_ymin = 'image/object/bbox/ymin' + object_bbox_xmin = 'image/object/bbox/xmin' + object_bbox_ymax = 'image/object/bbox/ymax' + object_bbox_xmax = 'image/object/bbox/xmax' + object_view = 'image/object/view' + object_truncated = 'image/object/truncated' + object_occluded = 'image/object/occluded' + object_difficult = 'image/object/difficult' + object_group_of = 'image/object/group_of' + object_depiction = 'image/object/depiction' + object_is_crowd = 'image/object/is_crowd' + object_segment_area = 'image/object/segment/area' + object_weight = 'image/object/weight' + instance_masks = 'image/segmentation/object' + instance_boundaries = 'image/boundaries/object' + instance_classes = 'image/segmentation/object/class' + detection_class_label = 'image/detection/label' + detection_bbox_ymin = 'image/detection/bbox/ymin' + detection_bbox_xmin = 'image/detection/bbox/xmin' + detection_bbox_ymax = 'image/detection/bbox/ymax' + detection_bbox_xmax = 'image/detection/bbox/xmax' + detection_score = 'image/detection/score' diff --git a/src/main/object_detection/image1.jpg b/src/main/object_detection/image1.jpg new file mode 100644 index 0000000..8b20d8a Binary files /dev/null and b/src/main/object_detection/image1.jpg differ diff --git a/src/main/object_detection/protos/string_int_label_map_pb2.py b/src/main/object_detection/protos/string_int_label_map_pb2.py new file mode 100644 index 0000000..381d552 --- /dev/null +++ b/src/main/object_detection/protos/string_int_label_map_pb2.py @@ -0,0 +1,123 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: object_detection/protos/string_int_label_map.proto + +import sys +_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +from google.protobuf import descriptor_pb2 +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='object_detection/protos/string_int_label_map.proto', + package='object_detection.protos', + syntax='proto2', + serialized_pb=_b('\n2object_detection/protos/string_int_label_map.proto\x12\x17object_detection.protos\"G\n\x15StringIntLabelMapItem\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\n\n\x02id\x18\x02 \x01(\x05\x12\x14\n\x0c\x64isplay_name\x18\x03 \x01(\t\"Q\n\x11StringIntLabelMap\x12<\n\x04item\x18\x01 \x03(\x0b\x32..object_detection.protos.StringIntLabelMapItem') +) + + + + +_STRINGINTLABELMAPITEM = _descriptor.Descriptor( + name='StringIntLabelMapItem', + full_name='object_detection.protos.StringIntLabelMapItem', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='name', full_name='object_detection.protos.StringIntLabelMapItem.name', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='id', full_name='object_detection.protos.StringIntLabelMapItem.id', index=1, + number=2, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='display_name', full_name='object_detection.protos.StringIntLabelMapItem.display_name', index=2, + number=3, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=79, + serialized_end=150, +) + + +_STRINGINTLABELMAP = _descriptor.Descriptor( + name='StringIntLabelMap', + full_name='object_detection.protos.StringIntLabelMap', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='item', full_name='object_detection.protos.StringIntLabelMap.item', index=0, + number=1, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=152, + serialized_end=233, +) + +_STRINGINTLABELMAP.fields_by_name['item'].message_type = _STRINGINTLABELMAPITEM +DESCRIPTOR.message_types_by_name['StringIntLabelMapItem'] = _STRINGINTLABELMAPITEM +DESCRIPTOR.message_types_by_name['StringIntLabelMap'] = _STRINGINTLABELMAP +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +StringIntLabelMapItem = _reflection.GeneratedProtocolMessageType('StringIntLabelMapItem', (_message.Message,), dict( + DESCRIPTOR = _STRINGINTLABELMAPITEM, + __module__ = 'object_detection.protos.string_int_label_map_pb2' + # @@protoc_insertion_point(class_scope:object_detection.protos.StringIntLabelMapItem) + )) +_sym_db.RegisterMessage(StringIntLabelMapItem) + +StringIntLabelMap = _reflection.GeneratedProtocolMessageType('StringIntLabelMap', (_message.Message,), dict( + DESCRIPTOR = _STRINGINTLABELMAP, + __module__ = 'object_detection.protos.string_int_label_map_pb2' + # @@protoc_insertion_point(class_scope:object_detection.protos.StringIntLabelMap) + )) +_sym_db.RegisterMessage(StringIntLabelMap) + + +# @@protoc_insertion_point(module_scope) diff --git a/src/main/object_detection/utils/label_map_util.py b/src/main/object_detection/utils/label_map_util.py new file mode 100644 index 0000000..aef46c1 --- /dev/null +++ b/src/main/object_detection/utils/label_map_util.py @@ -0,0 +1,181 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Label map utility functions.""" + +import logging + +import tensorflow as tf +from google.protobuf import text_format +from object_detection.protos import string_int_label_map_pb2 + + +def _validate_label_map(label_map): + """Checks if a label map is valid. + + Args: + label_map: StringIntLabelMap to validate. + + Raises: + ValueError: if label map is invalid. + """ + for item in label_map.item: + if item.id < 0: + raise ValueError('Label map ids should be >= 0.') + if (item.id == 0 and item.name != 'background' and + item.display_name != 'background'): + raise ValueError('Label map id 0 is reserved for the background label') + + +def create_category_index(categories): + """Creates dictionary of COCO compatible categories keyed by category id. + + Args: + categories: a list of dicts, each of which has the following keys: + 'id': (required) an integer id uniquely identifying this category. + 'name': (required) string representing category name + e.g., 'cat', 'dog', 'pizza'. + + Returns: + category_index: a dict containing the same entries as categories, but keyed + by the 'id' field of each category. + """ + category_index = {} + for cat in categories: + category_index[cat['id']] = cat + return category_index + + +def get_max_label_map_index(label_map): + """Get maximum index in label map. + + Args: + label_map: a StringIntLabelMapProto + + Returns: + an integer + """ + return max([item.id for item in label_map.item]) + + +def convert_label_map_to_categories(label_map, + max_num_classes, + use_display_name=True): + """Loads label map proto and returns categories list compatible with eval. + + This function loads a label map and returns a list of dicts, each of which + has the following keys: + 'id': (required) an integer id uniquely identifying this category. + 'name': (required) string representing category name + e.g., 'cat', 'dog', 'pizza'. + We only allow class into the list if its id-label_id_offset is + between 0 (inclusive) and max_num_classes (exclusive). + If there are several items mapping to the same id in the label map, + we will only keep the first one in the categories list. + + Args: + label_map: a StringIntLabelMapProto or None. If None, a default categories + list is created with max_num_classes categories. + max_num_classes: maximum number of (consecutive) label indices to include. + use_display_name: (boolean) choose whether to load 'display_name' field + as category name. If False or if the display_name field does not exist, + uses 'name' field as category names instead. + Returns: + categories: a list of dictionaries representing all possible categories. + """ + categories = [] + list_of_ids_already_added = [] + if not label_map: + label_id_offset = 1 + for class_id in range(max_num_classes): + categories.append({ + 'id': class_id + label_id_offset, + 'name': 'category_{}'.format(class_id + label_id_offset) + }) + return categories + for item in label_map.item: + if not 0 < item.id <= max_num_classes: + logging.info('Ignore item %d since it falls outside of requested ' + 'label range.', item.id) + continue + if use_display_name and item.HasField('display_name'): + name = item.display_name + else: + name = item.name + if item.id not in list_of_ids_already_added: + list_of_ids_already_added.append(item.id) + categories.append({'id': item.id, 'name': name}) + return categories + + +def load_labelmap(path): + """Loads label map proto. + + Args: + path: path to StringIntLabelMap proto text file. + Returns: + a StringIntLabelMapProto + """ + with tf.gfile.GFile(path, 'r') as fid: + label_map_string = fid.read() + label_map = string_int_label_map_pb2.StringIntLabelMap() + try: + text_format.Merge(label_map_string, label_map) + except text_format.ParseError: + label_map.ParseFromString(label_map_string) + _validate_label_map(label_map) + return label_map + + +def get_label_map_dict(label_map_path, use_display_name=False): + """Reads a label map and returns a dictionary of label names to id. + + Args: + label_map_path: path to label_map. + use_display_name: whether to use the label map items' display names as keys. + + Returns: + A dictionary mapping label names to id. + """ + label_map = load_labelmap(label_map_path) + label_map_dict = {} + for item in label_map.item: + if use_display_name: + label_map_dict[item.display_name] = item.id + else: + label_map_dict[item.name] = item.id + return label_map_dict + + +def create_category_index_from_labelmap(label_map_path): + """Reads a label map and returns a category index. + + Args: + label_map_path: Path to `StringIntLabelMap` proto text file. + + Returns: + A category index, which is a dictionary that maps integer ids to dicts + containing categories, e.g. + {1: {'id': 1, 'name': 'dog'}, 2: {'id': 2, 'name': 'cat'}, ...} + """ + label_map = load_labelmap(label_map_path) + max_num_classes = max(item.id for item in label_map.item) + categories = convert_label_map_to_categories(label_map, max_num_classes) + return create_category_index(categories) + + +def create_class_agnostic_category_index(): + """Creates a category index with a single `object` class.""" + return {1: {'id': 1, 'name': 'object'}} diff --git a/src/main/object_detection/utils/ops.py b/src/main/object_detection/utils/ops.py new file mode 100644 index 0000000..662f7e0 --- /dev/null +++ b/src/main/object_detection/utils/ops.py @@ -0,0 +1,959 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""A module for helper tensorflow ops.""" +import math +import numpy as np +import six + +import tensorflow as tf + +from object_detection.core import box_list +from object_detection.core import box_list_ops +from object_detection.core import standard_fields as fields +from object_detection.utils import shape_utils +from object_detection.utils import static_shape + + +def expanded_shape(orig_shape, start_dim, num_dims): + """Inserts multiple ones into a shape vector. + + Inserts an all-1 vector of length num_dims at position start_dim into a shape. + Can be combined with tf.reshape to generalize tf.expand_dims. + + Args: + orig_shape: the shape into which the all-1 vector is added (int32 vector) + start_dim: insertion position (int scalar) + num_dims: length of the inserted all-1 vector (int scalar) + Returns: + An int32 vector of length tf.size(orig_shape) + num_dims. + """ + with tf.name_scope('ExpandedShape'): + start_dim = tf.expand_dims(start_dim, 0) # scalar to rank-1 + before = tf.slice(orig_shape, [0], start_dim) + add_shape = tf.ones(tf.reshape(num_dims, [1]), dtype=tf.int32) + after = tf.slice(orig_shape, start_dim, [-1]) + new_shape = tf.concat([before, add_shape, after], 0) + return new_shape + + +def normalized_to_image_coordinates(normalized_boxes, image_shape, + parallel_iterations=32): + """Converts a batch of boxes from normal to image coordinates. + + Args: + normalized_boxes: a float32 tensor of shape [None, num_boxes, 4] in + normalized coordinates. + image_shape: a float32 tensor of shape [4] containing the image shape. + parallel_iterations: parallelism for the map_fn op. + + Returns: + absolute_boxes: a float32 tensor of shape [None, num_boxes, 4] containg the + boxes in image coordinates. + """ + def _to_absolute_coordinates(normalized_boxes): + return box_list_ops.to_absolute_coordinates( + box_list.BoxList(normalized_boxes), + image_shape[1], image_shape[2], check_range=False).get() + + absolute_boxes = shape_utils.static_or_dynamic_map_fn( + _to_absolute_coordinates, + elems=(normalized_boxes), + dtype=tf.float32, + parallel_iterations=parallel_iterations, + back_prop=True) + return absolute_boxes + + +def meshgrid(x, y): + """Tiles the contents of x and y into a pair of grids. + + Multidimensional analog of numpy.meshgrid, giving the same behavior if x and y + are vectors. Generally, this will give: + + xgrid(i1, ..., i_m, j_1, ..., j_n) = x(j_1, ..., j_n) + ygrid(i1, ..., i_m, j_1, ..., j_n) = y(i_1, ..., i_m) + + Keep in mind that the order of the arguments and outputs is reverse relative + to the order of the indices they go into, done for compatibility with numpy. + The output tensors have the same shapes. Specifically: + + xgrid.get_shape() = y.get_shape().concatenate(x.get_shape()) + ygrid.get_shape() = y.get_shape().concatenate(x.get_shape()) + + Args: + x: A tensor of arbitrary shape and rank. xgrid will contain these values + varying in its last dimensions. + y: A tensor of arbitrary shape and rank. ygrid will contain these values + varying in its first dimensions. + Returns: + A tuple of tensors (xgrid, ygrid). + """ + with tf.name_scope('Meshgrid'): + x = tf.convert_to_tensor(x) + y = tf.convert_to_tensor(y) + x_exp_shape = expanded_shape(tf.shape(x), 0, tf.rank(y)) + y_exp_shape = expanded_shape(tf.shape(y), tf.rank(y), tf.rank(x)) + + xgrid = tf.tile(tf.reshape(x, x_exp_shape), y_exp_shape) + ygrid = tf.tile(tf.reshape(y, y_exp_shape), x_exp_shape) + new_shape = y.get_shape().concatenate(x.get_shape()) + xgrid.set_shape(new_shape) + ygrid.set_shape(new_shape) + + return xgrid, ygrid + + +def fixed_padding(inputs, kernel_size, rate=1): + """Pads the input along the spatial dimensions independently of input size. + + Args: + inputs: A tensor of size [batch, height_in, width_in, channels]. + kernel_size: The kernel to be used in the conv2d or max_pool2d operation. + Should be a positive integer. + rate: An integer, rate for atrous convolution. + + Returns: + output: A tensor of size [batch, height_out, width_out, channels] with the + input, either intact (if kernel_size == 1) or padded (if kernel_size > 1). + """ + kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1) + pad_total = kernel_size_effective - 1 + pad_beg = pad_total // 2 + pad_end = pad_total - pad_beg + padded_inputs = tf.pad(inputs, [[0, 0], [pad_beg, pad_end], + [pad_beg, pad_end], [0, 0]]) + return padded_inputs + + +def pad_to_multiple(tensor, multiple): + """Returns the tensor zero padded to the specified multiple. + + Appends 0s to the end of the first and second dimension (height and width) of + the tensor until both dimensions are a multiple of the input argument + 'multiple'. E.g. given an input tensor of shape [1, 3, 5, 1] and an input + multiple of 4, PadToMultiple will append 0s so that the resulting tensor will + be of shape [1, 4, 8, 1]. + + Args: + tensor: rank 4 float32 tensor, where + tensor -> [batch_size, height, width, channels]. + multiple: the multiple to pad to. + + Returns: + padded_tensor: the tensor zero padded to the specified multiple. + """ + tensor_shape = tensor.get_shape() + batch_size = static_shape.get_batch_size(tensor_shape) + tensor_height = static_shape.get_height(tensor_shape) + tensor_width = static_shape.get_width(tensor_shape) + tensor_depth = static_shape.get_depth(tensor_shape) + + if batch_size is None: + batch_size = tf.shape(tensor)[0] + + if tensor_height is None: + tensor_height = tf.shape(tensor)[1] + padded_tensor_height = tf.to_int32( + tf.ceil(tf.to_float(tensor_height) / tf.to_float(multiple))) * multiple + else: + padded_tensor_height = int( + math.ceil(float(tensor_height) / multiple) * multiple) + + if tensor_width is None: + tensor_width = tf.shape(tensor)[2] + padded_tensor_width = tf.to_int32( + tf.ceil(tf.to_float(tensor_width) / tf.to_float(multiple))) * multiple + else: + padded_tensor_width = int( + math.ceil(float(tensor_width) / multiple) * multiple) + + if tensor_depth is None: + tensor_depth = tf.shape(tensor)[3] + + # Use tf.concat instead of tf.pad to preserve static shape + if padded_tensor_height != tensor_height: + height_pad = tf.zeros([ + batch_size, padded_tensor_height - tensor_height, tensor_width, + tensor_depth + ]) + tensor = tf.concat([tensor, height_pad], 1) + if padded_tensor_width != tensor_width: + width_pad = tf.zeros([ + batch_size, padded_tensor_height, padded_tensor_width - tensor_width, + tensor_depth + ]) + tensor = tf.concat([tensor, width_pad], 2) + + return tensor + + +def padded_one_hot_encoding(indices, depth, left_pad): + """Returns a zero padded one-hot tensor. + + This function converts a sparse representation of indices (e.g., [4]) to a + zero padded one-hot representation (e.g., [0, 0, 0, 0, 1] with depth = 4 and + left_pad = 1). If `indices` is empty, the result will simply be a tensor of + shape (0, depth + left_pad). If depth = 0, then this function just returns + `None`. + + Args: + indices: an integer tensor of shape [num_indices]. + depth: depth for the one-hot tensor (integer). + left_pad: number of zeros to left pad the one-hot tensor with (integer). + + Returns: + padded_onehot: a tensor with shape (num_indices, depth + left_pad). Returns + `None` if the depth is zero. + + Raises: + ValueError: if `indices` does not have rank 1 or if `left_pad` or `depth are + either negative or non-integers. + + TODO(rathodv): add runtime checks for depth and indices. + """ + if depth < 0 or not isinstance(depth, six.integer_types): + raise ValueError('`depth` must be a non-negative integer.') + if left_pad < 0 or not isinstance(left_pad, six.integer_types): + raise ValueError('`left_pad` must be a non-negative integer.') + if depth == 0: + return None + + rank = len(indices.get_shape().as_list()) + if rank != 1: + raise ValueError('`indices` must have rank 1, but has rank=%s' % rank) + + def one_hot_and_pad(): + one_hot = tf.cast(tf.one_hot(tf.cast(indices, tf.int64), depth, + on_value=1, off_value=0), tf.float32) + return tf.pad(one_hot, [[0, 0], [left_pad, 0]], mode='CONSTANT') + result = tf.cond(tf.greater(tf.size(indices), 0), one_hot_and_pad, + lambda: tf.zeros((depth + left_pad, 0))) + return tf.reshape(result, [-1, depth + left_pad]) + + +def dense_to_sparse_boxes(dense_locations, dense_num_boxes, num_classes): + """Converts bounding boxes from dense to sparse form. + + Args: + dense_locations: a [max_num_boxes, 4] tensor in which only the first k rows + are valid bounding box location coordinates, where k is the sum of + elements in dense_num_boxes. + dense_num_boxes: a [max_num_classes] tensor indicating the counts of + various bounding box classes e.g. [1, 0, 0, 2] means that the first + bounding box is of class 0 and the second and third bounding boxes are + of class 3. The sum of elements in this tensor is the number of valid + bounding boxes. + num_classes: number of classes + + Returns: + box_locations: a [num_boxes, 4] tensor containing only valid bounding + boxes (i.e. the first num_boxes rows of dense_locations) + box_classes: a [num_boxes] tensor containing the classes of each bounding + box (e.g. dense_num_boxes = [1, 0, 0, 2] => box_classes = [0, 3, 3] + """ + + num_valid_boxes = tf.reduce_sum(dense_num_boxes) + box_locations = tf.slice(dense_locations, + tf.constant([0, 0]), tf.stack([num_valid_boxes, 4])) + tiled_classes = [tf.tile([i], tf.expand_dims(dense_num_boxes[i], 0)) + for i in range(num_classes)] + box_classes = tf.concat(tiled_classes, 0) + box_locations.set_shape([None, 4]) + return box_locations, box_classes + + +def indices_to_dense_vector(indices, + size, + indices_value=1., + default_value=0, + dtype=tf.float32): + """Creates dense vector with indices set to specific value and rest to zeros. + + This function exists because it is unclear if it is safe to use + tf.sparse_to_dense(indices, [size], 1, validate_indices=False) + with indices which are not ordered. + This function accepts a dynamic size (e.g. tf.shape(tensor)[0]) + + Args: + indices: 1d Tensor with integer indices which are to be set to + indices_values. + size: scalar with size (integer) of output Tensor. + indices_value: values of elements specified by indices in the output vector + default_value: values of other elements in the output vector. + dtype: data type. + + Returns: + dense 1D Tensor of shape [size] with indices set to indices_values and the + rest set to default_value. + """ + size = tf.to_int32(size) + zeros = tf.ones([size], dtype=dtype) * default_value + values = tf.ones_like(indices, dtype=dtype) * indices_value + + return tf.dynamic_stitch([tf.range(size), tf.to_int32(indices)], + [zeros, values]) + + +def reduce_sum_trailing_dimensions(tensor, ndims): + """Computes sum across all dimensions following first `ndims` dimensions.""" + return tf.reduce_sum(tensor, axis=tuple(range(ndims, tensor.shape.ndims))) + + +def retain_groundtruth(tensor_dict, valid_indices): + """Retains groundtruth by valid indices. + + Args: + tensor_dict: a dictionary of following groundtruth tensors - + fields.InputDataFields.groundtruth_boxes + fields.InputDataFields.groundtruth_classes + fields.InputDataFields.groundtruth_keypoints + fields.InputDataFields.groundtruth_instance_masks + fields.InputDataFields.groundtruth_is_crowd + fields.InputDataFields.groundtruth_area + fields.InputDataFields.groundtruth_label_types + fields.InputDataFields.groundtruth_difficult + valid_indices: a tensor with valid indices for the box-level groundtruth. + + Returns: + a dictionary of tensors containing only the groundtruth for valid_indices. + + Raises: + ValueError: If the shape of valid_indices is invalid. + ValueError: field fields.InputDataFields.groundtruth_boxes is + not present in tensor_dict. + """ + input_shape = valid_indices.get_shape().as_list() + if not (len(input_shape) == 1 or + (len(input_shape) == 2 and input_shape[1] == 1)): + raise ValueError('The shape of valid_indices is invalid.') + valid_indices = tf.reshape(valid_indices, [-1]) + valid_dict = {} + if fields.InputDataFields.groundtruth_boxes in tensor_dict: + # Prevents reshape failure when num_boxes is 0. + num_boxes = tf.maximum(tf.shape( + tensor_dict[fields.InputDataFields.groundtruth_boxes])[0], 1) + for key in tensor_dict: + if key in [fields.InputDataFields.groundtruth_boxes, + fields.InputDataFields.groundtruth_classes, + fields.InputDataFields.groundtruth_keypoints, + fields.InputDataFields.groundtruth_instance_masks]: + valid_dict[key] = tf.gather(tensor_dict[key], valid_indices) + # Input decoder returns empty tensor when these fields are not provided. + # Needs to reshape into [num_boxes, -1] for tf.gather() to work. + elif key in [fields.InputDataFields.groundtruth_is_crowd, + fields.InputDataFields.groundtruth_area, + fields.InputDataFields.groundtruth_difficult, + fields.InputDataFields.groundtruth_label_types]: + valid_dict[key] = tf.reshape( + tf.gather(tf.reshape(tensor_dict[key], [num_boxes, -1]), + valid_indices), [-1]) + # Fields that are not associated with boxes. + else: + valid_dict[key] = tensor_dict[key] + else: + raise ValueError('%s not present in input tensor dict.' % ( + fields.InputDataFields.groundtruth_boxes)) + return valid_dict + + +def retain_groundtruth_with_positive_classes(tensor_dict): + """Retains only groundtruth with positive class ids. + + Args: + tensor_dict: a dictionary of following groundtruth tensors - + fields.InputDataFields.groundtruth_boxes + fields.InputDataFields.groundtruth_classes + fields.InputDataFields.groundtruth_keypoints + fields.InputDataFields.groundtruth_instance_masks + fields.InputDataFields.groundtruth_is_crowd + fields.InputDataFields.groundtruth_area + fields.InputDataFields.groundtruth_label_types + fields.InputDataFields.groundtruth_difficult + + Returns: + a dictionary of tensors containing only the groundtruth with positive + classes. + + Raises: + ValueError: If groundtruth_classes tensor is not in tensor_dict. + """ + if fields.InputDataFields.groundtruth_classes not in tensor_dict: + raise ValueError('`groundtruth classes` not in tensor_dict.') + keep_indices = tf.where(tf.greater( + tensor_dict[fields.InputDataFields.groundtruth_classes], 0)) + return retain_groundtruth(tensor_dict, keep_indices) + + +def replace_nan_groundtruth_label_scores_with_ones(label_scores): + """Replaces nan label scores with 1.0. + + Args: + label_scores: a tensor containing object annoation label scores. + + Returns: + a tensor where NaN label scores have been replaced by ones. + """ + return tf.where( + tf.is_nan(label_scores), tf.ones(tf.shape(label_scores)), label_scores) + + +def filter_groundtruth_with_crowd_boxes(tensor_dict): + """Filters out groundtruth with boxes corresponding to crowd. + + Args: + tensor_dict: a dictionary of following groundtruth tensors - + fields.InputDataFields.groundtruth_boxes + fields.InputDataFields.groundtruth_classes + fields.InputDataFields.groundtruth_keypoints + fields.InputDataFields.groundtruth_instance_masks + fields.InputDataFields.groundtruth_is_crowd + fields.InputDataFields.groundtruth_area + fields.InputDataFields.groundtruth_label_types + + Returns: + a dictionary of tensors containing only the groundtruth that have bounding + boxes. + """ + if fields.InputDataFields.groundtruth_is_crowd in tensor_dict: + is_crowd = tensor_dict[fields.InputDataFields.groundtruth_is_crowd] + is_not_crowd = tf.logical_not(is_crowd) + is_not_crowd_indices = tf.where(is_not_crowd) + tensor_dict = retain_groundtruth(tensor_dict, is_not_crowd_indices) + return tensor_dict + + +def filter_groundtruth_with_nan_box_coordinates(tensor_dict): + """Filters out groundtruth with no bounding boxes. + + Args: + tensor_dict: a dictionary of following groundtruth tensors - + fields.InputDataFields.groundtruth_boxes + fields.InputDataFields.groundtruth_classes + fields.InputDataFields.groundtruth_keypoints + fields.InputDataFields.groundtruth_instance_masks + fields.InputDataFields.groundtruth_is_crowd + fields.InputDataFields.groundtruth_area + fields.InputDataFields.groundtruth_label_types + + Returns: + a dictionary of tensors containing only the groundtruth that have bounding + boxes. + """ + groundtruth_boxes = tensor_dict[fields.InputDataFields.groundtruth_boxes] + nan_indicator_vector = tf.greater(tf.reduce_sum(tf.to_int32( + tf.is_nan(groundtruth_boxes)), reduction_indices=[1]), 0) + valid_indicator_vector = tf.logical_not(nan_indicator_vector) + valid_indices = tf.where(valid_indicator_vector) + + return retain_groundtruth(tensor_dict, valid_indices) + + +def normalize_to_target(inputs, + target_norm_value, + dim, + epsilon=1e-7, + trainable=True, + scope='NormalizeToTarget', + summarize=True): + """L2 normalizes the inputs across the specified dimension to a target norm. + + This op implements the L2 Normalization layer introduced in + Liu, Wei, et al. "SSD: Single Shot MultiBox Detector." + and Liu, Wei, Andrew Rabinovich, and Alexander C. Berg. + "Parsenet: Looking wider to see better." and is useful for bringing + activations from multiple layers in a convnet to a standard scale. + + Note that the rank of `inputs` must be known and the dimension to which + normalization is to be applied should be statically defined. + + TODO(jonathanhuang): Add option to scale by L2 norm of the entire input. + + Args: + inputs: A `Tensor` of arbitrary size. + target_norm_value: A float value that specifies an initial target norm or + a list of floats (whose length must be equal to the depth along the + dimension to be normalized) specifying a per-dimension multiplier + after normalization. + dim: The dimension along which the input is normalized. + epsilon: A small value to add to the inputs to avoid dividing by zero. + trainable: Whether the norm is trainable or not + scope: Optional scope for variable_scope. + summarize: Whether or not to add a tensorflow summary for the op. + + Returns: + The input tensor normalized to the specified target norm. + + Raises: + ValueError: If dim is smaller than the number of dimensions in 'inputs'. + ValueError: If target_norm_value is not a float or a list of floats with + length equal to the depth along the dimension to be normalized. + """ + with tf.variable_scope(scope, 'NormalizeToTarget', [inputs]): + if not inputs.get_shape(): + raise ValueError('The input rank must be known.') + input_shape = inputs.get_shape().as_list() + input_rank = len(input_shape) + if dim < 0 or dim >= input_rank: + raise ValueError( + 'dim must be non-negative but smaller than the input rank.') + if not input_shape[dim]: + raise ValueError('input shape should be statically defined along ' + 'the specified dimension.') + depth = input_shape[dim] + if not (isinstance(target_norm_value, float) or + (isinstance(target_norm_value, list) and + len(target_norm_value) == depth) and + all([isinstance(val, float) for val in target_norm_value])): + raise ValueError('target_norm_value must be a float or a list of floats ' + 'with length equal to the depth along the dimension to ' + 'be normalized.') + if isinstance(target_norm_value, float): + initial_norm = depth * [target_norm_value] + else: + initial_norm = target_norm_value + target_norm = tf.contrib.framework.model_variable( + name='weights', dtype=tf.float32, + initializer=tf.constant(initial_norm, dtype=tf.float32), + trainable=trainable) + if summarize: + mean = tf.reduce_mean(target_norm) + mean = tf.Print(mean, ['NormalizeToTarget:', mean]) + tf.summary.scalar(tf.get_variable_scope().name, mean) + lengths = epsilon + tf.sqrt(tf.reduce_sum(tf.square(inputs), dim, True)) + mult_shape = input_rank*[1] + mult_shape[dim] = depth + return tf.reshape(target_norm, mult_shape) * tf.truediv(inputs, lengths) + + +def position_sensitive_crop_regions(image, + boxes, + box_ind, + crop_size, + num_spatial_bins, + global_pool, + extrapolation_value=None): + """Position-sensitive crop and pool rectangular regions from a feature grid. + + The output crops are split into `spatial_bins_y` vertical bins + and `spatial_bins_x` horizontal bins. For each intersection of a vertical + and a horizontal bin the output values are gathered by performing + `tf.image.crop_and_resize` (bilinear resampling) on a a separate subset of + channels of the image. This reduces `depth` by a factor of + `(spatial_bins_y * spatial_bins_x)`. + + When global_pool is True, this function implements a differentiable version + of position-sensitive RoI pooling used in + [R-FCN detection system](https://arxiv.org/abs/1605.06409). + + When global_pool is False, this function implements a differentiable version + of position-sensitive assembling operation used in + [instance FCN](https://arxiv.org/abs/1603.08678). + + Args: + image: A `Tensor`. Must be one of the following types: `uint8`, `int8`, + `int16`, `int32`, `int64`, `half`, `float32`, `float64`. + A 4-D tensor of shape `[batch, image_height, image_width, depth]`. + Both `image_height` and `image_width` need to be positive. + boxes: A `Tensor` of type `float32`. + A 2-D tensor of shape `[num_boxes, 4]`. The `i`-th row of the tensor + specifies the coordinates of a box in the `box_ind[i]` image and is + specified in normalized coordinates `[y1, x1, y2, x2]`. A normalized + coordinate value of `y` is mapped to the image coordinate at + `y * (image_height - 1)`, so as the `[0, 1]` interval of normalized image + height is mapped to `[0, image_height - 1] in image height coordinates. + We do allow y1 > y2, in which case the sampled crop is an up-down flipped + version of the original image. The width dimension is treated similarly. + Normalized coordinates outside the `[0, 1]` range are allowed, in which + case we use `extrapolation_value` to extrapolate the input image values. + box_ind: A `Tensor` of type `int32`. + A 1-D tensor of shape `[num_boxes]` with int32 values in `[0, batch)`. + The value of `box_ind[i]` specifies the image that the `i`-th box refers + to. + crop_size: A list of two integers `[crop_height, crop_width]`. All + cropped image patches are resized to this size. The aspect ratio of the + image content is not preserved. Both `crop_height` and `crop_width` need + to be positive. + num_spatial_bins: A list of two integers `[spatial_bins_y, spatial_bins_x]`. + Represents the number of position-sensitive bins in y and x directions. + Both values should be >= 1. `crop_height` should be divisible by + `spatial_bins_y`, and similarly for width. + The number of image channels should be divisible by + (spatial_bins_y * spatial_bins_x). + Suggested value from R-FCN paper: [3, 3]. + global_pool: A boolean variable. + If True, we perform average global pooling on the features assembled from + the position-sensitive score maps. + If False, we keep the position-pooled features without global pooling + over the spatial coordinates. + Note that using global_pool=True is equivalent to but more efficient than + running the function with global_pool=False and then performing global + average pooling. + extrapolation_value: An optional `float`. Defaults to `0`. + Value used for extrapolation, when applicable. + Returns: + position_sensitive_features: A 4-D tensor of shape + `[num_boxes, K, K, crop_channels]`, + where `crop_channels = depth / (spatial_bins_y * spatial_bins_x)`, + where K = 1 when global_pool is True (Average-pooled cropped regions), + and K = crop_size when global_pool is False. + Raises: + ValueError: Raised in four situations: + `num_spatial_bins` is not >= 1; + `num_spatial_bins` does not divide `crop_size`; + `(spatial_bins_y*spatial_bins_x)` does not divide `depth`; + `bin_crop_size` is not square when global_pool=False due to the + constraint in function space_to_depth. + """ + total_bins = 1 + bin_crop_size = [] + + for (num_bins, crop_dim) in zip(num_spatial_bins, crop_size): + if num_bins < 1: + raise ValueError('num_spatial_bins should be >= 1') + + if crop_dim % num_bins != 0: + raise ValueError('crop_size should be divisible by num_spatial_bins') + + total_bins *= num_bins + bin_crop_size.append(crop_dim // num_bins) + + if not global_pool and bin_crop_size[0] != bin_crop_size[1]: + raise ValueError('Only support square bin crop size for now.') + + ymin, xmin, ymax, xmax = tf.unstack(boxes, axis=1) + spatial_bins_y, spatial_bins_x = num_spatial_bins + + # Split each box into spatial_bins_y * spatial_bins_x bins. + position_sensitive_boxes = [] + for bin_y in range(spatial_bins_y): + step_y = (ymax - ymin) / spatial_bins_y + for bin_x in range(spatial_bins_x): + step_x = (xmax - xmin) / spatial_bins_x + box_coordinates = [ymin + bin_y * step_y, + xmin + bin_x * step_x, + ymin + (bin_y + 1) * step_y, + xmin + (bin_x + 1) * step_x, + ] + position_sensitive_boxes.append(tf.stack(box_coordinates, axis=1)) + + image_splits = tf.split(value=image, num_or_size_splits=total_bins, axis=3) + + image_crops = [] + for (split, box) in zip(image_splits, position_sensitive_boxes): + crop = tf.image.crop_and_resize(split, box, box_ind, bin_crop_size, + extrapolation_value=extrapolation_value) + image_crops.append(crop) + + if global_pool: + # Average over all bins. + position_sensitive_features = tf.add_n(image_crops) / len(image_crops) + # Then average over spatial positions within the bins. + position_sensitive_features = tf.reduce_mean( + position_sensitive_features, [1, 2], keep_dims=True) + else: + # Reorder height/width to depth channel. + block_size = bin_crop_size[0] + if block_size >= 2: + image_crops = [tf.space_to_depth( + crop, block_size=block_size) for crop in image_crops] + + # Pack image_crops so that first dimension is for position-senstive boxes. + position_sensitive_features = tf.stack(image_crops, axis=0) + + # Unroll the position-sensitive boxes to spatial positions. + position_sensitive_features = tf.squeeze( + tf.batch_to_space_nd(position_sensitive_features, + block_shape=[1] + num_spatial_bins, + crops=tf.zeros((3, 2), dtype=tf.int32)), + squeeze_dims=[0]) + + # Reorder back the depth channel. + if block_size >= 2: + position_sensitive_features = tf.depth_to_space( + position_sensitive_features, block_size=block_size) + + return position_sensitive_features + + +def reframe_box_masks_to_image_masks(box_masks, boxes, image_height, + image_width): + """Transforms the box masks back to full image masks. + + Embeds masks in bounding boxes of larger masks whose shapes correspond to + image shape. + + Args: + box_masks: A tf.float32 tensor of size [num_masks, mask_height, mask_width]. + boxes: A tf.float32 tensor of size [num_masks, 4] containing the box + corners. Row i contains [ymin, xmin, ymax, xmax] of the box + corresponding to mask i. Note that the box corners are in + normalized coordinates. + image_height: Image height. The output mask will have the same height as + the image height. + image_width: Image width. The output mask will have the same width as the + image width. + + Returns: + A tf.float32 tensor of size [num_masks, image_height, image_width]. + """ + # TODO(rathodv): Make this a public function. + def reframe_box_masks_to_image_masks_default(): + """The default function when there are more than 0 box masks.""" + def transform_boxes_relative_to_boxes(boxes, reference_boxes): + boxes = tf.reshape(boxes, [-1, 2, 2]) + min_corner = tf.expand_dims(reference_boxes[:, 0:2], 1) + max_corner = tf.expand_dims(reference_boxes[:, 2:4], 1) + transformed_boxes = (boxes - min_corner) / (max_corner - min_corner) + return tf.reshape(transformed_boxes, [-1, 4]) + + box_masks_expanded = tf.expand_dims(box_masks, axis=3) + num_boxes = tf.shape(box_masks_expanded)[0] + unit_boxes = tf.concat( + [tf.zeros([num_boxes, 2]), tf.ones([num_boxes, 2])], axis=1) + reverse_boxes = transform_boxes_relative_to_boxes(unit_boxes, boxes) + return tf.image.crop_and_resize( + image=box_masks_expanded, + boxes=reverse_boxes, + box_ind=tf.range(num_boxes), + crop_size=[image_height, image_width], + extrapolation_value=0.0) + image_masks = tf.cond( + tf.shape(box_masks)[0] > 0, + reframe_box_masks_to_image_masks_default, + lambda: tf.zeros([0, image_height, image_width, 1], dtype=tf.float32)) + return tf.squeeze(image_masks, axis=3) + + +def merge_boxes_with_multiple_labels(boxes, classes, num_classes): + """Merges boxes with same coordinates and returns K-hot encoded classes. + + Args: + boxes: A tf.float32 tensor with shape [N, 4] holding N boxes. + classes: A tf.int32 tensor with shape [N] holding class indices. + The class index starts at 0. + num_classes: total number of classes to use for K-hot encoding. + + Returns: + merged_boxes: A tf.float32 tensor with shape [N', 4] holding boxes, + where N' <= N. + class_encodings: A tf.int32 tensor with shape [N', num_classes] holding + k-hot encodings for the merged boxes. + merged_box_indices: A tf.int32 tensor with shape [N'] holding original + indices of the boxes. + """ + def merge_numpy_boxes(boxes, classes, num_classes): + """Python function to merge numpy boxes.""" + if boxes.size < 1: + return (np.zeros([0, 4], dtype=np.float32), + np.zeros([0, num_classes], dtype=np.int32), + np.zeros([0], dtype=np.int32)) + box_to_class_indices = {} + for box_index in range(boxes.shape[0]): + box = tuple(boxes[box_index, :].tolist()) + class_index = classes[box_index] + if box not in box_to_class_indices: + box_to_class_indices[box] = [box_index, np.zeros([num_classes])] + box_to_class_indices[box][1][class_index] = 1 + merged_boxes = np.vstack(box_to_class_indices.keys()).astype(np.float32) + class_encodings = [item[1] for item in box_to_class_indices.values()] + class_encodings = np.vstack(class_encodings).astype(np.int32) + merged_box_indices = [item[0] for item in box_to_class_indices.values()] + merged_box_indices = np.array(merged_box_indices).astype(np.int32) + return merged_boxes, class_encodings, merged_box_indices + + merged_boxes, class_encodings, merged_box_indices = tf.py_func( + merge_numpy_boxes, [boxes, classes, num_classes], + [tf.float32, tf.int32, tf.int32]) + merged_boxes = tf.reshape(merged_boxes, [-1, 4]) + class_encodings = tf.reshape(class_encodings, [-1, num_classes]) + merged_box_indices = tf.reshape(merged_box_indices, [-1]) + return merged_boxes, class_encodings, merged_box_indices + + +def nearest_neighbor_upsampling(input_tensor, scale): + """Nearest neighbor upsampling implementation. + + Nearest neighbor upsampling function that maps input tensor with shape + [batch_size, height, width, channels] to [batch_size, height * scale + , width * scale, channels]. This implementation only uses reshape and + broadcasting to make it TPU compatible. + + Args: + input_tensor: A float32 tensor of size [batch, height_in, width_in, + channels]. + scale: An integer multiple to scale resolution of input data. + Returns: + data_up: A float32 tensor of size + [batch, height_in*scale, width_in*scale, channels]. + """ + with tf.name_scope('nearest_neighbor_upsampling'): + (batch_size, height, width, + channels) = shape_utils.combined_static_and_dynamic_shape(input_tensor) + output_tensor = tf.reshape( + input_tensor, [batch_size, height, 1, width, 1, channels]) * tf.ones( + [1, 1, scale, 1, scale, 1], dtype=input_tensor.dtype) + return tf.reshape(output_tensor, + [batch_size, height * scale, width * scale, channels]) + + +def matmul_gather_on_zeroth_axis(params, indices, scope=None): + """Matrix multiplication based implementation of tf.gather on zeroth axis. + + TODO(rathodv, jonathanhuang): enable sparse matmul option. + + Args: + params: A float32 Tensor. The tensor from which to gather values. + Must be at least rank 1. + indices: A Tensor. Must be one of the following types: int32, int64. + Must be in range [0, params.shape[0]) + scope: A name for the operation (optional). + + Returns: + A Tensor. Has the same type as params. Values from params gathered + from indices given by indices, with shape indices.shape + params.shape[1:]. + """ + with tf.name_scope(scope, 'MatMulGather'): + params_shape = shape_utils.combined_static_and_dynamic_shape(params) + indices_shape = shape_utils.combined_static_and_dynamic_shape(indices) + params2d = tf.reshape(params, [params_shape[0], -1]) + indicator_matrix = tf.one_hot(indices, params_shape[0]) + gathered_result_flattened = tf.matmul(indicator_matrix, params2d) + return tf.reshape(gathered_result_flattened, + tf.stack(indices_shape + params_shape[1:])) + + +def matmul_crop_and_resize(image, boxes, crop_size, scope=None): + """Matrix multiplication based implementation of the crop and resize op. + + Extracts crops from the input image tensor and bilinearly resizes them + (possibly with aspect ratio change) to a common output size specified by + crop_size. This is more general than the crop_to_bounding_box op which + extracts a fixed size slice from the input image and does not allow + resizing or aspect ratio change. + + Returns a tensor with crops from the input image at positions defined at + the bounding box locations in boxes. The cropped boxes are all resized + (with bilinear interpolation) to a fixed size = `[crop_height, crop_width]`. + The result is a 4-D tensor `[num_boxes, crop_height, crop_width, depth]`. + + Running time complexity: + O((# channels) * (# boxes) * (crop_size)^2 * M), where M is the number + of pixels of the longer edge of the image. + + Note that this operation is meant to replicate the behavior of the standard + tf.image.crop_and_resize operation but there are a few differences. + Specifically: + 1) The extrapolation value (the values that are interpolated from outside + the bounds of the image window) is always zero + 2) Only XLA supported operations are used (e.g., matrix multiplication). + 3) There is no `box_indices` argument --- to run this op on multiple images, + one must currently call this op independently on each image. + 4) All shapes and the `crop_size` parameter are assumed to be statically + defined. Moreover, the number of boxes must be strictly nonzero. + + Args: + image: A `Tensor`. Must be one of the following types: `uint8`, `int8`, + `int16`, `int32`, `int64`, `half`, `float32`, `float64`. + A 4-D tensor of shape `[batch, image_height, image_width, depth]`. + Both `image_height` and `image_width` need to be positive. + boxes: A `Tensor` of type `float32`. + A 2-D tensor of shape `[num_boxes, 4]`. The `i`-th row of the tensor + specifies the coordinates of a box in the `box_ind[i]` image and is + specified in normalized coordinates `[y1, x1, y2, x2]`. A normalized + coordinate value of `y` is mapped to the image coordinate at + `y * (image_height - 1)`, so as the `[0, 1]` interval of normalized image + height is mapped to `[0, image_height - 1] in image height coordinates. + We do allow y1 > y2, in which case the sampled crop is an up-down flipped + version of the original image. The width dimension is treated similarly. + Normalized coordinates outside the `[0, 1]` range are allowed, in which + case we use `extrapolation_value` to extrapolate the input image values. + crop_size: A list of two integers `[crop_height, crop_width]`. All + cropped image patches are resized to this size. The aspect ratio of the + image content is not preserved. Both `crop_height` and `crop_width` need + to be positive. + scope: A name for the operation (optional). + + Returns: + A 4-D tensor of shape `[num_boxes, crop_height, crop_width, depth]` + + Raises: + ValueError: if image tensor does not have shape + `[1, image_height, image_width, depth]` and all dimensions statically + defined. + ValueError: if boxes tensor does not have shape `[num_boxes, 4]` where + num_boxes > 0. + ValueError: if crop_size is not a list of two positive integers + """ + img_shape = image.shape.as_list() + boxes_shape = boxes.shape.as_list() + _, img_height, img_width, _ = img_shape + if not isinstance(crop_size, list) or len(crop_size) != 2: + raise ValueError('`crop_size` must be a list of length 2') + dimensions = img_shape + crop_size + boxes_shape + if not all([isinstance(dim, int) for dim in dimensions]): + raise ValueError('all input shapes must be statically defined') + if len(crop_size) != 2: + raise ValueError('`crop_size` must be a list of length 2') + if len(boxes_shape) != 2 or boxes_shape[1] != 4: + raise ValueError('`boxes` should have shape `[num_boxes, 4]`') + if len(img_shape) != 4 and img_shape[0] != 1: + raise ValueError('image should have shape ' + '`[1, image_height, image_width, depth]`') + num_crops = boxes_shape[0] + if not num_crops > 0: + raise ValueError('number of boxes must be > 0') + if not (crop_size[0] > 0 and crop_size[1] > 0): + raise ValueError('`crop_size` must be a list of two positive integers.') + + def _lin_space_weights(num, img_size): + if num > 1: + alpha = (img_size - 1) / float(num - 1) + indices = np.reshape(np.arange(num), (1, num)) + start_weights = alpha * (num - 1 - indices) + stop_weights = alpha * indices + else: + start_weights = num * [.5 * (img_size - 1)] + stop_weights = num * [.5 * (img_size - 1)] + return (tf.constant(start_weights, dtype=tf.float32), + tf.constant(stop_weights, dtype=tf.float32)) + + with tf.name_scope(scope, 'MatMulCropAndResize'): + y1_weights, y2_weights = _lin_space_weights(crop_size[0], img_height) + x1_weights, x2_weights = _lin_space_weights(crop_size[1], img_width) + [y1, x1, y2, x2] = tf.split(value=boxes, num_or_size_splits=4, axis=1) + + # Pixel centers of input image and grid points along height and width + image_idx_h = tf.constant( + np.reshape(np.arange(img_height), (1, 1, img_height)), dtype=tf.float32) + image_idx_w = tf.constant( + np.reshape(np.arange(img_width), (1, 1, img_width)), dtype=tf.float32) + grid_pos_h = tf.expand_dims(y1 * y1_weights + y2 * y2_weights, 2) + grid_pos_w = tf.expand_dims(x1 * x1_weights + x2 * x2_weights, 2) + + # Create kernel matrices of pairwise kernel evaluations between pixel + # centers of image and grid points. + kernel_h = tf.nn.relu(1 - tf.abs(image_idx_h - grid_pos_h)) + kernel_w = tf.nn.relu(1 - tf.abs(image_idx_w - grid_pos_w)) + + # TODO(jonathanhuang): investigate whether all channels can be processed + # without the explicit unstack --- possibly with a permute and map_fn call. + result_channels = [] + for channel in tf.unstack(image, axis=3): + result_channels.append( + tf.matmul( + tf.matmul(kernel_h, tf.tile(channel, [num_crops, 1, 1])), + kernel_w, transpose_b=True)) + return tf.stack(result_channels, axis=3) diff --git a/src/main/object_detection/utils/shape_utils.py b/src/main/object_detection/utils/shape_utils.py new file mode 100644 index 0000000..06f389a --- /dev/null +++ b/src/main/object_detection/utils/shape_utils.py @@ -0,0 +1,309 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Utils used to manipulate tensor shapes.""" + +import tensorflow as tf + +from object_detection.utils import static_shape + + +def _is_tensor(t): + """Returns a boolean indicating whether the input is a tensor. + + Args: + t: the input to be tested. + + Returns: + a boolean that indicates whether t is a tensor. + """ + return isinstance(t, (tf.Tensor, tf.SparseTensor, tf.Variable)) + + +def _set_dim_0(t, d0): + """Sets the 0-th dimension of the input tensor. + + Args: + t: the input tensor, assuming the rank is at least 1. + d0: an integer indicating the 0-th dimension of the input tensor. + + Returns: + the tensor t with the 0-th dimension set. + """ + t_shape = t.get_shape().as_list() + t_shape[0] = d0 + t.set_shape(t_shape) + return t + + +def pad_tensor(t, length): + """Pads the input tensor with 0s along the first dimension up to the length. + + Args: + t: the input tensor, assuming the rank is at least 1. + length: a tensor of shape [1] or an integer, indicating the first dimension + of the input tensor t after padding, assuming length <= t.shape[0]. + + Returns: + padded_t: the padded tensor, whose first dimension is length. If the length + is an integer, the first dimension of padded_t is set to length + statically. + """ + t_rank = tf.rank(t) + t_shape = tf.shape(t) + t_d0 = t_shape[0] + pad_d0 = tf.expand_dims(length - t_d0, 0) + pad_shape = tf.cond( + tf.greater(t_rank, 1), lambda: tf.concat([pad_d0, t_shape[1:]], 0), + lambda: tf.expand_dims(length - t_d0, 0)) + padded_t = tf.concat([t, tf.zeros(pad_shape, dtype=t.dtype)], 0) + if not _is_tensor(length): + padded_t = _set_dim_0(padded_t, length) + return padded_t + + +def clip_tensor(t, length): + """Clips the input tensor along the first dimension up to the length. + + Args: + t: the input tensor, assuming the rank is at least 1. + length: a tensor of shape [1] or an integer, indicating the first dimension + of the input tensor t after clipping, assuming length <= t.shape[0]. + + Returns: + clipped_t: the clipped tensor, whose first dimension is length. If the + length is an integer, the first dimension of clipped_t is set to length + statically. + """ + clipped_t = tf.gather(t, tf.range(length)) + if not _is_tensor(length): + clipped_t = _set_dim_0(clipped_t, length) + return clipped_t + + +def pad_or_clip_tensor(t, length): + """Pad or clip the input tensor along the first dimension. + + Args: + t: the input tensor, assuming the rank is at least 1. + length: a tensor of shape [1] or an integer, indicating the first dimension + of the input tensor t after processing. + + Returns: + processed_t: the processed tensor, whose first dimension is length. If the + length is an integer, the first dimension of the processed tensor is set + to length statically. + """ + processed_t = tf.cond( + tf.greater(tf.shape(t)[0], length), + lambda: clip_tensor(t, length), + lambda: pad_tensor(t, length)) + if not _is_tensor(length): + processed_t = _set_dim_0(processed_t, length) + return processed_t + + +def combined_static_and_dynamic_shape(tensor): + """Returns a list containing static and dynamic values for the dimensions. + + Returns a list of static and dynamic values for shape dimensions. This is + useful to preserve static shapes when available in reshape operation. + + Args: + tensor: A tensor of any type. + + Returns: + A list of size tensor.shape.ndims containing integers or a scalar tensor. + """ + static_tensor_shape = tensor.shape.as_list() + dynamic_tensor_shape = tf.shape(tensor) + combined_shape = [] + for index, dim in enumerate(static_tensor_shape): + if dim is not None: + combined_shape.append(dim) + else: + combined_shape.append(dynamic_tensor_shape[index]) + return combined_shape + + +def static_or_dynamic_map_fn(fn, elems, dtype=None, + parallel_iterations=32, back_prop=True): + """Runs map_fn as a (static) for loop when possible. + + This function rewrites the map_fn as an explicit unstack input -> for loop + over function calls -> stack result combination. This allows our graphs to + be acyclic when the batch size is static. + For comparison, see https://www.tensorflow.org/api_docs/python/tf/map_fn. + + Note that `static_or_dynamic_map_fn` currently is not *fully* interchangeable + with the default tf.map_fn function as it does not accept nested inputs (only + Tensors or lists of Tensors). Likewise, the output of `fn` can only be a + Tensor or list of Tensors. + + TODO(jonathanhuang): make this function fully interchangeable with tf.map_fn. + + Args: + fn: The callable to be performed. It accepts one argument, which will have + the same structure as elems. Its output must have the + same structure as elems. + elems: A tensor or list of tensors, each of which will + be unpacked along their first dimension. The sequence of the + resulting slices will be applied to fn. + dtype: (optional) The output type(s) of fn. If fn returns a structure of + Tensors differing from the structure of elems, then dtype is not optional + and must have the same structure as the output of fn. + parallel_iterations: (optional) number of batch items to process in + parallel. This flag is only used if the native tf.map_fn is used + and defaults to 32 instead of 10 (unlike the standard tf.map_fn default). + back_prop: (optional) True enables support for back propagation. + This flag is only used if the native tf.map_fn is used. + + Returns: + A tensor or sequence of tensors. Each tensor packs the + results of applying fn to tensors unpacked from elems along the first + dimension, from first to last. + Raises: + ValueError: if `elems` a Tensor or a list of Tensors. + ValueError: if `fn` does not return a Tensor or list of Tensors + """ + if isinstance(elems, list): + for elem in elems: + if not isinstance(elem, tf.Tensor): + raise ValueError('`elems` must be a Tensor or list of Tensors.') + + elem_shapes = [elem.shape.as_list() for elem in elems] + # Fall back on tf.map_fn if shapes of each entry of `elems` are None or fail + # to all be the same size along the batch dimension. + for elem_shape in elem_shapes: + if (not elem_shape or not elem_shape[0] + or elem_shape[0] != elem_shapes[0][0]): + return tf.map_fn(fn, elems, dtype, parallel_iterations, back_prop) + arg_tuples = zip(*[tf.unstack(elem) for elem in elems]) + outputs = [fn(arg_tuple) for arg_tuple in arg_tuples] + else: + if not isinstance(elems, tf.Tensor): + raise ValueError('`elems` must be a Tensor or list of Tensors.') + elems_shape = elems.shape.as_list() + if not elems_shape or not elems_shape[0]: + return tf.map_fn(fn, elems, dtype, parallel_iterations, back_prop) + outputs = [fn(arg) for arg in tf.unstack(elems)] + # Stack `outputs`, which is a list of Tensors or list of lists of Tensors + if all([isinstance(output, tf.Tensor) for output in outputs]): + return tf.stack(outputs) + else: + if all([isinstance(output, list) for output in outputs]): + if all([all( + [isinstance(entry, tf.Tensor) for entry in output_list]) + for output_list in outputs]): + return [tf.stack(output_tuple) for output_tuple in zip(*outputs)] + raise ValueError('`fn` should return a Tensor or a list of Tensors.') + + +def check_min_image_dim(min_dim, image_tensor): + """Checks that the image width/height are greater than some number. + + This function is used to check that the width and height of an image are above + a certain value. If the image shape is static, this function will perform the + check at graph construction time. Otherwise, if the image shape varies, an + Assertion control dependency will be added to the graph. + + Args: + min_dim: The minimum number of pixels along the width and height of the + image. + image_tensor: The image tensor to check size for. + + Returns: + If `image_tensor` has dynamic size, return `image_tensor` with a Assert + control dependency. Otherwise returns image_tensor. + + Raises: + ValueError: if `image_tensor`'s' width or height is smaller than `min_dim`. + """ + image_shape = image_tensor.get_shape() + image_height = static_shape.get_height(image_shape) + image_width = static_shape.get_width(image_shape) + if image_height is None or image_width is None: + shape_assert = tf.Assert( + tf.logical_and(tf.greater_equal(tf.shape(image_tensor)[1], min_dim), + tf.greater_equal(tf.shape(image_tensor)[2], min_dim)), + ['image size must be >= {} in both height and width.'.format(min_dim)]) + with tf.control_dependencies([shape_assert]): + return tf.identity(image_tensor) + + if image_height < min_dim or image_width < min_dim: + raise ValueError( + 'image size must be >= %d in both height and width; image dim = %d,%d' % + (min_dim, image_height, image_width)) + + return image_tensor + + +def assert_shape_equal(shape_a, shape_b): + """Asserts that shape_a and shape_b are equal. + + If the shapes are static, raises a ValueError when the shapes + mismatch. + + If the shapes are dynamic, raises a tf InvalidArgumentError when the shapes + mismatch. + + Args: + shape_a: a list containing shape of the first tensor. + shape_b: a list containing shape of the second tensor. + + Returns: + Either a tf.no_op() when shapes are all static and a tf.assert_equal() op + when the shapes are dynamic. + + Raises: + ValueError: When shapes are both static and unequal. + """ + if (all(isinstance(dim, int) for dim in shape_a) and + all(isinstance(dim, int) for dim in shape_b)): + if shape_a != shape_b: + raise ValueError('Unequal shapes {}, {}'.format(shape_a, shape_b)) + else: return tf.no_op() + else: + return tf.assert_equal(shape_a, shape_b) + + +def assert_shape_equal_along_first_dimension(shape_a, shape_b): + """Asserts that shape_a and shape_b are the same along the 0th-dimension. + + If the shapes are static, raises a ValueError when the shapes + mismatch. + + If the shapes are dynamic, raises a tf InvalidArgumentError when the shapes + mismatch. + + Args: + shape_a: a list containing shape of the first tensor. + shape_b: a list containing shape of the second tensor. + + Returns: + Either a tf.no_op() when shapes are all static and a tf.assert_equal() op + when the shapes are dynamic. + + Raises: + ValueError: When shapes are both static and unequal. + """ + if isinstance(shape_a[0], int) and isinstance(shape_b[0], int): + if shape_a[0] != shape_b[0]: + raise ValueError('Unequal first dimension {}, {}'.format( + shape_a[0], shape_b[0])) + else: return tf.no_op() + else: + return tf.assert_equal(shape_a[0], shape_b[0]) + diff --git a/src/main/object_detection/utils/static_shape.py b/src/main/object_detection/utils/static_shape.py new file mode 100644 index 0000000..8e4e522 --- /dev/null +++ b/src/main/object_detection/utils/static_shape.py @@ -0,0 +1,71 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Helper functions to access TensorShape values. + +The rank 4 tensor_shape must be of the form [batch_size, height, width, depth]. +""" + + +def get_batch_size(tensor_shape): + """Returns batch size from the tensor shape. + + Args: + tensor_shape: A rank 4 TensorShape. + + Returns: + An integer representing the batch size of the tensor. + """ + tensor_shape.assert_has_rank(rank=4) + return tensor_shape[0].value + + +def get_height(tensor_shape): + """Returns height from the tensor shape. + + Args: + tensor_shape: A rank 4 TensorShape. + + Returns: + An integer representing the height of the tensor. + """ + tensor_shape.assert_has_rank(rank=4) + return tensor_shape[1].value + + +def get_width(tensor_shape): + """Returns width from the tensor shape. + + Args: + tensor_shape: A rank 4 TensorShape. + + Returns: + An integer representing the width of the tensor. + """ + tensor_shape.assert_has_rank(rank=4) + return tensor_shape[2].value + + +def get_depth(tensor_shape): + """Returns depth from the tensor shape. + + Args: + tensor_shape: A rank 4 TensorShape. + + Returns: + An integer representing the depth of the tensor. + """ + tensor_shape.assert_has_rank(rank=4) + return tensor_shape[3].value diff --git a/src/main/object_detection/utils/visualization_utils.py b/src/main/object_detection/utils/visualization_utils.py new file mode 100644 index 0000000..79e1825 --- /dev/null +++ b/src/main/object_detection/utils/visualization_utils.py @@ -0,0 +1,733 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""A set of functions that are used for visualization. + +These functions often receive an image, perform some visualization on the image. +The functions do not return a value, instead they modify the image itself. + +""" +import collections +import functools +# Set headless-friendly backend. +import matplotlib; matplotlib.use('Agg') # pylint: disable=multiple-statements +import matplotlib.pyplot as plt # pylint: disable=g-import-not-at-top +import numpy as np +import PIL.Image as Image +import PIL.ImageColor as ImageColor +import PIL.ImageDraw as ImageDraw +import PIL.ImageFont as ImageFont +import six +import tensorflow as tf + +from object_detection.core import standard_fields as fields + + +_TITLE_LEFT_MARGIN = 10 +_TITLE_TOP_MARGIN = 10 +STANDARD_COLORS = [ + 'AliceBlue', 'Chartreuse', 'Aqua', 'Aquamarine', 'Azure', 'Beige', 'Bisque', + 'BlanchedAlmond', 'BlueViolet', 'BurlyWood', 'CadetBlue', 'AntiqueWhite', + 'Chocolate', 'Coral', 'CornflowerBlue', 'Cornsilk', 'Crimson', 'Cyan', + 'DarkCyan', 'DarkGoldenRod', 'DarkGrey', 'DarkKhaki', 'DarkOrange', + 'DarkOrchid', 'DarkSalmon', 'DarkSeaGreen', 'DarkTurquoise', 'DarkViolet', + 'DeepPink', 'DeepSkyBlue', 'DodgerBlue', 'FireBrick', 'FloralWhite', + 'ForestGreen', 'Fuchsia', 'Gainsboro', 'GhostWhite', 'Gold', 'GoldenRod', + 'Salmon', 'Tan', 'HoneyDew', 'HotPink', 'IndianRed', 'Ivory', 'Khaki', + 'Lavender', 'LavenderBlush', 'LawnGreen', 'LemonChiffon', 'LightBlue', + 'LightCoral', 'LightCyan', 'LightGoldenRodYellow', 'LightGray', 'LightGrey', + 'LightGreen', 'LightPink', 'LightSalmon', 'LightSeaGreen', 'LightSkyBlue', + 'LightSlateGray', 'LightSlateGrey', 'LightSteelBlue', 'LightYellow', 'Lime', + 'LimeGreen', 'Linen', 'Magenta', 'MediumAquaMarine', 'MediumOrchid', + 'MediumPurple', 'MediumSeaGreen', 'MediumSlateBlue', 'MediumSpringGreen', + 'MediumTurquoise', 'MediumVioletRed', 'MintCream', 'MistyRose', 'Moccasin', + 'NavajoWhite', 'OldLace', 'Olive', 'OliveDrab', 'Orange', 'OrangeRed', + 'Orchid', 'PaleGoldenRod', 'PaleGreen', 'PaleTurquoise', 'PaleVioletRed', + 'PapayaWhip', 'PeachPuff', 'Peru', 'Pink', 'Plum', 'PowderBlue', 'Purple', + 'Red', 'RosyBrown', 'RoyalBlue', 'SaddleBrown', 'Green', 'SandyBrown', + 'SeaGreen', 'SeaShell', 'Sienna', 'Silver', 'SkyBlue', 'SlateBlue', + 'SlateGray', 'SlateGrey', 'Snow', 'SpringGreen', 'SteelBlue', 'GreenYellow', + 'Teal', 'Thistle', 'Tomato', 'Turquoise', 'Violet', 'Wheat', 'White', + 'WhiteSmoke', 'Yellow', 'YellowGreen' +] + + +def save_image_array_as_png(image, output_path): + """Saves an image (represented as a numpy array) to PNG. + + Args: + image: a numpy array with shape [height, width, 3]. + output_path: path to which image should be written. + """ + image_pil = Image.fromarray(np.uint8(image)).convert('RGB') + with tf.gfile.Open(output_path, 'w') as fid: + image_pil.save(fid, 'PNG') + + +def encode_image_array_as_png_str(image): + """Encodes a numpy array into a PNG string. + + Args: + image: a numpy array with shape [height, width, 3]. + + Returns: + PNG encoded image string. + """ + image_pil = Image.fromarray(np.uint8(image)) + output = six.BytesIO() + image_pil.save(output, format='PNG') + png_string = output.getvalue() + output.close() + return png_string + + +def draw_bounding_box_on_image_array(image, + ymin, + xmin, + ymax, + xmax, + color='red', + thickness=4, + display_str_list=(), + use_normalized_coordinates=True): + """Adds a bounding box to an image (numpy array). + + Bounding box coordinates can be specified in either absolute (pixel) or + normalized coordinates by setting the use_normalized_coordinates argument. + + Args: + image: a numpy array with shape [height, width, 3]. + ymin: ymin of bounding box. + xmin: xmin of bounding box. + ymax: ymax of bounding box. + xmax: xmax of bounding box. + color: color to draw bounding box. Default is red. + thickness: line thickness. Default value is 4. + display_str_list: list of strings to display in box + (each to be shown on its own line). + use_normalized_coordinates: If True (default), treat coordinates + ymin, xmin, ymax, xmax as relative to the image. Otherwise treat + coordinates as absolute. + """ + image_pil = Image.fromarray(np.uint8(image)).convert('RGB') + draw_bounding_box_on_image(image_pil, ymin, xmin, ymax, xmax, color, + thickness, display_str_list, + use_normalized_coordinates) + np.copyto(image, np.array(image_pil)) + + +def draw_bounding_box_on_image(image, + ymin, + xmin, + ymax, + xmax, + color='red', + thickness=4, + display_str_list=(), + use_normalized_coordinates=True): + """Adds a bounding box to an image. + + Bounding box coordinates can be specified in either absolute (pixel) or + normalized coordinates by setting the use_normalized_coordinates argument. + + Each string in display_str_list is displayed on a separate line above the + bounding box in black text on a rectangle filled with the input 'color'. + If the top of the bounding box extends to the edge of the image, the strings + are displayed below the bounding box. + + Args: + image: a PIL.Image object. + ymin: ymin of bounding box. + xmin: xmin of bounding box. + ymax: ymax of bounding box. + xmax: xmax of bounding box. + color: color to draw bounding box. Default is red. + thickness: line thickness. Default value is 4. + display_str_list: list of strings to display in box + (each to be shown on its own line). + use_normalized_coordinates: If True (default), treat coordinates + ymin, xmin, ymax, xmax as relative to the image. Otherwise treat + coordinates as absolute. + """ + draw = ImageDraw.Draw(image) + im_width, im_height = image.size + if use_normalized_coordinates: + (left, right, top, bottom) = (xmin * im_width, xmax * im_width, + ymin * im_height, ymax * im_height) + else: + (left, right, top, bottom) = (xmin, xmax, ymin, ymax) + draw.line([(left, top), (left, bottom), (right, bottom), + (right, top), (left, top)], width=thickness, fill=color) + try: + font = ImageFont.truetype('arial.ttf', 24) + except IOError: + font = ImageFont.load_default() + + # If the total height of the display strings added to the top of the bounding + # box exceeds the top of the image, stack the strings below the bounding box + # instead of above. + display_str_heights = [font.getsize(ds)[1] for ds in display_str_list] + # Each display_str has a top and bottom margin of 0.05x. + total_display_str_height = (1 + 2 * 0.05) * sum(display_str_heights) + + if top > total_display_str_height: + text_bottom = top + else: + text_bottom = bottom + total_display_str_height + # Reverse list and print from bottom to top. + for display_str in display_str_list[::-1]: + text_width, text_height = font.getsize(display_str) + margin = np.ceil(0.05 * text_height) + draw.rectangle( + [(left, text_bottom - text_height - 2 * margin), (left + text_width, + text_bottom)], + fill=color) + draw.text( + (left + margin, text_bottom - text_height - margin), + display_str, + fill='black', + font=font) + text_bottom -= text_height - 2 * margin + + +def draw_bounding_boxes_on_image_array(image, + boxes, + color='red', + thickness=4, + display_str_list_list=()): + """Draws bounding boxes on image (numpy array). + + Args: + image: a numpy array object. + boxes: a 2 dimensional numpy array of [N, 4]: (ymin, xmin, ymax, xmax). + The coordinates are in normalized format between [0, 1]. + color: color to draw bounding box. Default is red. + thickness: line thickness. Default value is 4. + display_str_list_list: list of list of strings. + a list of strings for each bounding box. + The reason to pass a list of strings for a + bounding box is that it might contain + multiple labels. + + Raises: + ValueError: if boxes is not a [N, 4] array + """ + image_pil = Image.fromarray(image) + draw_bounding_boxes_on_image(image_pil, boxes, color, thickness, + display_str_list_list) + np.copyto(image, np.array(image_pil)) + + +def draw_bounding_boxes_on_image(image, + boxes, + color='red', + thickness=4, + display_str_list_list=()): + """Draws bounding boxes on image. + + Args: + image: a PIL.Image object. + boxes: a 2 dimensional numpy array of [N, 4]: (ymin, xmin, ymax, xmax). + The coordinates are in normalized format between [0, 1]. + color: color to draw bounding box. Default is red. + thickness: line thickness. Default value is 4. + display_str_list_list: list of list of strings. + a list of strings for each bounding box. + The reason to pass a list of strings for a + bounding box is that it might contain + multiple labels. + + Raises: + ValueError: if boxes is not a [N, 4] array + """ + boxes_shape = boxes.shape + if not boxes_shape: + return + if len(boxes_shape) != 2 or boxes_shape[1] != 4: + raise ValueError('Input must be of size [N, 4]') + for i in range(boxes_shape[0]): + display_str_list = () + if display_str_list_list: + display_str_list = display_str_list_list[i] + draw_bounding_box_on_image(image, boxes[i, 0], boxes[i, 1], boxes[i, 2], + boxes[i, 3], color, thickness, display_str_list) + + +def _visualize_boxes(image, boxes, classes, scores, category_index, **kwargs): + return visualize_boxes_and_labels_on_image_array( + image, boxes, classes, scores, category_index=category_index, **kwargs) + + +def _visualize_boxes_and_masks(image, boxes, classes, scores, masks, + category_index, **kwargs): + return visualize_boxes_and_labels_on_image_array( + image, + boxes, + classes, + scores, + category_index=category_index, + instance_masks=masks, + **kwargs) + + +def _visualize_boxes_and_keypoints(image, boxes, classes, scores, keypoints, + category_index, **kwargs): + return visualize_boxes_and_labels_on_image_array( + image, + boxes, + classes, + scores, + category_index=category_index, + keypoints=keypoints, + **kwargs) + + +def _visualize_boxes_and_masks_and_keypoints( + image, boxes, classes, scores, masks, keypoints, category_index, **kwargs): + return visualize_boxes_and_labels_on_image_array( + image, + boxes, + classes, + scores, + category_index=category_index, + instance_masks=masks, + keypoints=keypoints, + **kwargs) + + +def draw_bounding_boxes_on_image_tensors(images, + boxes, + classes, + scores, + category_index, + instance_masks=None, + keypoints=None, + max_boxes_to_draw=20, + min_score_thresh=0.2, + use_normalized_coordinates=True): + """Draws bounding boxes, masks, and keypoints on batch of image tensors. + + Args: + images: A 4D uint8 image tensor of shape [N, H, W, C]. If C > 3, additional + channels will be ignored. + boxes: [N, max_detections, 4] float32 tensor of detection boxes. + classes: [N, max_detections] int tensor of detection classes. Note that + classes are 1-indexed. + scores: [N, max_detections] float32 tensor of detection scores. + category_index: a dict that maps integer ids to category dicts. e.g. + {1: {1: 'dog'}, 2: {2: 'cat'}, ...} + instance_masks: A 4D uint8 tensor of shape [N, max_detection, H, W] with + instance masks. + keypoints: A 4D float32 tensor of shape [N, max_detection, num_keypoints, 2] + with keypoints. + max_boxes_to_draw: Maximum number of boxes to draw on an image. Default 20. + min_score_thresh: Minimum score threshold for visualization. Default 0.2. + use_normalized_coordinates: Whether to assume boxes and kepoints are in + normalized coordinates (as opposed to absolute coordiantes). + Default is True. + + Returns: + 4D image tensor of type uint8, with boxes drawn on top. + """ + # Additional channels are being ignored. + images = images[:, :, :, 0:3] + visualization_keyword_args = { + 'use_normalized_coordinates': use_normalized_coordinates, + 'max_boxes_to_draw': max_boxes_to_draw, + 'min_score_thresh': min_score_thresh, + 'agnostic_mode': False, + 'line_thickness': 4 + } + + if instance_masks is not None and keypoints is None: + visualize_boxes_fn = functools.partial( + _visualize_boxes_and_masks, + category_index=category_index, + **visualization_keyword_args) + elems = [images, boxes, classes, scores, instance_masks] + elif instance_masks is None and keypoints is not None: + visualize_boxes_fn = functools.partial( + _visualize_boxes_and_keypoints, + category_index=category_index, + **visualization_keyword_args) + elems = [images, boxes, classes, scores, keypoints] + elif instance_masks is not None and keypoints is not None: + visualize_boxes_fn = functools.partial( + _visualize_boxes_and_masks_and_keypoints, + category_index=category_index, + **visualization_keyword_args) + elems = [images, boxes, classes, scores, instance_masks, keypoints] + else: + visualize_boxes_fn = functools.partial( + _visualize_boxes, + category_index=category_index, + **visualization_keyword_args) + elems = [images, boxes, classes, scores] + + def draw_boxes(image_and_detections): + """Draws boxes on image.""" + image_with_boxes = tf.py_func(visualize_boxes_fn, image_and_detections, + tf.uint8) + return image_with_boxes + + images = tf.map_fn(draw_boxes, elems, dtype=tf.uint8, back_prop=False) + return images + + +def draw_side_by_side_evaluation_image(eval_dict, + category_index, + max_boxes_to_draw=20, + min_score_thresh=0.2, + use_normalized_coordinates=True): + """Creates a side-by-side image with detections and groundtruth. + + Bounding boxes (and instance masks, if available) are visualized on both + subimages. + + Args: + eval_dict: The evaluation dictionary returned by + eval_util.result_dict_for_single_example(). + category_index: A category index (dictionary) produced from a labelmap. + max_boxes_to_draw: The maximum number of boxes to draw for detections. + min_score_thresh: The minimum score threshold for showing detections. + use_normalized_coordinates: Whether to assume boxes and kepoints are in + normalized coordinates (as opposed to absolute coordiantes). + Default is True. + + Returns: + A [1, H, 2 * W, C] uint8 tensor. The subimage on the left corresponds to + detections, while the subimage on the right corresponds to groundtruth. + """ + detection_fields = fields.DetectionResultFields() + input_data_fields = fields.InputDataFields() + instance_masks = None + if detection_fields.detection_masks in eval_dict: + instance_masks = tf.cast( + tf.expand_dims(eval_dict[detection_fields.detection_masks], axis=0), + tf.uint8) + keypoints = None + if detection_fields.detection_keypoints in eval_dict: + keypoints = tf.expand_dims( + eval_dict[detection_fields.detection_keypoints], axis=0) + groundtruth_instance_masks = None + if input_data_fields.groundtruth_instance_masks in eval_dict: + groundtruth_instance_masks = tf.cast( + tf.expand_dims( + eval_dict[input_data_fields.groundtruth_instance_masks], axis=0), + tf.uint8) + images_with_detections = draw_bounding_boxes_on_image_tensors( + eval_dict[input_data_fields.original_image], + tf.expand_dims(eval_dict[detection_fields.detection_boxes], axis=0), + tf.expand_dims(eval_dict[detection_fields.detection_classes], axis=0), + tf.expand_dims(eval_dict[detection_fields.detection_scores], axis=0), + category_index, + instance_masks=instance_masks, + keypoints=keypoints, + max_boxes_to_draw=max_boxes_to_draw, + min_score_thresh=min_score_thresh, + use_normalized_coordinates=use_normalized_coordinates) + images_with_groundtruth = draw_bounding_boxes_on_image_tensors( + eval_dict[input_data_fields.original_image], + tf.expand_dims(eval_dict[input_data_fields.groundtruth_boxes], axis=0), + tf.expand_dims(eval_dict[input_data_fields.groundtruth_classes], axis=0), + tf.expand_dims( + tf.ones_like( + eval_dict[input_data_fields.groundtruth_classes], + dtype=tf.float32), + axis=0), + category_index, + instance_masks=groundtruth_instance_masks, + keypoints=None, + max_boxes_to_draw=None, + min_score_thresh=0.0, + use_normalized_coordinates=use_normalized_coordinates) + return tf.concat([images_with_detections, images_with_groundtruth], axis=2) + + +def draw_keypoints_on_image_array(image, + keypoints, + color='red', + radius=2, + use_normalized_coordinates=True): + """Draws keypoints on an image (numpy array). + + Args: + image: a numpy array with shape [height, width, 3]. + keypoints: a numpy array with shape [num_keypoints, 2]. + color: color to draw the keypoints with. Default is red. + radius: keypoint radius. Default value is 2. + use_normalized_coordinates: if True (default), treat keypoint values as + relative to the image. Otherwise treat them as absolute. + """ + image_pil = Image.fromarray(np.uint8(image)).convert('RGB') + draw_keypoints_on_image(image_pil, keypoints, color, radius, + use_normalized_coordinates) + np.copyto(image, np.array(image_pil)) + + +def draw_keypoints_on_image(image, + keypoints, + color='red', + radius=2, + use_normalized_coordinates=True): + """Draws keypoints on an image. + + Args: + image: a PIL.Image object. + keypoints: a numpy array with shape [num_keypoints, 2]. + color: color to draw the keypoints with. Default is red. + radius: keypoint radius. Default value is 2. + use_normalized_coordinates: if True (default), treat keypoint values as + relative to the image. Otherwise treat them as absolute. + """ + draw = ImageDraw.Draw(image) + im_width, im_height = image.size + keypoints_x = [k[1] for k in keypoints] + keypoints_y = [k[0] for k in keypoints] + if use_normalized_coordinates: + keypoints_x = tuple([im_width * x for x in keypoints_x]) + keypoints_y = tuple([im_height * y for y in keypoints_y]) + for keypoint_x, keypoint_y in zip(keypoints_x, keypoints_y): + draw.ellipse([(keypoint_x - radius, keypoint_y - radius), + (keypoint_x + radius, keypoint_y + radius)], + outline=color, fill=color) + + +def draw_mask_on_image_array(image, mask, color='red', alpha=0.4): + """Draws mask on an image. + + Args: + image: uint8 numpy array with shape (img_height, img_height, 3) + mask: a uint8 numpy array of shape (img_height, img_height) with + values between either 0 or 1. + color: color to draw the keypoints with. Default is red. + alpha: transparency value between 0 and 1. (default: 0.4) + + Raises: + ValueError: On incorrect data type for image or masks. + """ + if image.dtype != np.uint8: + raise ValueError('`image` not of type np.uint8') + if mask.dtype != np.uint8: + raise ValueError('`mask` not of type np.uint8') + if np.any(np.logical_and(mask != 1, mask != 0)): + raise ValueError('`mask` elements should be in [0, 1]') + if image.shape[:2] != mask.shape: + raise ValueError('The image has spatial dimensions %s but the mask has ' + 'dimensions %s' % (image.shape[:2], mask.shape)) + rgb = ImageColor.getrgb(color) + pil_image = Image.fromarray(image) + + solid_color = np.expand_dims( + np.ones_like(mask), axis=2) * np.reshape(list(rgb), [1, 1, 3]) + pil_solid_color = Image.fromarray(np.uint8(solid_color)).convert('RGBA') + pil_mask = Image.fromarray(np.uint8(255.0*alpha*mask)).convert('L') + pil_image = Image.composite(pil_solid_color, pil_image, pil_mask) + np.copyto(image, np.array(pil_image.convert('RGB'))) + + +def visualize_boxes_and_labels_on_image_array( + image, + boxes, + classes, + scores, + category_index, + instance_masks=None, + instance_boundaries=None, + keypoints=None, + use_normalized_coordinates=False, + max_boxes_to_draw=20, + min_score_thresh=.5, + agnostic_mode=False, + line_thickness=4, + groundtruth_box_visualization_color='black', + skip_scores=False, + skip_labels=False): + """Overlay labeled boxes on an image with formatted scores and label names. + + This function groups boxes that correspond to the same location + and creates a display string for each detection and overlays these + on the image. Note that this function modifies the image in place, and returns + that same image. + + Args: + image: uint8 numpy array with shape (img_height, img_width, 3) + boxes: a numpy array of shape [N, 4] + classes: a numpy array of shape [N]. Note that class indices are 1-based, + and match the keys in the label map. + scores: a numpy array of shape [N] or None. If scores=None, then + this function assumes that the boxes to be plotted are groundtruth + boxes and plot all boxes as black with no classes or scores. + category_index: a dict containing category dictionaries (each holding + category index `id` and category name `name`) keyed by category indices. + instance_masks: a numpy array of shape [N, image_height, image_width] with + values ranging between 0 and 1, can be None. + instance_boundaries: a numpy array of shape [N, image_height, image_width] + with values ranging between 0 and 1, can be None. + keypoints: a numpy array of shape [N, num_keypoints, 2], can + be None + use_normalized_coordinates: whether boxes is to be interpreted as + normalized coordinates or not. + max_boxes_to_draw: maximum number of boxes to visualize. If None, draw + all boxes. + min_score_thresh: minimum score threshold for a box to be visualized + agnostic_mode: boolean (default: False) controlling whether to evaluate in + class-agnostic mode or not. This mode will display scores but ignore + classes. + line_thickness: integer (default: 4) controlling line width of the boxes. + groundtruth_box_visualization_color: box color for visualizing groundtruth + boxes + skip_scores: whether to skip score when drawing a single detection + skip_labels: whether to skip label when drawing a single detection + + Returns: + uint8 numpy array with shape (img_height, img_width, 3) with overlaid boxes. + """ + # Create a display string (and color) for every box location, group any boxes + # that correspond to the same location. + box_to_display_str_map = collections.defaultdict(list) + box_to_color_map = collections.defaultdict(str) + box_to_instance_masks_map = {} + box_to_instance_boundaries_map = {} + box_to_keypoints_map = collections.defaultdict(list) + if not max_boxes_to_draw: + max_boxes_to_draw = boxes.shape[0] + for i in range(min(max_boxes_to_draw, boxes.shape[0])): + if scores is None or scores[i] > min_score_thresh: + box = tuple(boxes[i].tolist()) + if instance_masks is not None: + box_to_instance_masks_map[box] = instance_masks[i] + if instance_boundaries is not None: + box_to_instance_boundaries_map[box] = instance_boundaries[i] + if keypoints is not None: + box_to_keypoints_map[box].extend(keypoints[i]) + if scores is None: + box_to_color_map[box] = groundtruth_box_visualization_color + else: + display_str = '' + if not skip_labels: + if not agnostic_mode: + if classes[i] in category_index.keys(): + class_name = category_index[classes[i]]['name'] + else: + class_name = 'N/A' + display_str = str(class_name) + if not skip_scores: + if not display_str: + display_str = '{}%'.format(int(100*scores[i])) + else: + display_str = '{}: {}%'.format(display_str, int(100*scores[i])) + box_to_display_str_map[box].append(display_str) + if agnostic_mode: + box_to_color_map[box] = 'DarkOrange' + else: + box_to_color_map[box] = STANDARD_COLORS[ + classes[i] % len(STANDARD_COLORS)] + + # Draw all boxes onto image. + for box, color in box_to_color_map.items(): + ymin, xmin, ymax, xmax = box + if instance_masks is not None: + draw_mask_on_image_array( + image, + box_to_instance_masks_map[box], + color=color + ) + if instance_boundaries is not None: + draw_mask_on_image_array( + image, + box_to_instance_boundaries_map[box], + color='red', + alpha=1.0 + ) + draw_bounding_box_on_image_array( + image, + ymin, + xmin, + ymax, + xmax, + color=color, + thickness=line_thickness, + display_str_list=box_to_display_str_map[box], + use_normalized_coordinates=use_normalized_coordinates) + if keypoints is not None: + draw_keypoints_on_image_array( + image, + box_to_keypoints_map[box], + color=color, + radius=line_thickness / 2, + use_normalized_coordinates=use_normalized_coordinates) + + return image + + +def add_cdf_image_summary(values, name): + """Adds a tf.summary.image for a CDF plot of the values. + + Normalizes `values` such that they sum to 1, plots the cumulative distribution + function and creates a tf image summary. + + Args: + values: a 1-D float32 tensor containing the values. + name: name for the image summary. + """ + def cdf_plot(values): + """Numpy function to plot CDF.""" + normalized_values = values / np.sum(values) + sorted_values = np.sort(normalized_values) + cumulative_values = np.cumsum(sorted_values) + fraction_of_examples = (np.arange(cumulative_values.size, dtype=np.float32) + / cumulative_values.size) + fig = plt.figure(frameon=False) + ax = fig.add_subplot('111') + ax.plot(fraction_of_examples, cumulative_values) + ax.set_ylabel('cumulative normalized values') + ax.set_xlabel('fraction of examples') + fig.canvas.draw() + width, height = fig.get_size_inches() * fig.get_dpi() + image = np.fromstring(fig.canvas.tostring_rgb(), dtype='uint8').reshape( + 1, int(height), int(width), 3) + return image + cdf_plot = tf.py_func(cdf_plot, [values], tf.uint8) + tf.summary.image(name, cdf_plot) + + +def add_hist_image_summary(values, bins, name): + """Adds a tf.summary.image for a histogram plot of the values. + + Plots the histogram of values and creates a tf image summary. + + Args: + values: a 1-D float32 tensor containing the values. + bins: bin edges which will be directly passed to np.histogram. + name: name for the image summary. + """ + + def hist_plot(values, bins): + """Numpy function to plot hist.""" + fig = plt.figure(frameon=False) + ax = fig.add_subplot('111') + y, x = np.histogram(values, bins=bins) + ax.plot(x[:-1], y) + ax.set_ylabel('count') + ax.set_xlabel('value') + fig.canvas.draw() + width, height = fig.get_size_inches() * fig.get_dpi() + image = np.fromstring( + fig.canvas.tostring_rgb(), dtype='uint8').reshape( + 1, int(height), int(width), 3) + return image + hist_plot = tf.py_func(hist_plot, [values, bins], tf.uint8) + tf.summary.image(name, hist_plot) diff --git a/src/main/start.py b/src/main/start.py new file mode 100644 index 0000000..51d5dc5 --- /dev/null +++ b/src/main/start.py @@ -0,0 +1,192 @@ +import sys +from starlette.responses import FileResponse +from models import ApiResponse +from typing import List +from fastapi import FastAPI, Form, File, UploadFile, Header, BackgroundTasks +from starlette.staticfiles import StaticFiles +from starlette.middleware.cors import CORSMiddleware +from deep_learning_service import DeepLearningService +from inference.exceptions import ModelNotFound, InvalidModelConfiguration, ApplicationError, ModelNotLoaded, \ + InferenceEngineNotFound, InvalidInputData +from inference.errors import Error + +sys.path.append('./inference') + +dl_service = DeepLearningService() +error_logging = Error() +app = FastAPI(version='3.1.0', title='BMW InnovationLab tensorflow cpu inference Automation', + description="API for performing tensorflow cpu inference

" + "Contact the developers:
" + "Antoine Charbel: antoine.charbel@inmind.ai
" + "BMW Innovation Lab: innovation-lab@bmw.de") + + +# app.mount("/public", StaticFiles(directory="/main/public"), name="public") + +# app.add_middleware( +# CORSMiddleware, +# allow_origins=["*"], +# allow_credentials=True, +# allow_methods=["*"], +# allow_headers=["*"], +# ) + + +@app.get('/load') +def load_custom(): + """ + Loads all the available models. + :return: All the available models with their respective hashed values + """ + try: + return dl_service.load_all_models() + except ApplicationError as e: + return ApiResponse(success=False, error=e) + except Exception: + return ApiResponse(success=False, error='unexpected server error') + + +@app.post('/detect') +async def detect_custom(model: str = Form(...), image: UploadFile = File(...)): + """ + Performs a prediction for a specified image using one of the available models. + :param model: Model name or model hash + :param image: Image file + :return: Model's Bounding boxes + """ + draw_boxes = False + predict_batch = False + try: + output = await dl_service.run_model(model, image, draw_boxes, predict_batch) + error_logging.info('request successful;' + str(output)) + return output + except ApplicationError as e: + error_logging.warning(model + ';' + str(e)) + return ApiResponse(success=False, error=e) + except Exception as e: + error_logging.error(model + ' ' + str(e)) + return ApiResponse(success=False, error='unexpected server error') + + +@app.post('/get_labels') +def get_labels_custom(model: str = Form(...)): + """ + Lists the model's labels with their hashed values. + :param model: Model name or model hash + :return: A list of the model's labels with their hashed values + """ + return dl_service.get_labels_custom(model) + + +@app.get('/models/{model_name}/load') +async def load(model_name: str, force: bool = False): + """ + Loads a model specified as a query parameter. + :param model_name: Model name + :param force: Boolean for model force reload on each call + :return: APIResponse + """ + try: + dl_service.load_model(model_name, force) + return ApiResponse(success=True) + except ApplicationError as e: + return ApiResponse(success=False, error=e) + + +@app.get('/models') +async def list_models(user_agent: str = Header(None)): + """ + Lists all available models. + :param user_agent: + :return: APIResponse + """ + return ApiResponse(data={'models': dl_service.list_models()}) + + +@app.post('/models/{model_name}/predict') +async def run_model(model_name: str, input_data: UploadFile = File(...)): + """ + Performs a prediction by giving both model name and image file. + :param model_name: Model name + :param input_data: An image file + :return: APIResponse containing the prediction's bounding boxes + """ + draw_boxes = False + predict_batch = False + try: + output = await dl_service.run_model(model_name, input_data, draw_boxes, predict_batch) + error_logging.info('request successful;' + str(output)) + return ApiResponse(data=output) + except ApplicationError as e: + error_logging.warning(model_name + ';' + str(e)) + return ApiResponse(success=False, error=e) + except Exception as e: + error_logging.error(model_name + ' ' + str(e)) + return ApiResponse(success=False, error='unexpected server error') + + +@app.post('/models/{model_name}/predict_batch', include_in_schema=False) +async def run_model_batch(model_name: str, input_data: List[UploadFile] = File(...)): + """ + Performs a prediction by giving both model name and image file(s). + :param model_name: Model name + :param input_data: A batch of image files or a single image file + :return: APIResponse containing prediction(s) bounding boxes + """ + draw_boxes = False + predict_batch = True + try: + output = await dl_service.run_model(model_name, input_data, draw_boxes, predict_batch) + error_logging.info('request successful;' + str(output)) + return ApiResponse(data=output) + except ApplicationError as e: + error_logging.warning(model_name + ';' + str(e)) + return ApiResponse(success=False, error=e) + except Exception as e: + print(e) + error_logging.error(model_name + ' ' + str(e)) + return ApiResponse(success=False, error='unexpected server error') + + +@app.post('/models/{model_name}/predict_image') +async def run_model(model_name: str, input_data: UploadFile = File(...)): + """ + Draws bounding box(es) on image and returns it. + :param model_name: Model name + :param input_data: Image file + :return: Image file + """ + draw_boxes = True + predict_batch = False + try: + output = await dl_service.run_model(model_name, input_data, draw_boxes, predict_batch) + error_logging.info('request successful;' + str(output)) + return FileResponse("/main/result.jpg", media_type="image/jpg") + except ApplicationError as e: + error_logging.warning(model_name + ';' + str(e)) + return ApiResponse(success=False, error=e) + except Exception as e: + error_logging.error(model_name + ' ' + str(e)) + return ApiResponse(success=False, error='unexpected server error') + + +@app.get('/models/{model_name}/labels') +async def list_model_labels(model_name: str): + """ + Lists all the model's labels. + :param model_name: Model name + :return: List of model's labels + """ + labels = dl_service.get_labels(model_name) + return ApiResponse(data=labels) + + +@app.get('/models/{model_name}/config') +async def list_model_config(model_name: str): + """ + Lists all the model's configuration. + :param model_name: Model name + :return: List of model's configuration + """ + config = dl_service.get_config(model_name) + return ApiResponse(data=config)