diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ef63710 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +.idea +old diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..a612ad9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,373 @@ +Mozilla Public License Version 2.0 +================================== + +1. Definitions +-------------- + +1.1. "Contributor" + means each individual or legal entity that creates, contributes to + the creation of, or owns Covered Software. + +1.2. "Contributor Version" + means the combination of the Contributions of others (if any) used + by a Contributor and that particular Contributor's Contribution. + +1.3. "Contribution" + means Covered Software of a particular Contributor. + +1.4. "Covered Software" + means Source Code Form to which the initial Contributor has attached + the notice in Exhibit A, the Executable Form of such Source Code + Form, and Modifications of such Source Code Form, in each case + including portions thereof. + +1.5. "Incompatible With Secondary Licenses" + means + + (a) that the initial Contributor has attached the notice described + in Exhibit B to the Covered Software; or + + (b) that the Covered Software was made available under the terms of + version 1.1 or earlier of the License, but not also under the + terms of a Secondary License. + +1.6. "Executable Form" + means any form of the work other than Source Code Form. + +1.7. "Larger Work" + means a work that combines Covered Software with other material, in + a separate file or files, that is not Covered Software. + +1.8. "License" + means this document. + +1.9. "Licensable" + means having the right to grant, to the maximum extent possible, + whether at the time of the initial grant or subsequently, any and + all of the rights conveyed by this License. + +1.10. "Modifications" + means any of the following: + + (a) any file in Source Code Form that results from an addition to, + deletion from, or modification of the contents of Covered + Software; or + + (b) any new file in Source Code Form that contains any Covered + Software. + +1.11. "Patent Claims" of a Contributor + means any patent claim(s), including without limitation, method, + process, and apparatus claims, in any patent Licensable by such + Contributor that would be infringed, but for the grant of the + License, by the making, using, selling, offering for sale, having + made, import, or transfer of either its Contributions or its + Contributor Version. + +1.12. "Secondary License" + means either the GNU General Public License, Version 2.0, the GNU + Lesser General Public License, Version 2.1, the GNU Affero General + Public License, Version 3.0, or any later versions of those + licenses. + +1.13. "Source Code Form" + means the form of the work preferred for making modifications. + +1.14. "You" (or "Your") + means an individual or a legal entity exercising rights under this + License. For legal entities, "You" includes any entity that + controls, is controlled by, or is under common control with You. For + purposes of this definition, "control" means (a) the power, direct + or indirect, to cause the direction or management of such entity, + whether by contract or otherwise, or (b) ownership of more than + fifty percent (50%) of the outstanding shares or beneficial + ownership of such entity. + +2. License Grants and Conditions +-------------------------------- + +2.1. Grants + +Each Contributor hereby grants You a world-wide, royalty-free, +non-exclusive license: + +(a) under intellectual property rights (other than patent or trademark) + Licensable by such Contributor to use, reproduce, make available, + modify, display, perform, distribute, and otherwise exploit its + Contributions, either on an unmodified basis, with Modifications, or + as part of a Larger Work; and + +(b) under Patent Claims of such Contributor to make, use, sell, offer + for sale, have made, import, and otherwise transfer either its + Contributions or its Contributor Version. + +2.2. Effective Date + +The licenses granted in Section 2.1 with respect to any Contribution +become effective for each Contribution on the date the Contributor first +distributes such Contribution. + +2.3. Limitations on Grant Scope + +The licenses granted in this Section 2 are the only rights granted under +this License. No additional rights or licenses will be implied from the +distribution or licensing of Covered Software under this License. +Notwithstanding Section 2.1(b) above, no patent license is granted by a +Contributor: + +(a) for any code that a Contributor has removed from Covered Software; + or + +(b) for infringements caused by: (i) Your and any other third party's + modifications of Covered Software, or (ii) the combination of its + Contributions with other software (except as part of its Contributor + Version); or + +(c) under Patent Claims infringed by Covered Software in the absence of + its Contributions. + +This License does not grant any rights in the trademarks, service marks, +or logos of any Contributor (except as may be necessary to comply with +the notice requirements in Section 3.4). + +2.4. Subsequent Licenses + +No Contributor makes additional grants as a result of Your choice to +distribute the Covered Software under a subsequent version of this +License (see Section 10.2) or under the terms of a Secondary License (if +permitted under the terms of Section 3.3). + +2.5. Representation + +Each Contributor represents that the Contributor believes its +Contributions are its original creation(s) or it has sufficient rights +to grant the rights to its Contributions conveyed by this License. + +2.6. Fair Use + +This License is not intended to limit any rights You have under +applicable copyright doctrines of fair use, fair dealing, or other +equivalents. + +2.7. Conditions + +Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted +in Section 2.1. + +3. Responsibilities +------------------- + +3.1. Distribution of Source Form + +All distribution of Covered Software in Source Code Form, including any +Modifications that You create or to which You contribute, must be under +the terms of this License. You must inform recipients that the Source +Code Form of the Covered Software is governed by the terms of this +License, and how they can obtain a copy of this License. You may not +attempt to alter or restrict the recipients' rights in the Source Code +Form. + +3.2. Distribution of Executable Form + +If You distribute Covered Software in Executable Form then: + +(a) such Covered Software must also be made available in Source Code + Form, as described in Section 3.1, and You must inform recipients of + the Executable Form how they can obtain a copy of such Source Code + Form by reasonable means in a timely manner, at a charge no more + than the cost of distribution to the recipient; and + +(b) You may distribute such Executable Form under the terms of this + License, or sublicense it under different terms, provided that the + license for the Executable Form does not attempt to limit or alter + the recipients' rights in the Source Code Form under this License. + +3.3. Distribution of a Larger Work + +You may create and distribute a Larger Work under terms of Your choice, +provided that You also comply with the requirements of this License for +the Covered Software. If the Larger Work is a combination of Covered +Software with a work governed by one or more Secondary Licenses, and the +Covered Software is not Incompatible With Secondary Licenses, this +License permits You to additionally distribute such Covered Software +under the terms of such Secondary License(s), so that the recipient of +the Larger Work may, at their option, further distribute the Covered +Software under the terms of either this License or such Secondary +License(s). + +3.4. Notices + +You may not remove or alter the substance of any license notices +(including copyright notices, patent notices, disclaimers of warranty, +or limitations of liability) contained within the Source Code Form of +the Covered Software, except that You may alter any license notices to +the extent required to remedy known factual inaccuracies. + +3.5. Application of Additional Terms + +You may choose to offer, and to charge a fee for, warranty, support, +indemnity or liability obligations to one or more recipients of Covered +Software. However, You may do so only on Your own behalf, and not on +behalf of any Contributor. You must make it absolutely clear that any +such warranty, support, indemnity, or liability obligation is offered by +You alone, and You hereby agree to indemnify every Contributor for any +liability incurred by such Contributor as a result of warranty, support, +indemnity or liability terms You offer. You may include additional +disclaimers of warranty and limitations of liability specific to any +jurisdiction. + +4. Inability to Comply Due to Statute or Regulation +--------------------------------------------------- + +If it is impossible for You to comply with any of the terms of this +License with respect to some or all of the Covered Software due to +statute, judicial order, or regulation then You must: (a) comply with +the terms of this License to the maximum extent possible; and (b) +describe the limitations and the code they affect. Such description must +be placed in a text file included with all distributions of the Covered +Software under this License. Except to the extent prohibited by statute +or regulation, such description must be sufficiently detailed for a +recipient of ordinary skill to be able to understand it. + +5. Termination +-------------- + +5.1. The rights granted under this License will terminate automatically +if You fail to comply with any of its terms. However, if You become +compliant, then the rights granted under this License from a particular +Contributor are reinstated (a) provisionally, unless and until such +Contributor explicitly and finally terminates Your grants, and (b) on an +ongoing basis, if such Contributor fails to notify You of the +non-compliance by some reasonable means prior to 60 days after You have +come back into compliance. Moreover, Your grants from a particular +Contributor are reinstated on an ongoing basis if such Contributor +notifies You of the non-compliance by some reasonable means, this is the +first time You have received notice of non-compliance with this License +from such Contributor, and You become compliant prior to 30 days after +Your receipt of the notice. + +5.2. If You initiate litigation against any entity by asserting a patent +infringement claim (excluding declaratory judgment actions, +counter-claims, and cross-claims) alleging that a Contributor Version +directly or indirectly infringes any patent, then the rights granted to +You by any and all Contributors for the Covered Software under Section +2.1 of this License shall terminate. + +5.3. In the event of termination under Sections 5.1 or 5.2 above, all +end user license agreements (excluding distributors and resellers) which +have been validly granted by You or Your distributors under this License +prior to termination shall survive termination. + +************************************************************************ +* * +* 6. Disclaimer of Warranty * +* ------------------------- * +* * +* Covered Software is provided under this License on an "as is" * +* basis, without warranty of any kind, either expressed, implied, or * +* statutory, including, without limitation, warranties that the * +* Covered Software is free of defects, merchantable, fit for a * +* particular purpose or non-infringing. The entire risk as to the * +* quality and performance of the Covered Software is with You. * +* Should any Covered Software prove defective in any respect, You * +* (not any Contributor) assume the cost of any necessary servicing, * +* repair, or correction. This disclaimer of warranty constitutes an * +* essential part of this License. No use of any Covered Software is * +* authorized under this License except under this disclaimer. * +* * +************************************************************************ + +************************************************************************ +* * +* 7. Limitation of Liability * +* -------------------------- * +* * +* Under no circumstances and under no legal theory, whether tort * +* (including negligence), contract, or otherwise, shall any * +* Contributor, or anyone who distributes Covered Software as * +* permitted above, be liable to You for any direct, indirect, * +* special, incidental, or consequential damages of any character * +* including, without limitation, damages for lost profits, loss of * +* goodwill, work stoppage, computer failure or malfunction, or any * +* and all other commercial damages or losses, even if such party * +* shall have been informed of the possibility of such damages. This * +* limitation of liability shall not apply to liability for death or * +* personal injury resulting from such party's negligence to the * +* extent applicable law prohibits such limitation. Some * +* jurisdictions do not allow the exclusion or limitation of * +* incidental or consequential damages, so this exclusion and * +* limitation may not apply to You. * +* * +************************************************************************ + +8. Litigation +------------- + +Any litigation relating to this License may be brought only in the +courts of a jurisdiction where the defendant maintains its principal +place of business and such litigation shall be governed by laws of that +jurisdiction, without reference to its conflict-of-law provisions. +Nothing in this Section shall prevent a party's ability to bring +cross-claims or counter-claims. + +9. Miscellaneous +---------------- + +This License represents the complete agreement concerning the subject +matter hereof. If any provision of this License is held to be +unenforceable, such provision shall be reformed only to the extent +necessary to make it enforceable. Any law or regulation which provides +that the language of a contract shall be construed against the drafter +shall not be used to construe this License against a Contributor. + +10. Versions of the License +--------------------------- + +10.1. New Versions + +Mozilla Foundation is the license steward. Except as provided in Section +10.3, no one other than the license steward has the right to modify or +publish new versions of this License. Each version will be given a +distinguishing version number. + +10.2. Effect of New Versions + +You may distribute the Covered Software under the terms of the version +of the License under which You originally received the Covered Software, +or under the terms of any subsequent version published by the license +steward. + +10.3. Modified Versions + +If you create software not governed by this License, and you want to +create a new license for such software, you may create and use a +modified version of this License if you rename the license and remove +any references to the name of the license steward (except to note that +such modified license differs from this License). + +10.4. Distributing Source Code Form that is Incompatible With Secondary +Licenses + +If You choose to distribute Source Code Form that is Incompatible With +Secondary Licenses under the terms of this version of the License, the +notice described in Exhibit B of this License must be attached. + +Exhibit A - Source Code Form License Notice +------------------------------------------- + + This Source Code Form is subject to the terms of the Mozilla Public + License, v. 2.0. If a copy of the MPL was not distributed with this + file, You can obtain one at http://mozilla.org/MPL/2.0/. + +If it is not possible or desirable to put the notice in a particular +file, then You may include the notice in a location (such as a LICENSE +file in a relevant directory) where a recipient would be likely to look +for such a notice. + +You may add additional accurate notices of copyright ownership. + +Exhibit B - "Incompatible With Secondary Licenses" Notice +--------------------------------------------------------- + + This Source Code Form is "Incompatible With Secondary Licenses", as + defined by the Mozilla Public License, v. 2.0. diff --git a/README.md b/README.md new file mode 100644 index 0000000..20fe261 --- /dev/null +++ b/README.md @@ -0,0 +1,109 @@ +# Decoder Modulation for Indoor Depth Completion + +

+ + + + +

+ +> **Decoder Modulation for Indoor Depth Completion**
+> [Dmitry Senushkin](https://github.com/senush), +> [Ilia Belikov](https://github.com/ferluht), +> [Anton Konushin](https://scholar.google.com/citations?user=ZT_k-wMAAAAJ) +>
+> Samsung AI Center Moscow
+> https://arxiv.org/abs/20??.????? + +> **Abstract**: *Accurate depth map estimation is an essential step in scene spatial mapping for AR applications and 3D modeling. Current depth sensors provide time-synchronized depth and color images in real-time, but have limited range and suffer from missing and erroneous depth values on transparent or glossy surfaces. We investigate the task of depth completion that aims at improving the accuracy of depth measurements and recovering the missing depth values using additional information from corresponding color images. Surprisingly, we find that a simple baseline model based on modern encoder-decoder architecture for semantic segmentation achieves state-of-the-art accuracy on standard depth completion benchmarks. Then, we show that the accuracy can be further improved by taking into account a mask of missing depth values. The main contributions of our work are two-fold. First, we propose a modified decoder architecture, where features from raw depth and color are modulated by features from the mask via Spatially-Adaptive Denormalization (SPADE). Second, we introduce a new loss function for depth estimation based on direct comparison of log depth prediction with ground truth values. The resulting model outperforms current state-of-the-art by a large margin on the challenging Matterport3D dataset.* + +## Installation +This implementation is based on Python 3+ and Pytorch 1.4+. We provide two ways of setting up an environment. If you are using `Anaconda`, the following code performs necessary installation: +```.bash +conda env create -f environment.yaml +conda activate depth-completion +python setup.py install +``` +The same procedure can be done with `pip`: +```.bash +pip3 install -r requirements.txt +python setup.py install +``` + +## Training +We provide a code for training on [Matterport3D](https://github.com/patrickwu2/Depth-Completion/blob/master/doc/data.md). Download Matterpord3D dataset and reorder your root folder as follows: +```bash +ROOT/ + ├── data/ + └── splits/ + ├── train.txt + ├── val.txt + └── test.txt +``` +and `data` directory is should be configured in [this order](https://github.com/patrickwu2/Depth-Completion/blob/master/doc/data.md). Be sure that ROOT path in [matterport.py](https://github.sec.samsung.net/d-senushkin/saic_depth_completion_public/blob/master/saic_depth_completion/data/datasets/matterport.py) is valid. +Now you can start training with the following command: +```.bash +# for LRN decoder with efficientnet-b4 backbone +python train_matterport.py --default_cfg='LRN' --config_file='../configs/LRN_efficientnet-b4_lena.yaml' --postfix='example_lrn' +# for DM-LRN decoder with efficientnet-b4 backbone +python train_matterport.py --default_cfg='DM-LRN' --config_file='../configs/DM-LRN_efficientnet-b4_pepper.yaml' --postfix='example_dm_lrn' +``` + +## Evaluation +We provide scripts for evaluation on Matterport3D. If you need to perform test on NYUv2, see directly into a code since it may be changed in the future. Following instructions performs evaluation on Matterport3D test set: +```.bash +# for LRN decoder with efficientnet-b4 backbone +python test_net.py --default_cfg='LRN' --config_file='../configs/LRN_efficientnet-b4_lena.yaml' --weights= +# for DM-LRN decoder with efficientnet-b4 backbone +python test_net.py --default_cfg='DM-LRN' --config_file='../configs/DM-LRN_efficientnet-b4_pepper.yaml' --weights= +# if you need to visualize the results just add --save_dir argument +python test_net.py --default_cfg='DM-LRN' --config_file='../configs/DM-LRN_efficientnet-b4_pepper.yaml' --weights= --save_dir= +``` + +## Model ZOO +This repository includes all models mentioned in original paper. + +| Backbone | Decoder
type | Encoder
input | Training loss | Link | Config | +|----------|-----------|:-----:|:-------------:|:----------------:|:----------:| +| efficientnet-b0 | LRN | RGBD | LogDepthL1loss | [lrn_b0.pth][lrn_b0] | LRN_efficientnet-b0_suzy.yaml | +| efficientnet-b1 | LRN | RGBD | LogDepthL1loss | [lrn_b1.pth][lrn_b1] | LRN_efficientnet-b1_anabel.yaml | +| efficientnet-b2 | LRN | RGBD | LogDepthL1loss | [lrn_b2.pth][lrn_b2] | LRN_efficientnet-b2_irina.yaml | +| efficientnet-b3 | LRN | RGBD | LogDepthL1loss | [lrn_b3.pth][lrn_b3] | LRN_efficientnet-b3_sara.yaml | +| efficientnet-b4 | LRN | RGBD | LogDepthL1loss | [lrn_b4.pth][lrn_b4] | LRN_efficientnet-b4_lena.yaml | +| efficientnet-b4 | LRN | RGBD | BerHu | [lrn_b4_berhu.pth][lrn_b4_berhu] | LRN_efficientnet-b4_helga.yaml | +| efficientnet-b4 | LRN | RGBD+M | LogDepthL1loss | [lrn_b4_mask.pth][lrn_b4_mask] | LRN_efficientnet-b4_simona.yaml | +| efficientnet-b0 | DM-LRN | RGBD | LogDepthL1Loss | [dm-lrn_b0.pth][dm-lrn_b0] | DM_LRN_efficientnet-b0_camila.yaml | +| efficientnet-b1 | DM-LRN | RGBD | LogDepthL1Loss | [dm-lrn_b1.pth][dm-lrn_b1] | DM_LRN_efficientnet-b1_pamela.yaml | +| efficientnet-b2 | DM-LRN | RGBD | LogDepthL1Loss | [dm-lrn_b2.pth][dm-lrn_b2] | DM_LRN_efficientnet-b2_rosaline.yaml | +| efficientnet-b3 | DM-LRN | RGBD | LogDepthL1Loss | [dm-lrn_b3.pth][dm-lrn_b3] | DM_LRN_efficientnet-b3_jenifer.yaml | +| efficientnet-b4 | DM-LRN | RGBD | LogDepthL1Loss | [dm-lrn_b4.pth][dm-lrn_b4] | DM_LRN_efficientnet-b4_pepper.yaml | +| efficientnet-b4 | DM-LRN | RGBD | BerHu | [dm-lrn_b4_berhu.pth][dm-lrn_b4_berhu] | DM_LRN_efficientnet-b4_amelia.yaml | + +[lrn_b0]: https://github.sec.samsung.net/d-senushkin/saic_depth_completion_public/releases/download/v1.0/lrn_b0.pth +[lrn_b1]: https://github.sec.samsung.net/d-senushkin/saic_depth_completion_public/releases/download/v1.0/lrn_b1.pth +[lrn_b2]: https://github.sec.samsung.net/d-senushkin/saic_depth_completion_public/releases/download/v1.0/lrn_b2.pth +[lrn_b3]: https://github.sec.samsung.net/d-senushkin/saic_depth_completion_public/releases/download/v1.0/lrn_b3.pth +[lrn_b4]: https://github.sec.samsung.net/d-senushkin/saic_depth_completion_public/releases/download/v1.0/lrn_b4.pth +[lrn_b4_berhu]: https://github.sec.samsung.net/d-senushkin/saic_depth_completion_public/releases/download/v1.0/lrn_b4_berhu.pth +[lrn_b4_mask]: https://github.sec.samsung.net/d-senushkin/saic_depth_completion_public/releases/download/v1.0/lrn_b4_mask.pth + +[dm-lrn_b0]: https://github.sec.samsung.net/d-senushkin/saic_depth_completion_public/releases/download/v1.0/dm-lrn_b0.pth +[dm-lrn_b1]: https://github.sec.samsung.net/d-senushkin/saic_depth_completion_public/releases/download/v1.0/dm-lrn_b1.pth +[dm-lrn_b2]: https://github.sec.samsung.net/d-senushkin/saic_depth_completion_public/releases/download/v1.0/dm-lrn_b2.pth +[dm-lrn_b3]: https://github.sec.samsung.net/d-senushkin/saic_depth_completion_public/releases/download/v1.0/dm-lrn_b3.pth +[dm-lrn_b4]: https://github.sec.samsung.net/d-senushkin/saic_depth_completion_public/releases/download/v1.0/dm-lrn_b4.pth +[dm-lrn_b4_berhu]: https://github.sec.samsung.net/d-senushkin/saic_depth_completion_public/releases/download/v1.0/dm-lrn_b4_berhu.pth + +## License +The code is released under the MPL 2.0 License. MPL is a copyleft license that is easy to comply with. You must make the source code for any of your changes available under MPL, but you can combine the MPL software with proprietary code, as long as you keep the MPL code in separate files. + +## Citation +If you find this work is useful for your research, please cite our paper: +``` +@article{dmidc2020, + title={Decoder Modulation for Indoor Depth Completion}, + author={Dmitry Senushkin, Ilia Belikov, Anton Konushin}, + journal={arXiv preprint arXiv:20??.????}, + year={2020} +} +``` diff --git a/configs/dm_lrn/DM-LRN_efficientnet-b0_camila.yaml b/configs/dm_lrn/DM-LRN_efficientnet-b0_camila.yaml new file mode 100644 index 0000000..065611a --- /dev/null +++ b/configs/dm_lrn/DM-LRN_efficientnet-b0_camila.yaml @@ -0,0 +1,21 @@ +model: + arch: "DM-LRN" + max_channels: 256 + modulation: "SPADE" + activation: ("LeakyReLU", [0.2, True] ) + upsample: "bilinear" + use_crp: True + criterion: (("LogDepthL1Loss", 1.0), ) + predict_log_depth: True + mask_encoder_ksize: 3 + + + backbone: + arch: "efficientnet-b0" + imagenet: True + norm_layer: "" + multi_scale_output: True + +train: + batch_size: 32 + lr: 0.0001 diff --git a/configs/dm_lrn/DM-LRN_efficientnet-b1_pamela.yaml b/configs/dm_lrn/DM-LRN_efficientnet-b1_pamela.yaml new file mode 100644 index 0000000..21c351b --- /dev/null +++ b/configs/dm_lrn/DM-LRN_efficientnet-b1_pamela.yaml @@ -0,0 +1,21 @@ +model: + arch: "DM-LRN" + max_channels: 256 + modulation: "SPADE" + activation: ("LeakyReLU", [0.2, True] ) + upsample: "bilinear" + use_crp: True + criterion: (("LogDepthL1Loss", 1.0), ) + predict_log_depth: True + mask_encoder_ksize: 3 + + + backbone: + arch: "efficientnet-b1" + imagenet: True + norm_layer: "" + multi_scale_output: True + +train: + batch_size: 32 + lr: 0.0001 diff --git a/configs/dm_lrn/DM-LRN_efficientnet-b2_rosaline.yaml b/configs/dm_lrn/DM-LRN_efficientnet-b2_rosaline.yaml new file mode 100644 index 0000000..ceca89c --- /dev/null +++ b/configs/dm_lrn/DM-LRN_efficientnet-b2_rosaline.yaml @@ -0,0 +1,21 @@ +model: + arch: "DM-LRN" + max_channels: 256 + modulation: "SPADE" + activation: ("LeakyReLU", [0.2, True] ) + upsample: "bilinear" + use_crp: True + criterion: (("LogDepthL1Loss", 1.0), ) + predict_log_depth: True + mask_encoder_ksize: 3 + + + backbone: + arch: "efficientnet-b2" + imagenet: True + norm_layer: "" + multi_scale_output: True + +train: + batch_size: 32 + lr: 0.0001 diff --git a/configs/dm_lrn/DM-LRN_efficientnet-b3_jenifer.yaml b/configs/dm_lrn/DM-LRN_efficientnet-b3_jenifer.yaml new file mode 100644 index 0000000..218b9e1 --- /dev/null +++ b/configs/dm_lrn/DM-LRN_efficientnet-b3_jenifer.yaml @@ -0,0 +1,21 @@ +model: + arch: "DM-LRN" + max_channels: 256 + modulation: "SPADE" + activation: ("LeakyReLU", [0.2, True] ) + upsample: "bilinear" + use_crp: True + criterion: (("LogDepthL1Loss", 1.0), ) + predict_log_depth: True + mask_encoder_ksize: 3 + + + backbone: + arch: "efficientnet-b3" + imagenet: True + norm_layer: "" + multi_scale_output: True + +train: + batch_size: 32 + lr: 0.0001 diff --git a/configs/dm_lrn/DM-LRN_efficientnet-b4_amelia.yaml b/configs/dm_lrn/DM-LRN_efficientnet-b4_amelia.yaml new file mode 100644 index 0000000..4f147f8 --- /dev/null +++ b/configs/dm_lrn/DM-LRN_efficientnet-b4_amelia.yaml @@ -0,0 +1,21 @@ +model: + arch: "DM-LRN" + max_channels: 256 + modulation: "SPADE" + activation: ("ReLU", [] ) + upsample: "bilinear" + use_crp: True + criterion: (("BerHuLoss", 1.0, [0.5]), ) + predict_log_depth: False + mask_encoder_ksize: 3 + + + backbone: + arch: "efficientnet-b4" + imagenet: True + norm_layer: "" + multi_scale_output: True + +train: + batch_size: 32 + lr: 0.0001 diff --git a/configs/dm_lrn/DM-LRN_efficientnet-b4_pepper.yaml b/configs/dm_lrn/DM-LRN_efficientnet-b4_pepper.yaml new file mode 100644 index 0000000..2bf90c7 --- /dev/null +++ b/configs/dm_lrn/DM-LRN_efficientnet-b4_pepper.yaml @@ -0,0 +1,23 @@ +model: + arch: "DM-LRN" + max_channels: 256 + modulation: "SPADE" + activation: ("LeakyReLU", [0.2, True] ) + upsample: "bilinear" + use_crp: True + criterion: (("LogDepthL1Loss", 1.0), ) + predict_log_depth: True + mask_encoder_ksize: 3 + + + backbone: + arch: "efficientnet-b4" + imagenet: True + norm_layer: "" + multi_scale_output: True + +train: + batch_size: 32 + lr: 0.0001 +test: + batch_size: 8 diff --git a/configs/dm_lrn/DM-LRN_efficientnet-b5_tamara.yaml b/configs/dm_lrn/DM-LRN_efficientnet-b5_tamara.yaml new file mode 100644 index 0000000..1271069 --- /dev/null +++ b/configs/dm_lrn/DM-LRN_efficientnet-b5_tamara.yaml @@ -0,0 +1,23 @@ +model: + arch: "DM-LRN" + max_channels: 256 + modulation: "SPADE" + activation: ("LeakyReLU", [0.2, True] ) + upsample: "bilinear" + use_crp: True + criterion: (("LogDepthL1Loss", 1.0), ) + predict_log_depth: True + mask_encoder_ksize: 3 + + + backbone: + arch: "efficientnet-b5" + imagenet: True + norm_layer: "" + multi_scale_output: True + +train: + batch_size: 16 + lr: 0.0001 +test: + batch_size: 4 diff --git a/configs/lrn/LRN_efficientnet-b0_suzy.yaml b/configs/lrn/LRN_efficientnet-b0_suzy.yaml new file mode 100644 index 0000000..ff16e03 --- /dev/null +++ b/configs/lrn/LRN_efficientnet-b0_suzy.yaml @@ -0,0 +1,18 @@ +model: + arch: "LRN" + max_channels: 256 + activation: ("LeakyReLU", [0.2, True] ) + upsample: "bilinear" + use_crp: True + criterion: (("LogDepthL1Loss", 1.0), ) + predict_log_depth: True + + + backbone: + arch: "efficientnet-b0" + imagenet: True + multi_scale_output: True + +train: + batch_size: 32 + lr: 0.0001 diff --git a/configs/lrn/LRN_efficientnet-b1_anabel.yaml b/configs/lrn/LRN_efficientnet-b1_anabel.yaml new file mode 100644 index 0000000..9f286f4 --- /dev/null +++ b/configs/lrn/LRN_efficientnet-b1_anabel.yaml @@ -0,0 +1,18 @@ +model: + arch: "LRN" + max_channels: 256 + activation: ("LeakyReLU", [0.2, True] ) + upsample: "bilinear" + use_crp: True + criterion: (("LogDepthL1Loss", 1.0), ) + predict_log_depth: True + + + backbone: + arch: "efficientnet-b1" + imagenet: True + multi_scale_output: True + +train: + batch_size: 32 + lr: 0.0001 diff --git a/configs/lrn/LRN_efficientnet-b2_irina.yaml b/configs/lrn/LRN_efficientnet-b2_irina.yaml new file mode 100644 index 0000000..b35538e --- /dev/null +++ b/configs/lrn/LRN_efficientnet-b2_irina.yaml @@ -0,0 +1,18 @@ +model: + arch: "LRN" + max_channels: 256 + activation: ("LeakyReLU", [0.2, True] ) + upsample: "bilinear" + use_crp: True + criterion: (("LogDepthL1Loss", 1.0), ) + predict_log_depth: True + + + backbone: + arch: "efficientnet-b2" + imagenet: True + multi_scale_output: True + +train: + batch_size: 32 + lr: 0.0001 diff --git a/configs/lrn/LRN_efficientnet-b3_sara.yaml b/configs/lrn/LRN_efficientnet-b3_sara.yaml new file mode 100644 index 0000000..59ceb11 --- /dev/null +++ b/configs/lrn/LRN_efficientnet-b3_sara.yaml @@ -0,0 +1,18 @@ +model: + arch: "LRN" + max_channels: 256 + activation: ("LeakyReLU", [0.2, True] ) + upsample: "bilinear" + use_crp: True + criterion: (("LogDepthL1Loss", 1.0), ) + predict_log_depth: True + + + backbone: + arch: "efficientnet-b3" + imagenet: True + multi_scale_output: True + +train: + batch_size: 32 + lr: 0.0001 diff --git a/configs/lrn/LRN_efficientnet-b4_helga.yaml b/configs/lrn/LRN_efficientnet-b4_helga.yaml new file mode 100644 index 0000000..772f8ac --- /dev/null +++ b/configs/lrn/LRN_efficientnet-b4_helga.yaml @@ -0,0 +1,18 @@ +model: + arch: "LRN" + max_channels: 256 + activation: ("LeakyReLU", [0.2, True] ) + upsample: "bilinear" + use_crp: True + criterion: (("BerHuLoss", 1.0, [0.5]), ) + predict_log_depth: False + + + backbone: + arch: "efficientnet-b4" + imagenet: True + multi_scale_output: True + +train: + batch_size: 32 + lr: 0.0001 diff --git a/configs/lrn/LRN_efficientnet-b4_lena.yaml b/configs/lrn/LRN_efficientnet-b4_lena.yaml new file mode 100644 index 0000000..65c90f1 --- /dev/null +++ b/configs/lrn/LRN_efficientnet-b4_lena.yaml @@ -0,0 +1,18 @@ +model: + arch: "LRN" + max_channels: 256 + activation: ("LeakyReLU", [0.2, True] ) + upsample: "bilinear" + use_crp: True + criterion: (("LogDepthL1Loss", 1.0), ) + predict_log_depth: True + + + backbone: + arch: "efficientnet-b4" + imagenet: True + multi_scale_output: True + +train: + batch_size: 32 + lr: 0.0001 diff --git a/configs/lrn/LRN_efficientnet-b4_simona.yaml b/configs/lrn/LRN_efficientnet-b4_simona.yaml new file mode 100644 index 0000000..a174132 --- /dev/null +++ b/configs/lrn/LRN_efficientnet-b4_simona.yaml @@ -0,0 +1,19 @@ +model: + arch: "LRN" + max_channels: 256 + activation: ("LeakyReLU", [0.2, True] ) + upsample: "bilinear" + use_crp: True + criterion: (("LogDepthL1Loss", 1.0), ) + predict_log_depth: True + input_mask: True + + + backbone: + arch: "efficientnet-b4" + imagenet: True + multi_scale_output: True + +train: + batch_size: 32 + lr: 0.0001 diff --git a/configs/lrn/LRN_efficientnet-b5_tereza.yaml b/configs/lrn/LRN_efficientnet-b5_tereza.yaml new file mode 100644 index 0000000..98f1902 --- /dev/null +++ b/configs/lrn/LRN_efficientnet-b5_tereza.yaml @@ -0,0 +1,18 @@ +model: + arch: "LRN" + max_channels: 256 + activation: ("LeakyReLU", [0.2, True] ) + upsample: "bilinear" + use_crp: True + criterion: (("LogDepthL1Loss", 1.0), ) + predict_log_depth: True + + + backbone: + arch: "efficientnet-b5" + imagenet: True + multi_scale_output: True + +train: + batch_size: 16 + lr: 0.0001 diff --git a/environment.yaml b/environment.yaml new file mode 100644 index 0000000..a863b58 --- /dev/null +++ b/environment.yaml @@ -0,0 +1,32 @@ +name: depth-completion +channels: + - pytorch + - conda-forge + - defaults +dependencies: + - python=3.7.6 + - pytorch=1.4.* + - torchvision=0.5.* + - ninja + - numpy + - opencv + - pillow + - pip + - setuptools + - six + - tensorboard + - tqdm + - pip: + - argparse + - colorlog + - cycler + - decorator + - easydict + - efficientnet-pytorch + - future + - imageio + - matplotlib + - pyyaml + - scikit-image + - termcolor + - yacs \ No newline at end of file diff --git a/images/color_1.jpg b/images/color_1.jpg new file mode 100644 index 0000000..5070a45 Binary files /dev/null and b/images/color_1.jpg differ diff --git a/images/color_2.jpg b/images/color_2.jpg new file mode 100644 index 0000000..9d7c9b7 Binary files /dev/null and b/images/color_2.jpg differ diff --git a/images/gt_1.jpg b/images/gt_1.jpg new file mode 100644 index 0000000..b46ddb0 Binary files /dev/null and b/images/gt_1.jpg differ diff --git a/images/gt_2.jpg b/images/gt_2.jpg new file mode 100644 index 0000000..5170596 Binary files /dev/null and b/images/gt_2.jpg differ diff --git a/images/pred_1.jpg b/images/pred_1.jpg new file mode 100644 index 0000000..7b6c27b Binary files /dev/null and b/images/pred_1.jpg differ diff --git a/images/pred_2.jpg b/images/pred_2.jpg new file mode 100644 index 0000000..6ec4874 Binary files /dev/null and b/images/pred_2.jpg differ diff --git a/images/raw_1.jpg b/images/raw_1.jpg new file mode 100644 index 0000000..f188a35 Binary files /dev/null and b/images/raw_1.jpg differ diff --git a/images/raw_2.jpg b/images/raw_2.jpg new file mode 100644 index 0000000..1125227 Binary files /dev/null and b/images/raw_2.jpg differ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..6d0fd78 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,14 @@ +torch>=1.4.0 +torchvision>=0.5.0 +yacs +numpy +scikit-image +Pillow +matplotlib +pyyaml +easydict +tqdm +efficientnet_pytorch +colorlog +tensorboard +opencv-python-headless \ No newline at end of file diff --git a/saic_depth_completion/__init__.py b/saic_depth_completion/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/saic_depth_completion/config/__init__.py b/saic_depth_completion/config/__init__.py new file mode 100644 index 0000000..31f256d --- /dev/null +++ b/saic_depth_completion/config/__init__.py @@ -0,0 +1,12 @@ +from saic_depth_completion.utils.registry import Registry +from saic_depth_completion.config.lrn import _C as lrn_cfg +from saic_depth_completion.config.dm_lrn import _C as dm_lrn_cfg + +CONFIGS = Registry() + +CONFIGS["LRN"] = lrn_cfg +CONFIGS["DM-LRN"] = dm_lrn_cfg + + +def get_default_config(type): + return CONFIGS[type].clone() diff --git a/saic_depth_completion/config/dm_lrn.py b/saic_depth_completion/config/dm_lrn.py new file mode 100644 index 0000000..0b14885 --- /dev/null +++ b/saic_depth_completion/config/dm_lrn.py @@ -0,0 +1,47 @@ +from yacs.config import CfgNode as CN + +_C = CN() +_C.model = CN() +# global arch +_C.model.arch = 'DM-LRN' +# width of model +_C.model.max_channels = 256 +# modulation layer +_C.model.modulation = "SPADE" +# activation: (type: str, kwargs: dict) +_C.model.activation = ("LeakyReLU", [0.2, True]) +# upsample mode +_C.model.upsample = "bilinear" +# include CRP blocks or not +_C.model.use_crp = True +# loss fn: list of tuple +_C.model.criterion = [("LogDepthL1Loss", 1.0)] +_C.model.predict_log_depth = True +# mask encoder convolution's kernel size +_C.model.mask_encoder_ksize = 3 + + +# backbone +_C.model.backbone = CN() +# backbone arch +_C.model.backbone.arch = 'efficientnet-b4' +# pretraining +_C.model.backbone.imagenet = True +# batch norm or frozen batch norm +_C.model.backbone.norm_layer = "" +# return features from 4 scale or not +_C.model.backbone.multi_scale_output = True + +# train parameters +_C.train = CN() +# use standard scaler or not +_C.train.rgb_mean = [0.485, 0.456, 0.406] +_C.train.rgb_std = [0.229, 0.224, 0.225] +# standard scaler params for raw_depth +_C.train.depth_mean = 2.1489 +_C.train.depth_std = 1.4279 +_C.train.batch_size = 32 +_C.train.lr = 0.0001 + +_C.test = CN() +_C.test.batch_size = 4 \ No newline at end of file diff --git a/saic_depth_completion/config/lrn.py b/saic_depth_completion/config/lrn.py new file mode 100644 index 0000000..f86a89a --- /dev/null +++ b/saic_depth_completion/config/lrn.py @@ -0,0 +1,43 @@ +from yacs.config import CfgNode as CN + +_C = CN() +_C.model = CN() +# global arch +_C.model.arch = 'LRN' +# width of model +_C.model.max_channels = 256 +# activation: (type: str, kwargs: dict) +_C.model.activation = ("LeakyReLU", [0.2, True]) +# upsample mode +_C.model.upsample = "bilinear" +# include CRP blocks or not +_C.model.use_crp = True +# loss fn: list of tuple +_C.model.criterion = [("LogDepthL1Loss", 1.0)] +_C.model.predict_log_depth = True +_C.model.input_mask = False + +# backbone +_C.model.backbone = CN() +# backbone arch +_C.model.backbone.arch = 'efficientnet-b0' +# pretraining +_C.model.backbone.imagenet = True +# batch norm or frozen batch norm +_C.model.backbone.norm_layer = "" +# return features from 4 scale or not +_C.model.backbone.multi_scale_output = True + +# train parameters +_C.train = CN() +# use standard scaler or not +_C.train.rgb_mean = [0.485, 0.456, 0.406] +_C.train.rgb_std = [0.229, 0.224, 0.225] +# standard scaler params for raw_depth +_C.train.depth_mean = 2.1489 +_C.train.depth_std = 1.4279 +_C.train.batch_size = 32 +_C.train.lr = 0.0001 + +_C.test = CN() +_C.test.batch_size = 16 diff --git a/saic_depth_completion/data/__init__.py b/saic_depth_completion/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/saic_depth_completion/data/collate.py b/saic_depth_completion/data/collate.py new file mode 100644 index 0000000..19bad0d --- /dev/null +++ b/saic_depth_completion/data/collate.py @@ -0,0 +1,15 @@ +import torch + +def default_collate(samples): + batch = dict() + for k in samples[0].keys(): + batch[k] = list() + + for sample in samples: + for k, v in sample.items(): + batch[k].append(v) + + for k, v in batch.items(): + batch[k] = torch.stack(v) + + return batch \ No newline at end of file diff --git a/saic_depth_completion/data/datasets/__init__.py b/saic_depth_completion/data/datasets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/saic_depth_completion/data/datasets/completion_dataset.py b/saic_depth_completion/data/datasets/completion_dataset.py new file mode 100644 index 0000000..0ff7d75 --- /dev/null +++ b/saic_depth_completion/data/datasets/completion_dataset.py @@ -0,0 +1,81 @@ +import numpy as np +from skimage.filters import gaussian +import torch + + +def rgb2gray(rgb): + return np.dot(rgb[..., :3], [0.2989, 0.5870, 0.1140]) + + +def create_holes_mask(layer, granularity, percentile): + gaussian_layer = np.random.uniform(size=layer.shape[1:]) + gaussian_layer = gaussian(gaussian_layer, sigma=granularity) + threshold = np.percentile(gaussian_layer.reshape([-1]), 100 * (1 - percentile)) + gaussian_layer = torch.tensor(gaussian_layer).unsqueeze(0) + return gaussian_layer > threshold + + +def spatter(layer, mask, granularity=10, percentile=0.4): + holes_mask = create_holes_mask(layer, granularity, percentile) + + res = layer.clone() + mask = mask.clone() + res[holes_mask] = 0 + mask[holes_mask] = 0 + return res, mask + + +def deform(layer, mask, granularity=10, percentile=0.02): + holes_mask = create_holes_mask(layer, granularity, percentile) + + res = layer.clone() + mask = mask.clone() + v = res[(res > 1.0e-4) & holes_mask].mean() * 2 ** (np.random.uniform() * 2.0 - 1.0) + res[(res > 1.0e-4) & holes_mask] = v + mask[(res > 1.0e-4) & holes_mask] = 0 + + return res, mask + + +class CompletionDataset: + def __init__(self, + ds, + threshold=True, + granularity=8, + percentile_void=0.3, + percentile_deform=0.02): + self.ds = ds + self.threshold = threshold + self.granularity = granularity + self.percentile_deform = percentile_deform + self.percentile_void = percentile_void + + def __len__(self): + return len(self.ds) + + def __getitem__(self, index): + res = self.ds[index] + np.random.seed(index) + + if 'raw_depth' not in res: + res['raw_depth'] = res['depth'].clone() + res['raw_depth_mask'] = res['mask'].clone() + + if self.threshold: + maxd = res['raw_depth'][res['raw_depth_mask']].max() + mind = res['raw_depth'][res['raw_depth_mask']].min() + threshold = np.random.uniform() * (maxd - mind) + mind + mask = (res['raw_depth'] > threshold) + res['raw_depth'][mask] = 0 + + res['raw_depth'], res['raw_depth_mask'] = deform(res['raw_depth'], + res['raw_depth_mask'], + granularity=self.granularity, + percentile=self.percentile_deform) + res['raw_depth'], res['raw_depth_mask'] = spatter(res['raw_depth'], + res['raw_depth_mask'], + granularity=self.granularity, + percentile=self.percentile_void) + res['gt_depth'] = res.pop('depth') + res['color'] = res.pop('image') + return res \ No newline at end of file diff --git a/saic_depth_completion/data/datasets/matterport.py b/saic_depth_completion/data/datasets/matterport.py new file mode 100644 index 0000000..99191c6 --- /dev/null +++ b/saic_depth_completion/data/datasets/matterport.py @@ -0,0 +1,96 @@ +import os +import torch + +import numpy as np +from PIL import Image + +# ROOT = '/Vol1/dbstore/datasets/depth_completion/Matterport3D/' +ROOT = "/Vol0/user/d.senushkin/datasets/matterport3d" + +class Matterport: + def __init__( + self, root=ROOT, split="train", transforms=None + ): + self.transforms = transforms + self.data_root = os.path.join(root, "data") + self.split_file = os.path.join(root, "splits", split + ".txt") + self.data_list = self._get_data_list(self.split_file) + self.color_name, self.depth_name, self.render_name = [], [], [] + self.normal_name = [] + + self._load_data() + + def _load_data(self): + for x in os.listdir(self.data_root): + scene = os.path.join(self.data_root, x) + raw_depth_scene = os.path.join(scene, 'undistorted_depth_images') + render_depth_scene = os.path.join(scene, 'render_depth') + + for y in os.listdir(raw_depth_scene): + valid, resize_count, one_scene_name, num_1, num_2, png = self._split_matterport_path(y) + if valid == False or png != 'png' or resize_count != 1: + continue + data_id = (x, one_scene_name, num_1, num_2) + if data_id not in self.data_list: + continue + raw_depth_f = os.path.join(raw_depth_scene, y) + render_depth_f = os.path.join(render_depth_scene, y.split('.')[0] + '_mesh_depth.png') + color_f = os.path.join( + scene,'undistorted_color_images', f'resize_{one_scene_name}_i{num_1}_{num_2}.jpg' + ) + est_normal_f = os.path.join( + scene, 'estimate_normal', f'resize_{one_scene_name}_d{num_1}_{num_2}_normal_est.png' + ) + + + self.depth_name.append(raw_depth_f) + self.render_name.append(render_depth_f) + self.color_name.append(color_f) + self.normal_name.append(est_normal_f) + + def _get_data_list(self, filename): + with open(filename, 'r') as f: + content = f.read().splitlines() + data_list = [] + for ele in content: + left, _, right = ele.split('/') + valid, resize_count, one_scene_name, num_1, num_2, png = self._split_matterport_path(right) + if valid == False: + print(f'Invalid data_id in datalist: {ele}') + data_list.append((left, one_scene_name, num_1, num_2)) + return set(data_list) + + def _split_matterport_path(self, path): + try: + left, png = path.split('.') + lefts = left.split('_') + resize_count = left.count('resize') + one_scene_name = lefts[resize_count] + num_1 = lefts[resize_count+1][-1] + num_2 = lefts[resize_count+2] + return True, resize_count, one_scene_name, num_1, num_2, png + except Exception as e: + print(e) + return False, None, None, None, None, None + + def __len__(self): + return len(self.depth_name) + + def __getitem__(self, index): + color = np.array(Image.open(self.color_name[index])).transpose([2, 0, 1]) / 255. + render_depth = np.array(Image.open(self.render_name[index])) / 4000. + depth = np.array(Image.open(self.depth_name[index])) / 4000. + + normals = np.array(Image.open(self.normal_name[index])).transpose([2, 0, 1]) + normals = (normals - 90.) / 180. + + mask = np.zeros_like(depth) + mask[np.where(depth > 0)] = 1 + + return { + 'color': torch.tensor(color, dtype=torch.float32), + 'raw_depth': torch.tensor(depth, dtype=torch.float32).unsqueeze(0), + 'mask': torch.tensor(mask, dtype=torch.float32).unsqueeze(0), + 'normals': torch.tensor(normals, dtype=torch.float32).unsqueeze(0), + 'gt_depth': torch.tensor(render_depth, dtype=torch.float32).unsqueeze(0), + } \ No newline at end of file diff --git a/saic_depth_completion/data/datasets/nyu_raw.py b/saic_depth_completion/data/datasets/nyu_raw.py new file mode 100644 index 0000000..65b3645 --- /dev/null +++ b/saic_depth_completion/data/datasets/nyu_raw.py @@ -0,0 +1,158 @@ +import os +import torch + +import numpy as np +from PIL import Image +import random + +ROOT = '/Vol1/dbstore/datasets/depth_completion/NYUv2_raw' + +class NYURaw(): + def __init__(self, split, dt=0.01, valid_split=0.05, + focal=None, image_aug=None, + depth_aug=None, geometry_aug=None, + n_scenes=None): + + super().__init__() + + self.split = split + + self.fx = 5.1885790117450188e+02 + self.fy = 5.1946961112127485e+02 + self.cx = 3.2558244941119034e+02 + self.cy = 2.5373616633400465e+02 + + self.crop = 8 + + self.focal = focal + + self.image_aug = image_aug + self.depth_aug = depth_aug + self.geometry_aug = geometry_aug + + self.train_path = os.path.join(ROOT, "train") + self.test_path = os.path.join(ROOT, "test") + + with open(os.path.join(self.train_path, 'time_diffs.pth'), 'rb') as fin: + train_time_diffs = torch.load(fin) + + with open(os.path.join(self.test_path, 'time_diffs.pth'), 'rb') as fin: + test_time_diffs = torch.load(fin) + + train_keys = [] + for key in train_time_diffs: + time_lapse = train_time_diffs[key] + if abs(time_lapse) < dt: + train_keys.append(key) + + test_keys = [] + for key in test_time_diffs: + time_lapse = test_time_diffs[key] + if abs(time_lapse) < dt: + test_keys.append(key) + + if n_scenes is not None: + new_train_keys = [] + for key in train_keys: + if int(key.split('_')[0]) < n_scenes: + new_train_keys.append(key) + train_keys = new_train_keys + + random.seed(0) + random.shuffle(train_keys) + n_valid = int(valid_split * len(train_keys)) + + valid_keys = train_keys[:n_valid] + train_keys = train_keys[n_valid:] + + random.shuffle(train_keys) + random.shuffle(valid_keys) + random.shuffle(test_keys) + + self.keys = { + 'train': train_keys, + 'val': valid_keys, + 'test': test_keys} + + def __len__(self): + return len(self.keys[self.split]) + + def __getitem__(self, index): + key = self.keys[self.split][index] + + if self.split == 'test': + path = self.test_path + else: + path = self.train_path + + image_path = os.path.join(path, 'images', key) + depth_path = os.path.join(path, 'depths', key) + + image = torch.load(image_path)['image'] + depth = torch.load(depth_path)['depth'].astype('float32') / (2.0 ** 16 - 1.0) * 10.0 + + image = image[self.crop:-self.crop, self.crop:-self.crop] + depth = depth[self.crop:-self.crop, self.crop:-self.crop] + + depth = np.expand_dims(depth, -1) + + if self.image_aug is not None: + if self.split in self.image_aug: + image = self.image_aug[self.split](image=image)['image'] + + + # borders are invalid pixels by default + mask = (depth > 1.0e-4) & (depth < 10.0 - 1.0e-4) + + mask[:50, :] = 0 + mask[:, :40] = 0 + mask[-10:, :] = 0 + mask[:, -40:] = 0 + + mask = mask.astype('float32') + center = torch.tensor([self.cx, self.cy]) + focal = torch.tensor([self.fx, self.fy]) + + if self.focal is not None: + scale = self.focal / self.fx + + interpolation = lambda x: torch.nn.functional.interpolate( + torch.tensor(x).permute(2, 0, 1).unsqueeze(0), + scale_factor=scale, + mode='bilinear', + align_corners=True)[0].permute(1, 2, 0).numpy() + + image = interpolation(image) + depth = interpolation(depth) + mask = interpolation(mask) + center = center * scale + focal = focal * scale + + + + if self.geometry_aug is not None: + if self.split in self.geometry_aug: + res = self.geometry_aug[self.split](image=image, depth=depth, mask=mask, keypoints=[center]) + image = res['image'] + depth = res['depth'] + mask = res['mask'] + center = res['keypoints'][0] + + if self.depth_aug is not None: + if self.split in self.depth_aug: + res = self.depth_aug[self.split](image=depth, mask=mask) + depth = res['image'] + mask = res['mask'] + + mask = (mask > 1.0 - 1.0e-4) + + sample = { + 'image': torch.tensor(image).permute(2, 0, 1).float(), + 'depth': torch.tensor(depth).permute(2, 0, 1).float(), + 'mask': torch.tensor(mask).permute(2, 0, 1).bool(), +# 'type': torch.tensor(0), # 0 stands for absolute depth +# 'focal': focal.float(), +# 'center': center.float() + } + + return sample diff --git a/saic_depth_completion/data/datasets/nyuv2_test.py b/saic_depth_completion/data/datasets/nyuv2_test.py new file mode 100644 index 0000000..ce16d74 --- /dev/null +++ b/saic_depth_completion/data/datasets/nyuv2_test.py @@ -0,0 +1,94 @@ +import os +import torch +from torch.nn import functional as F +import numpy as np +from PIL import Image + +import cv2 + +# ROOT = '/Vol1/dbstore/datasets/depth_completion/Matterport3D/' +ROOT = "/Vol0/user/d.senushkin/datasets/nyuv2_test" + +class NyuV2Test: + def __init__( + self, root=ROOT, split="1gr10pv1pd", transforms=None + ): + self.transforms = transforms + self.data_root = os.path.join(root, "data", "DC_dataset_NUYV2_"+split) + if "official" in split: + file = "official_test.txt" + else: + file = "test.txt" + + self.split_file = os.path.join(root, file) + self.data_list = self._get_data_list(self.split_file) + self.color_name, self.depth_name, self.render_name = [], [], [] + + self._load_data() + + def _load_data(self): + for x in os.listdir(self.data_root): + scene = os.path.join(self.data_root, x) + raw_depth_scene = os.path.join(scene, 'undistorted_depth_images') + render_depth_scene = os.path.join(scene, 'render_depth') + + for y in os.listdir(raw_depth_scene): + valid, resize_count, one_scene_name, num_1, num_2, png = self._split_matterport_path(y) + if valid == False or png != 'png' or resize_count != 1: + continue + data_id = (x, one_scene_name, num_1, num_2) + if data_id not in self.data_list: + continue + raw_depth_f = os.path.join(raw_depth_scene, y) + render_depth_f = os.path.join(render_depth_scene, y.split('.')[0] + '_mesh_depth.png') + color_f = os.path.join( + scene,'undistorted_color_images', f'resize_{one_scene_name}_i{num_1}_{num_2}.jpg' + ) + + + self.depth_name.append(raw_depth_f) + self.render_name.append(render_depth_f) + self.color_name.append(color_f) + + def _get_data_list(self, filename): + with open(filename, 'r') as f: + content = f.read().splitlines() + data_list = [] + for ele in content: + left, _, right = ele.split('/') + valid, resize_count, one_scene_name, num_1, num_2, png = self._split_matterport_path(right) + if valid == False: + print(f'Invalid data_id in datalist: {ele}') + data_list.append((left, one_scene_name, num_1, num_2)) + return set(data_list) + + def _split_matterport_path(self, path): + try: + left, png = path.split('.') + lefts = left.split('_') + resize_count = left.count('resize') + one_scene_name = lefts[resize_count] + num_1 = lefts[resize_count+1][-1] + num_2 = lefts[resize_count+2] + return True, resize_count, one_scene_name, num_1, num_2, png + except Exception as e: + print(e) + return False, None, None, None, None, None + + def __len__(self): + return len(self.depth_name) + + def __getitem__(self, index): + color = np.array(Image.open(self.color_name[index])).transpose([2, 0, 1]) / 255. + render_depth = np.array(Image.open(self.render_name[index])) / 4000. + depth = np.array(Image.open(self.depth_name[index])) / 4000. + + mask = np.zeros_like(depth) + mask[np.where(depth > 0)] = 1 + + return { + 'color': torch.tensor(color, dtype=torch.float32), + 'raw_depth': torch.tensor(depth, dtype=torch.float32).unsqueeze(0), + 'mask': torch.tensor(mask, dtype=torch.float32).unsqueeze(0), + 'gt_depth': torch.tensor(render_depth, dtype=torch.float32).unsqueeze(0), + } \ No newline at end of file diff --git a/saic_depth_completion/engine/__init__.py b/saic_depth_completion/engine/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/saic_depth_completion/engine/inference.py b/saic_depth_completion/engine/inference.py new file mode 100644 index 0000000..a81e51d --- /dev/null +++ b/saic_depth_completion/engine/inference.py @@ -0,0 +1,51 @@ +import os +import time +import datetime +import torch +from tqdm import tqdm + +import matplotlib.pyplot as plt + +from saic_depth_completion.utils.meter import AggregatedMeter +from saic_depth_completion.utils.meter import Statistics as LossMeter +from saic_depth_completion.utils import visualize + + +def inference( + model, test_loaders, metrics, save_dir="", logger=None +): + + model.eval() + metrics_meter = AggregatedMeter(metrics, maxlen=20) + for subset, loader in test_loaders.items(): + idx = 0 + logger.info( + "Inference: subset -- {}. Total number of batches: {}.".format(subset, len(loader)) + ) + + metrics_meter.reset() + # loop over dataset + for batch in tqdm(loader): + batch = model.preprocess(batch) + pred = model(batch) + + with torch.no_grad(): + post_pred = model.postprocess(pred) + if save_dir: + B = batch["color"].shape[0] + for it in range(B): + fig = visualize.figure( + batch["color"][it], batch["raw_depth"][it], + batch["mask"][it], batch["gt_depth"][it], + post_pred[it], close=True + ) + fig.savefig( + os.path.join(save_dir, "result_{}.png".format(idx)), dpi=fig.dpi + ) + + idx += 1 + + metrics_meter.update(post_pred, batch["gt_depth"]) + + state = "Inference: subset -- {} | ".format(subset) + logger.info(state + metrics_meter.suffix) \ No newline at end of file diff --git a/saic_depth_completion/engine/train.py b/saic_depth_completion/engine/train.py new file mode 100644 index 0000000..638d834 --- /dev/null +++ b/saic_depth_completion/engine/train.py @@ -0,0 +1,84 @@ +import time +import datetime +import torch + +from saic_depth_completion.utils.meter import AggregatedMeter +from saic_depth_completion.utils.meter import Statistics as LossMeter +from saic_depth_completion.engine.val import validate + +def train( + model, trainloader, optimizer, val_loaders={}, scheduler=None, snapshoter=None, logger=None, + epochs=100, init_epoch=0, logging_period=10, metrics={}, tensorboard=None, tracker=None +): + + # move model to train mode + model.train() + logger.info( + "Total number of params: {}".format(model.count_parameters()) + ) + loss_meter = LossMeter(maxlen=20) + metrics_meter = AggregatedMeter(metrics, maxlen=20) + logger.info( + "Start training at {} epoch. Total number of epochs {}.".format(init_epoch, epochs) + ) + + num_batches = len(trainloader) + + start_time_stamp = time.time() + for epoch in range(init_epoch, epochs): + loss_meter.reset() + metrics_meter.reset() + # loop over dataset + for it, batch in enumerate(trainloader): + batch = model.preprocess(batch) + pred = model(batch) + loss = model.criterion(pred, batch["gt_depth"]) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + loss_meter.update(loss.item(), 1) + + if scheduler is not None: + scheduler.step() + + with torch.no_grad(): + post_pred = model.postprocess(pred) + metrics_meter.update(post_pred, batch["gt_depth"]) + + if (epoch * num_batches + it) % logging_period == 0: + state = "ep: {}, it {}/{} -- loss {:.4f}({:.4f}) | ".format( + epoch, it, num_batches, loss_meter.median, loss_meter.global_avg + ) + logger.info(state + metrics_meter.suffix) + + state = "ep: {}, it {}/{} -- loss {:.4f}({:.4f}) | ".format( + epoch, it, num_batches, loss_meter.median, loss_meter.global_avg + ) + + logger.info(state + metrics_meter.suffix) + + if tensorboard is not None: + tensorboard.update( + {k: v.global_avg for k, v in metrics_meter.meters.items()}, tag="train", epoch=epoch + ) + tensorboard.add_figures(batch, post_pred, epoch=epoch) + + if snapshoter is not None and epoch % snapshoter.period == 0: + snapshoter.save('snapshot_{}'.format(epoch)) + + # validate + # ... + validate( + model, val_loaders, metrics, epoch=epoch, logger=logger, + tensorboard=tensorboard, tracker=tracker + ) + + if snapshoter is not None: + snapshoter.save('snapshot_final') + + total_time = str(datetime.timedelta(seconds=time.time() - start_time_stamp)) + + logger.info( + "Training finished! Total spent time: {}.".format(total_time) + ) \ No newline at end of file diff --git a/saic_depth_completion/engine/val.py b/saic_depth_completion/engine/val.py new file mode 100644 index 0000000..c525290 --- /dev/null +++ b/saic_depth_completion/engine/val.py @@ -0,0 +1,41 @@ +import time +import datetime +import torch +from tqdm import tqdm + +from saic_depth_completion.utils.meter import AggregatedMeter +from saic_depth_completion.utils.meter import Statistics as LossMeter + + +def validate( + model, val_loaders, metrics, epoch=0, logger=None, tensorboard=None, tracker=None +): + + model.eval() + metrics_meter = AggregatedMeter(metrics, maxlen=20) + for subset, loader in val_loaders.items(): + logger.info( + "Validate: ep: {}, subset -- {}. Total number of batches: {}.".format(epoch, subset, len(loader)) + ) + + metrics_meter.reset() + # loop over dataset + for batch in tqdm(loader): + batch = model.preprocess(batch) + pred = model(batch) + + with torch.no_grad(): + post_pred = model.postprocess(pred) + metrics_meter.update(post_pred, batch["gt_depth"]) + + state = "Validate: ep: {}, subset -- {} | ".format(epoch, subset) + logger.info(state + metrics_meter.suffix) + + metric_state = {k: v.global_avg for k, v in metrics_meter.meters.items()} + + if tensorboard is not None: + tensorboard.update(metric_state, tag=subset, epoch=epoch) + tensorboard.add_figures(batch, post_pred, tag=subset, epoch=epoch) + + if tracker is not None: + tracker.update(subset, metric_state) \ No newline at end of file diff --git a/saic_depth_completion/metrics/__init__.py b/saic_depth_completion/metrics/__init__.py new file mode 100644 index 0000000..ba1d1ab --- /dev/null +++ b/saic_depth_completion/metrics/__init__.py @@ -0,0 +1,15 @@ +from .relative import * +from .absolute import * + +from saic_depth_completion.utils.registry import Registry + +LOSSES = Registry() + +LOSSES["DepthL2Loss"] = DepthL2Loss +# LOSSES["DepthLogL2Loss"] = DepthLogL2Loss +LOSSES["LogDepthL1Loss"] = LogDepthL1Loss +LOSSES["DepthL1Loss"] = DepthL1Loss +# LOSSES["DepthLogL1Loss"] = DepthLogL1Loss +LOSSES["SSIM"] = SSIM +LOSSES["BerHuLoss"] = BerHuLoss + diff --git a/saic_depth_completion/metrics/absolute.py b/saic_depth_completion/metrics/absolute.py new file mode 100644 index 0000000..42905fb --- /dev/null +++ b/saic_depth_completion/metrics/absolute.py @@ -0,0 +1,70 @@ +import torch +from torch import nn + +###### LOSSES ####### + +class BerHuLoss(nn.Module): + def __init__(self, scale=0.5, eps=1e-5): + super(BerHuLoss, self).__init__() + self.scale = scale + self.eps = eps + + def forward(self, pred, gt): + img1 = torch.zeros_like(pred) + img2 = torch.zeros_like(gt) + + img1 = img1.copy_(pred) + img2 = img2.copy_(gt) + + img1 = img1[img2 > self.eps] + img2 = img2[img2 > self.eps] + + diff = torch.abs(img1 - img2) + threshold = self.scale * torch.max(diff).detach() + mask = diff > threshold + diff[mask] = ((img1[mask]-img2[mask])**2 + threshold**2) / (2*threshold + self.eps) + return diff.sum() / diff.numel() + + +class LogDepthL1Loss(nn.Module): + def __init__(self, eps=1e-5): + super(LogDepthL1Loss, self).__init__() + self.eps = eps + def forward(self, pred, gt): + mask = gt > self.eps + diff = torch.abs(torch.log(gt[mask]) - pred[mask]) + return diff.mean() + +###### METRICS ####### + +class DepthL1Loss(nn.Module): + def __init__(self, eps=1e-5): + super(DepthL1Loss, self).__init__() + self.eps = eps + def forward(self, pred, gt): + img1 = torch.zeros_like(pred) + img2 = torch.zeros_like(gt) + + img1 = img1.copy_(pred) + img2 = img2.copy_(gt) + + mask = gt > self.eps + img1[~mask] = 0. + img2[~mask] = 0. + return nn.L1Loss(reduction="sum")(img1, img2), pred.numel() + +class DepthL2Loss(nn.Module): + def __init__(self, eps=1e-5): + super(DepthL2Loss, self).__init__() + self.eps = eps + def forward(self, pred, gt): + img1 = torch.zeros_like(pred) + img2 = torch.zeros_like(gt) + + img1 = img1.copy_(pred) + img2 = img2.copy_(gt) + + mask = gt > self.eps + img1[~mask] = 0. + img2[~mask] = 0. + return nn.MSELoss(reduction="sum")(img1, img2), pred.numel() diff --git a/saic_depth_completion/metrics/relative.py b/saic_depth_completion/metrics/relative.py new file mode 100644 index 0000000..489850d --- /dev/null +++ b/saic_depth_completion/metrics/relative.py @@ -0,0 +1,105 @@ +from functools import partial + +import torch +from torch import nn +import torch.nn.functional as F +from torch.autograd import Variable +import numpy as np +from math import exp + +###### METRICS ####### + +class DepthRel(nn.Module): + def __init__(self, eps=1e-5): + super(DepthRel, self).__init__() + self.eps = eps + def forward(self, pred, gt): + mask = gt > self.eps + diff = torch.abs(gt[mask] - pred[mask]) / gt[mask] + return diff.median() + +class Miss(nn.Module): + def __init__(self, thresh, eps=1e-5): + super(Miss, self).__init__() + self.thresh = thresh + self.eps = eps + def forward(self, pred, gt): + mask = (gt > self.eps)# & (pred > self.eps) + + pred_over_gt, gt_over_pred = pred[mask] / gt[mask], gt[mask] / pred[mask] + miss_map = torch.max(pred_over_gt, gt_over_pred) + hit_rate = torch.sum(miss_map < self.thresh ).float()#, miss_map.numel() + + # if torch.isnan(hit_rate):return 0 + + return hit_rate, miss_map.numel() + + +class SSIM(torch.nn.Module): + def __init__(self, window_size=11, size_average=True, eps=1e-5): + super(SSIM, self).__init__() + self.eps = eps + self.window_size = window_size + self.size_average = size_average + self.channel = 1 + self.window = self.create_window(window_size, self.channel) + + def gaussian(self, window_size, sigma): + gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) + return gauss / gauss.sum() + + def create_window(self, window_size, channel): + _1D_window = self.gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) + return window + + def _ssim(self, img1, img2, window, window_size, channel, size_average=True): + mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) + mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq + sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq + sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 + + C1 = 0.01 ** 2 + C2 = 0.03 ** 2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean(1).mean(1).mean(1) + + def forward(self, pred, gt): + + img1 = torch.zeros_like(pred) + img2 = torch.zeros_like(gt) + + img1 = img1.copy_(pred) + img2 = img2.copy_(gt) + + img2[img2 < self.eps] = 0 + img1[img2 < self.eps] = 0 + + (_, channel, _, _) = img1.size() + + + if channel == self.channel and self.window.data.type() == img1.data.type(): + window = self.window + else: + window = self.create_window(self.window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + self.window = window + self.channel = channel + + return self._ssim(img1, img2, window, self.window_size, channel, self.size_average) diff --git a/saic_depth_completion/modeling/__init__.py b/saic_depth_completion/modeling/__init__.py new file mode 100644 index 0000000..ce576a2 --- /dev/null +++ b/saic_depth_completion/modeling/__init__.py @@ -0,0 +1,3 @@ +from .lrn import LRN +from .dm_lrn import DM_LRN + diff --git a/saic_depth_completion/modeling/backbone/__init__.py b/saic_depth_completion/modeling/backbone/__init__.py new file mode 100644 index 0000000..f6c421d --- /dev/null +++ b/saic_depth_completion/modeling/backbone/__init__.py @@ -0,0 +1 @@ +from saic_depth_completion.modeling.backbone.build import build_backbone \ No newline at end of file diff --git a/saic_depth_completion/modeling/backbone/build.py b/saic_depth_completion/modeling/backbone/build.py new file mode 100644 index 0000000..b51662c --- /dev/null +++ b/saic_depth_completion/modeling/backbone/build.py @@ -0,0 +1,57 @@ +from saic_depth_completion.modeling.backbone.resnet import ResNet +from saic_depth_completion.modeling.backbone.hrnet import HRNet +from saic_depth_completion.modeling.backbone.efficientnet import EfficientNet +from saic_depth_completion.utils import registry +from saic_depth_completion.utils.model_zoo import (_load_state_dict_hrnet, + _load_state_dict_resnet, + _load_state_dict_efficientnet) + + +@registry.BACKBONES.register("resnet18") +@registry.BACKBONES.register("resnet34") +@registry.BACKBONES.register("resnet50") +@registry.BACKBONES.register("resnet101") +@registry.BACKBONES.register("resnet152") +def build_resnet(cfg): + resnet = ResNet(cfg) + if cfg.imagenet is True: + state_dict = _load_state_dict_resnet(cfg.arch) + resnet.load_state_dict(state_dict, strict=False) + return resnet + +@registry.BACKBONES.register("hrnet_w18") +@registry.BACKBONES.register("hrnet_w18_small_v1") +@registry.BACKBONES.register("hrnet_w18_small_v2") +@registry.BACKBONES.register("hrnet_w30") +@registry.BACKBONES.register("hrnet_w32") +@registry.BACKBONES.register("hrnet_w40") +@registry.BACKBONES.register("hrnet_w44") +@registry.BACKBONES.register("hrnet_w48") +@registry.BACKBONES.register("hrnet_w64") +def build_hrnet(cfg): + hrnet = HRNet(cfg) + if cfg.imagenet is True: + state_dict = _load_state_dict_hrnet(cfg.arch) + hrnet.load_state_dict(state_dict, strict=False) + return hrnet + + +@registry.BACKBONES.register("efficientnet-b0") +@registry.BACKBONES.register("efficientnet-b1") +@registry.BACKBONES.register("efficientnet-b2") +@registry.BACKBONES.register("efficientnet-b3") +@registry.BACKBONES.register("efficientnet-b4") +@registry.BACKBONES.register("efficientnet-b5") +@registry.BACKBONES.register("efficientnet-b6") +@registry.BACKBONES.register("efficientnet-b7") +def build_efficientnet(cfg): + efficientnet = EfficientNet(cfg) + if cfg.imagenet is True: + state_dict = _load_state_dict_efficientnet(cfg.arch) + efficientnet.load_state_dict(state_dict, strict=False) + return efficientnet + + +def build_backbone(cfg): + return registry.BACKBONES[cfg.arch](cfg) + diff --git a/saic_depth_completion/modeling/backbone/efficientnet.py b/saic_depth_completion/modeling/backbone/efficientnet.py new file mode 100644 index 0000000..80d5c2a --- /dev/null +++ b/saic_depth_completion/modeling/backbone/efficientnet.py @@ -0,0 +1,78 @@ +import sys +import torch.nn as nn +from efficientnet_pytorch import EfficientNet as _EfficientNet +from efficientnet_pytorch.utils import url_map, get_model_params + + +from collections import namedtuple + +StageSpec = namedtuple("StageSpec", ["num_channels", "stage_stamp"],) + +efficientnet_b0 = tuple(StageSpec(num_channels=nc, stage_stamp=ss) + for (nc, ss) in ((24, 3), (40, 4), (112, 9), (320, 16)) +) +efficientnet_b1 = tuple(StageSpec(num_channels=nc, stage_stamp=ss) + for (nc, ss) in ((24, 5), (40, 8), (112, 16), (320, 23)) +) +efficientnet_b2 = tuple(StageSpec(num_channels=nc, stage_stamp=ss) + for (nc, ss) in ((24, 5), (48, 8), (120, 16), (352, 23)) +) +efficientnet_b3 = tuple(StageSpec(num_channels=nc, stage_stamp=ss) + for (nc, ss) in ((32, 5), (48, 8), (136, 18), (384, 26)) +) +efficientnet_b4 = tuple(StageSpec(num_channels=nc, stage_stamp=ss) + for (nc, ss) in ((32, 6), (56, 10), (160, 22), (448, 32)) +) +efficientnet_b5 = tuple(StageSpec(num_channels=nc, stage_stamp=ss) + for (nc, ss) in ((40, 8), (64, 13), (176, 27), (512, 39)) +) +efficientnet_b6 = tuple(StageSpec(num_channels=nc, stage_stamp=ss) + for (nc, ss) in ((40, 9), (72, 15), (200, 31), (576, 45)) +) +efficientnet_b7 = tuple(StageSpec(num_channels=nc, stage_stamp=ss) + for (nc, ss) in ((48, 11), (80, 18), (224, 38), (640, 55)) +) + +class EfficientNet(_EfficientNet): + def __init__(self, model_cfg): + + blocks_args, global_params = get_model_params(model_cfg.arch, dict(image_size=None)) + super().__init__(blocks_args, global_params) + + self.multi_scale_output = model_cfg.multi_scale_output + self.stage_specs = sys.modules[__name__].__getattribute__(model_cfg.arch.replace("-", "_")) + self.num_blocks = len(self._blocks) + + del self._fc, self._conv_head, self._bn1, self._avg_pooling, self._dropout + + @property + def feature_channels(self): + if self.multi_scale_output: + return tuple([x.num_channels for x in self.stage_specs]) + return self.stage_specs[-1].num_channels + + + def forward(self, x): + + x = self._swish(self._bn0(self._conv_stem(x))) + + block_idx = 0. + features = [] + for stage in [ + self._blocks[:self.stage_specs[0].stage_stamp], + self._blocks[self.stage_specs[0].stage_stamp:self.stage_specs[1].stage_stamp], + self._blocks[self.stage_specs[1].stage_stamp:self.stage_specs[2].stage_stamp], + self._blocks[self.stage_specs[2].stage_stamp:], + ]: + for block in stage: + x = block( + x, self._global_params.drop_connect_rate * block_idx / self.num_blocks + ) + block_idx += 1. + + + features.append(x) + + if self.multi_scale_output: + return tuple(features) + return tuple([x]) diff --git a/saic_depth_completion/modeling/backbone/hrnet.py b/saic_depth_completion/modeling/backbone/hrnet.py new file mode 100644 index 0000000..a4e17ef --- /dev/null +++ b/saic_depth_completion/modeling/backbone/hrnet.py @@ -0,0 +1,423 @@ +import sys + +import torch +from torch import nn +from torch.nn import Conv2d +from collections import namedtuple + +from saic_depth_completion.modeling.backbone import res_blocks +from saic_depth_completion import ops + +StageSpec = namedtuple( + "StageSpec", + [ + "num_channels", # tuple + "num_blocks", # All layers in the same sequence have the same number output channels + "num_modules", # Number of residual blocks in the sequence + "num_branches", # True => return the last feature map from this sequence + "block" + ], +) + +hrnet_w18 = tuple( + StageSpec(num_channels=nc, num_blocks=nbl, num_modules=nm, num_branches=nbr, block=b) + for (nc, nbl, nm, nbr, b) in ( + ((64), (4), 1, 1, "Bottleneck"), + ((18, 36), (4, 4), 1, 2, "BasicBlock"), + ((18, 36, 72), (4, 4, 4), 4, 3, "BasicBlock"), + ((18, 36, 72, 144), (4, 4, 4, 4), 3, 4, "BasicBlock"), + ) +) + +hrnet_w18_small_v1 = tuple( + StageSpec(num_channels=nc, num_blocks=nbl, num_modules=nm, num_branches=nbr, block=b) + for (nc, nbl, nm, nbr, b) in ( + ((32), (1), 1, 1, "Bottleneck"), + ((16, 32), (2, 2), 1, 2, "BasicBlock"), + ((16, 32, 64), (2, 2, 2), 1, 3, "BasicBlock"), + ((16, 32, 64, 128), (2, 2, 2, 2), 1, 4, "BasicBlock"), + ) +) +hrnet_w18_small_v2 = tuple( + StageSpec(num_channels=nc, num_blocks=nbl, num_modules=nm, num_branches=nbr, block=b) + for (nc, nbl, nm, nbr, b) in ( + ((64), (2), 1, 1, "Bottleneck"), + ((18, 36), (2, 2), 1, 2, "BasicBlock"), + ((18, 36, 72), (2, 2, 2), 3, 3, "BasicBlock"), + ((18, 36, 72, 144), (2, 2, 2, 2), 2, 4, "BasicBlock"), + ) +) + +hrnet_w30 = tuple( + StageSpec(num_channels=nc, num_blocks=nbl, num_modules=nm, num_branches=nbr, block=b) + for (nc, nbl, nm, nbr, b) in ( + ((64), (4), 1, 1, "Bottleneck"), + ((30, 60), (4, 4), 1, 2, "BasicBlock"), + ((30, 60, 120), (4, 4, 4), 4, 3, "BasicBlock"), + ((30, 60, 120, 240), (4, 4, 4, 4), 3, 4, "BasicBlock"), + ) +) + +hrnet_w32 = tuple( + StageSpec(num_channels=nc, num_blocks=nbl, num_modules=nm, num_branches=nbr, block=b) + for (nc, nbl, nm, nbr, b) in ( + ((64), (4), 1, 1, "Bottleneck"), + ((32, 64), (4, 4), 1, 2, "BasicBlock"), + ((32, 64, 128), (4, 4, 4), 4, 3, "BasicBlock"), + ((32, 64, 128, 256), (4, 4, 4, 4), 3, 4, "BasicBlock"), + ) +) + +hrnet_w40 = tuple( + StageSpec(num_channels=nc, num_blocks=nbl, num_modules=nm, num_branches=nbr, block=b) + for (nc, nbl, nm, nbr, b) in ( + ((64), (4), 1, 1, "Bottleneck"), + ((40, 80), (4, 4), 1, 2, "BasicBlock"), + ((40, 80, 160), (4, 4, 4), 4, 3, "BasicBlock"), + ((40, 80, 160, 320), (4, 4, 4, 4), 3, 4, "BasicBlock"), + ) +) + +hrnet_w44 = tuple( + StageSpec(num_channels=nc, num_blocks=nbl, num_modules=nm, num_branches=nbr, block=b) + for (nc, nbl, nm, nbr, b) in ( + ((64), (4), 1, 1, "Bottleneck"), + ((44, 88), (4, 4), 1, 2, "BasicBlock"), + ((44, 88, 176), (4, 4, 4), 4, 3, "BasicBlock"), + ((44, 88, 176, 352), (4, 4, 4, 4), 3, 4, "BasicBlock"), + ) +) + +hrnet_w48 = tuple( + StageSpec(num_channels=nc, num_blocks=nbl, num_modules=nm, num_branches=nbr, block=b) + for (nc, nbl, nm, nbr, b) in ( + ((64), (4), 1, 1, "Bottleneck"), + ((48, 96), (4, 4), 1, 2, "BasicBlock"), + ((48, 96, 192), (4, 4, 4), 4, 3, "BasicBlock"), + ((48, 96, 192, 384), (4, 4, 4, 4), 3, 4, "BasicBlock"), + ) +) + +hrnet_w64 = tuple( + StageSpec(num_channels=nc, num_blocks=nbl, num_modules=nm, num_branches=nbr, block=b) + for (nc, nbl, nm, nbr, b) in ( + ((64), (4), 1, 1, "Bottleneck"), + ((64, 128), (4, 4), 1, 2, "BasicBlock"), + ((64, 128, 256), (4, 4, 4), 4, 3, "BasicBlock"), + ((64, 128, 256, 512), (4, 4, 4, 4), 3, 4, "BasicBlock"), + ) +) + + + + +class HighResolutionModule(nn.Module): + + def __init__(self, num_branches, blocks, num_blocks, num_inchannels, + num_channels, fuse_method, norm_layer, multi_scale_output=True): + super(HighResolutionModule, self).__init__() + self._check_branches( + num_branches, blocks, num_blocks, num_inchannels, num_channels) + + self.num_inchannels = num_inchannels + self.fuse_method = fuse_method + self.num_branches = num_branches + self._norm_layer = norm_layer + + self.multi_scale_output = multi_scale_output + + self.branches = self._make_branches( + num_branches, blocks, num_blocks, num_channels) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(True) + + def _check_branches(self, num_branches, blocks, num_blocks, + num_inchannels, num_channels): + if num_branches != len(num_blocks): + error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( + num_branches, len(num_blocks)) + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( + num_branches, len(num_channels)) + raise ValueError(error_msg) + + if num_branches != len(num_inchannels): + error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( + num_branches, len(num_inchannels)) + raise ValueError(error_msg) + + def _make_one_branch(self, branch_index, block, num_blocks, num_channels, + stride=1): + layers = [] + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index], stride, norm_layer=self._norm_layer)) + self.num_inchannels[branch_index] = \ + num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index], norm_layer=self._norm_layer)) + + return nn.Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + branches = [] + + for i in range(num_branches): + branches.append( + self._make_one_branch(i, block, num_blocks, num_channels)) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return None + + num_branches = self.num_branches + num_inchannels = self.num_inchannels + fuse_layers = [] + for i in range(num_branches if self.multi_scale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append(nn.Sequential( + Conv2d(num_inchannels[j], num_inchannels[i], 1, 1, 0, bias=False), + self._norm_layer(num_inchannels[i]), + nn.Upsample(scale_factor=2**(j-i), mode='nearest'))) + elif j == i: + fuse_layer.append(None) + else: + conv3x3s = [] + for k in range(i-j): + if k == i - j - 1: + num_outchannels_conv3x3 = num_inchannels[i] + conv3x3s.append(nn.Sequential( + Conv2d(num_inchannels[j], num_outchannels_conv3x3, + 3, 2, 1, bias=False), + self._norm_layer(num_outchannels_conv3x3))) + else: + num_outchannels_conv3x3 = num_inchannels[j] + conv3x3s.append(nn.Sequential( + Conv2d(num_inchannels[j], num_outchannels_conv3x3, + 3, 2, 1, bias=False), + self._norm_layer(num_outchannels_conv3x3), + nn.ReLU(True))) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_inchannels(self): + return self.num_inchannels + + def forward(self, x): + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + for i in range(len(self.fuse_layers)): + y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) + for j in range(1, self.num_branches): + if i == j: + y = y + x[j] + else: + y = y + self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + + return x_fuse + + +class HRNet(nn.Module): + def __init__(self, model_cfg, **kwargs): + super(HRNet, self).__init__() + + + self.fuze_method = "SUM" + self.stage_specs = sys.modules[__name__].__getattribute__(model_cfg.arch) + self._norm_layer = ops.NORM_LAYERS[model_cfg.norm_layer] + self.multiscale = model_cfg.multi_scale_output + + self.inplanes = 64 + + # stem net + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, + bias=False) + self.bn1 = self._norm_layer(64) + self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, + bias=False) + self.bn2 = self._norm_layer(64) + self.relu = nn.ReLU(inplace=True) + + + self.stage1_cfg = self.stage_specs[0] + num_channels = self.stage1_cfg.num_channels + block = getattr(res_blocks, self.stage1_cfg.block) + num_blocks = self.stage1_cfg.num_blocks + self.layer1 = self._make_layer(block, 64, num_channels, num_blocks) + # stage1_out_channel = block.expansion*num_channels + # self.layer1 = self._make_layer(Bottleneck, self.inplanes, 64, 4) + + self.stage2_cfg = self.stage_specs[1] + num_channels = self.stage2_cfg.num_channels + block = getattr(res_blocks, self.stage2_cfg.block) + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition1 = self._make_transition_layer([256], num_channels) + self.stage2, pre_stage_channels = self._make_stage( + self.stage2_cfg, num_channels) + + self.stage3_cfg = self.stage_specs[2] + num_channels = self.stage3_cfg.num_channels + block = getattr(res_blocks, self.stage3_cfg.block) + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition2 = self._make_transition_layer( + pre_stage_channels, num_channels) + self.stage3, pre_stage_channels = self._make_stage( + self.stage3_cfg, num_channels) + + self.stage4_cfg = self.stage_specs[3] + num_channels = self.stage4_cfg.num_channels + block = getattr(res_blocks, self.stage4_cfg.block) + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition3 = self._make_transition_layer( + pre_stage_channels, num_channels) + self.stage4, pre_stage_channels = self._make_stage( + self.stage4_cfg, num_channels, multi_scale_output=self.multiscale) + self.num_channels = pre_stage_channels + + def _make_transition_layer( + self, num_channels_pre_layer, num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append(nn.Sequential( + Conv2d(num_channels_pre_layer[i], + num_channels_cur_layer[i], + 3, + 1, + 1, + bias=False), + self._norm_layer(num_channels_cur_layer[i]), + nn.ReLU(inplace=True))) + else: + # authors fuck TorchScript + transition_layers.append(None) + else: + conv3x3s = [] + for j in range(i+1-num_branches_pre): + inchannels = num_channels_pre_layer[-1] + outchannels = num_channels_cur_layer[i] \ + if j == i-num_branches_pre else inchannels + conv3x3s.append(nn.Sequential( + Conv2d( + inchannels, outchannels, 3, 2, 1, bias=False), + self._norm_layer(outchannels), + nn.ReLU(inplace=True))) + + transition_layers.append(nn.Sequential(*conv3x3s)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, inplanes, planes, num_blocks, stride=1): + + downsample = None + + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + blocks.conv1x1(inplanes, planes * block.expansion, stride), + self._norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(inplanes, planes, stride, downsample, norm_layer=self._norm_layer)) + inplanes = planes * block.expansion + for _ in range(1, num_blocks): + layers.append(block(inplanes, planes, norm_layer=self._norm_layer)) + + return nn.Sequential(*layers) + + + def _make_stage(self, layer_config, num_inchannels, multi_scale_output=True): + num_modules = layer_config.num_modules + num_branches = layer_config.num_branches + num_blocks = layer_config.num_blocks + num_channels = layer_config.num_channels + block = getattr(res_blocks, layer_config.block) + + # All original configs have 'FUSE_METHOD' = 'SUM' + fuse_method = self.fuze_method #layer_config['FUSE_METHOD'] + + modules = [] + for i in range(num_modules): + # multi_scale_output is only used last module + if not multi_scale_output and i == num_modules - 1: + reset_multi_scale_output = False + else: + reset_multi_scale_output = True + + modules.append( + HighResolutionModule(num_branches, + block, + num_blocks, + num_inchannels, + num_channels, + fuse_method, + self._norm_layer, + reset_multi_scale_output) + ) + num_inchannels = modules[-1].get_num_inchannels() + + return nn.Sequential(*modules), num_inchannels + + @property + def feature_channels(self): + if self.multiscale: + return self.stage_specs[-1].num_channels + else: + return self.stage_specs[-1].num_channels[-1] + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + + x = self.layer1(x) + + x_list = [] + for i in range(self.stage2_cfg.num_branches): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + y_list = self.stage2(x_list) + + x_list = [] + for i in range(self.stage3_cfg.num_branches): + if self.transition2[i] is not None: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage3(x_list) + + x_list = [] + for i in range(self.stage4_cfg.num_branches): + if self.transition3[i] is not None: + x_list.append(self.transition3[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage4(x_list) + + return tuple(y_list) diff --git a/saic_depth_completion/modeling/backbone/res_blocks.py b/saic_depth_completion/modeling/backbone/res_blocks.py new file mode 100644 index 0000000..97eec94 --- /dev/null +++ b/saic_depth_completion/modeling/backbone/res_blocks.py @@ -0,0 +1,101 @@ +import torch +import torch.nn as nn +from collections import namedtuple +import sys + + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + __constants__ = ['downsample'] + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + __constants__ = ['downsample'] + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out \ No newline at end of file diff --git a/saic_depth_completion/modeling/backbone/resnet.py b/saic_depth_completion/modeling/backbone/resnet.py new file mode 100644 index 0000000..22ae5eb --- /dev/null +++ b/saic_depth_completion/modeling/backbone/resnet.py @@ -0,0 +1,142 @@ +import sys +from collections import namedtuple + +from torch import nn + +from saic_depth_completion.modeling.backbone import res_blocks +from saic_depth_completion import ops + +StageSpec = namedtuple("StageSpec", ["block_count", "block"],) + +resnet18 = tuple(StageSpec(block_count=c, block=b) + for (c, b) in ((2, "BasicBlock"),(2, "BasicBlock"),(2, "BasicBlock"),(2, "BasicBlock")) +) + +resnet34 = tuple(StageSpec(block_count=c, block=b) + for (c, b) in ((3, "BasicBlock"),(4, "BasicBlock"),(6, "BasicBlock"),(3, "BasicBlock")) +) + +resnet50 = tuple(StageSpec(block_count=c, block=b) + for (c, b) in ((3, "Bottleneck"),(4, "Bottleneck"),(6, "Bottleneck"),(3, "Bottleneck")) +) + +resnet101 = tuple(StageSpec(block_count=c, block=b) + for (c, b) in ((3, "Bottleneck"),(4, "Bottleneck"),(23, "Bottleneck"),(3, "Bottleneck")) +) + +resnet152 = tuple(StageSpec(block_count=c, block=b) + for (c, b) in ((3, "Bottleneck"),(8, "Bottleneck"),(36, "Bottleneck"),(3, "Bottleneck")) +) + +class ResNet(nn.Module): + + def __init__(self, model_cfg, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None): + + super(ResNet, self).__init__() + + self.stage_specs = sys.modules[__name__].__getattribute__(model_cfg.arch) + self.block = getattr(res_blocks, self.stage_specs[0].block) + self._norm_layer = ops.NORM_LAYERS[model_cfg.norm_layer] + self.multiscale = model_cfg.multi_scale_output + self.base_channel = 64 * self.block.expansion + self.input_channels = 3 + + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(self.input_channels, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = self._norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + # 1/4 + self.layer1 = self._make_layer(self.block, 64, self.stage_specs[0].block_count) + # 1/8 + self.layer2 = self._make_layer(self.block, 128, self.stage_specs[1].block_count, stride=2, + dilate=replace_stride_with_dilation[0]) + # 1/16 + self.layer3 = self._make_layer(self.block, 256, self.stage_specs[2].block_count, stride=2, + dilate=replace_stride_with_dilation[1]) + # 1/32 + self.layer4 = self._make_layer(self.block, 512, self.stage_specs[3].block_count, stride=2, + dilate=replace_stride_with_dilation[2]) + + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + # self.cuda() + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + res_blocks.conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + @property + def feature_channels(self): + if self.multiscale: + return self.base_channel, self.base_channel*2, \ + self.base_channel*4, self.base_channel*8 + return self.base_channel*8 + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + res = [] + x = self.layer1(x) + res += [x] + x = self.layer2(x) + res += [x] + x = self.layer3(x) + res += [x] + x = self.layer4(x) + res += [x] + + if self.multiscale: + return tuple(res) + return tuple([x]) diff --git a/saic_depth_completion/modeling/blocks.py b/saic_depth_completion/modeling/blocks.py new file mode 100644 index 0000000..15328ad --- /dev/null +++ b/saic_depth_completion/modeling/blocks.py @@ -0,0 +1,158 @@ +from functools import partial + +import torch +from torch import nn +import torch.nn.functional as F + +from saic_depth_completion import ops +from saic_depth_completion.modeling.backbone.res_blocks import Bottleneck + + +class CRPBlock(nn.Module): + def conv1x1(self, in_planes, out_planes, stride=1, bias=False): + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, + padding=0, bias=bias) + + def __init__( + self, in_planes, out_planes, n_stages=4 + ): + super(CRPBlock, self).__init__() + for i in range(n_stages): + setattr( + self, '{}_{}'.format(i + 1, 'crp'), + self.conv1x1( + in_planes if (i == 0) else out_planes, + out_planes, stride=1, bias=False + ) + ) + self.stride = 1 + self.n_stages = n_stages + self.maxpool = nn.MaxPool2d(kernel_size=5, stride=1, padding=2) + + def forward(self, x): + top = x + for i in range(self.n_stages): + top = self.maxpool(top) + top = getattr(self, '{}_{}'.format(i + 1, 'crp'))(top) + x = top + x + return x + +class FusionBlock(nn.Module): + def __init__( + self, hidden_dim, small_planes, activation=("ReLU", []), upsample="bilinear", + ): + super(FusionBlock, self).__init__() + self.act = ops.ACTIVATION_LAYERS[activation[0]](*activation[1]) + self.upsample = upsample + self.conv1 = nn.Conv2d(hidden_dim, hidden_dim, 1, bias=True) + self.conv2 = nn.Conv2d(small_planes, hidden_dim, 1, bias=True) + + def forward(self, input1, input2): + x1 = self.conv1(input1) + x2 = F.interpolate( + self.conv2(input2), size=x1.size()[-2:], mode=self.upsample, align_corners=True + ) + return self.act(x1 + x2) + + +class MaskEncoder(nn.Module): + def __init__( + self, out_ch, scale, kernel_size=3, activation=("ReLU", []), + round=False, upsample="bilinear", + ): + super(MaskEncoder, self).__init__() + self.scale = scale + self.upsample = upsample + self.round = round + self.convs = nn.ModuleList([ + nn.Conv2d(1, out_ch // 4, kernel_size, padding=(kernel_size-1)//2), + nn.Conv2d(out_ch // 4, out_ch // 2, kernel_size, padding=(kernel_size-1)//2), + nn.Conv2d(out_ch // 2, out_ch, kernel_size, padding=(kernel_size-1)//2) + ]) + self.acts = nn.ModuleList([ + ops.ACTIVATION_LAYERS[activation[0]](*activation[1]), + ops.ACTIVATION_LAYERS[activation[0]](*activation[1]), + ops.ACTIVATION_LAYERS[activation[0]](*activation[1]), + ]) + def forward(self, mask): + + mask = F.interpolate( + mask, scale_factor=1./self.scale, mode=self.upsample + ) + if self.round: + mask = torch.round(mask).float() + + x = mask + for conv, act in zip(self.convs, self.acts): + x = conv(x) + x = act(x) + return x + +class SharedEncoder(nn.Module): + def __init__( + self, out_channels, scales, in_channels=1, kernel_size=3, upsample="bilinear", activation=("ReLU", []) + ): + super(SharedEncoder, self).__init__() + self.scales = scales + self.upsample = upsample + self.feature_extractor = nn.Sequential(*[ + nn.Conv2d(in_channels, 32, kernel_size, padding=(kernel_size - 1) // 2), + ops.ACTIVATION_LAYERS[activation[0]](*activation[1]), + nn.Conv2d(32, 64, kernel_size, padding=(kernel_size - 1) // 2), + ops.ACTIVATION_LAYERS[activation[0]](*activation[1]) + ]) + + self.predictors = [] + for oup in out_channels: + self.predictors.append( + nn.Sequential(*[ + nn.Conv2d(64, oup, kernel_size=3, padding=0), + ops.ACTIVATION_LAYERS[activation[0]](*activation[1]) + ]) + ) + self.predictors = nn.ModuleList(self.predictors) + + def forward(self, x): + features = self.feature_extractor(x) + res = [] + for it, scale in enumerate(self.scales): + features_scaled = F.interpolate(features, scale_factor=1./scale, mode=self.upsample) + res.append( + self.predictors[it](features_scaled) + ) + return tuple(res) + + +class AdaptiveBlock(nn.Module): + def __init__( + self, x_in_ch, x_out_ch, y_ch, modulation="spade", activation=("ReLU", []), upsample='bilinear' + ): + super(AdaptiveBlock, self).__init__() + + x_hidden_ch = min(x_in_ch, x_out_ch) + self.learned_res = x_in_ch != x_out_ch + + if self.learned_res: + self.residual = nn.Conv2d(x_in_ch, x_out_ch, kernel_size=1, bias=False) + + self.modulation1 = ops.MODULATION_LAYERS[modulation](x_ch=x_in_ch, y_ch=y_ch, upsample=upsample) + self.act1 = ops.ACTIVATION_LAYERS[activation[0]](*activation[1]) + self.conv1 = nn.Conv2d(x_in_ch, x_hidden_ch, kernel_size=3, padding=1, bias=True) + self.modulation2 = ops.MODULATION_LAYERS[modulation](x_ch=x_hidden_ch, y_ch=y_ch, upsample=upsample) + self.act2 = ops.ACTIVATION_LAYERS[activation[0]](*activation[1]) + self.conv2 = nn.Conv2d(x_hidden_ch, x_out_ch, kernel_size=3, padding=1, bias=True) + + def forward(self, x, skip): + if self.learned_res: + res = self.residual(x) + else: + res = x + + x = self.modulation1(x, skip) + x = self.act1(x) + x = self.conv1(x) + x = self.modulation2(x, skip) + x = self.act2(x) + x = self.conv2(x) + + return x + res \ No newline at end of file diff --git a/saic_depth_completion/modeling/dm_lrn.py b/saic_depth_completion/modeling/dm_lrn.py new file mode 100644 index 0000000..8764102 --- /dev/null +++ b/saic_depth_completion/modeling/dm_lrn.py @@ -0,0 +1,163 @@ +from functools import partial + +import torch +from torch import nn +import torch.nn.functional as F + + +from saic_depth_completion.modeling.backbone import build_backbone +from saic_depth_completion.modeling.blocks import AdaptiveBlock, MaskEncoder, FusionBlock, CRPBlock, SharedEncoder +from saic_depth_completion.utils import registry +from saic_depth_completion import ops +from saic_depth_completion.metrics import LOSSES + + + +@registry.MODELS.register("DM-LRN") +class DM_LRN(nn.Module): + def __init__(self, model_cfg): + super(DM_LRN, self).__init__() + self.stem = nn.Sequential( + nn.Conv2d(in_channels=4, out_channels=3, kernel_size=7, padding=3), + nn.BatchNorm2d(3), + nn.ReLU(inplace=True) + ) + self.backbone = build_backbone(model_cfg.backbone) + self.feature_channels = self.backbone.feature_channels + + self.predict_log_depth = model_cfg.predict_log_depth + self.losses = model_cfg.criterion + self.activation = model_cfg.activation + self.modulation = model_cfg.modulation + self.channels = model_cfg.max_channels + self.upsample = model_cfg.upsample + self.use_crp = model_cfg.use_crp + self.mask_encoder_ksize = model_cfg.mask_encoder_ksize + + self.modulation32 = AdaptiveBlock( + self.channels, self.channels, self.channels, + modulation=self.modulation, activation=self.activation, + upsample=self.upsample + ) + self.modulation16 = AdaptiveBlock( + self.channels // 2, self.channels // 2, self.channels // 2, + modulation=self.modulation, activation=self.activation, + upsample=self.upsample + ) + self.modulation8 = AdaptiveBlock( + self.channels // 4, self.channels // 4, self.channels // 4, + modulation=self.modulation, activation=self.activation, + upsample=self.upsample + ) + self.modulation4 = AdaptiveBlock( + self.channels // 8, self.channels // 8, self.channels // 8, + modulation=self.modulation, activation=self.activation, + upsample=self.upsample + ) + + self.modulation4_1 = AdaptiveBlock( + self.channels // 8, self.channels // 16, self.channels // 8, + modulation=self.modulation, activation=self.activation, + upsample=self.upsample + ) + self.modulation4_2 = AdaptiveBlock( + self.channels // 16, self.channels // 16, self.channels // 16, + modulation=self.modulation, activation=self.activation, + upsample=self.upsample + ) + + + + self.mask_encoder = SharedEncoder( + out_channels=( + self.channels, self.channels // 2, self.channels // 4, + self.channels // 8, self.channels // 8, self.channels // 16 + ), + scales=(32, 16, 8, 4, 2, 1), + upsample=self.upsample, + activation=self.activation, + kernel_size=self.mask_encoder_ksize + ) + + + self.fusion_32x16 = FusionBlock(self.channels // 2, self.channels, upsample=self.upsample) + self.fusion_16x8 = FusionBlock(self.channels // 4, self.channels // 2, upsample=self.upsample) + self.fusion_8x4 = FusionBlock(self.channels // 8, self.channels // 4, upsample=self.upsample) + + self.adapt1 = nn.Conv2d(self.feature_channels[-1], self.channels, 1, bias=False) + self.adapt2 = nn.Conv2d(self.feature_channels[-2], self.channels // 2, 1, bias=False) + self.adapt3 = nn.Conv2d(self.feature_channels[-3], self.channels // 4, 1, bias=False) + self.adapt4 = nn.Conv2d(self.feature_channels[-4], self.channels // 8, 1, bias=False) + + if self.use_crp: + self.crp1 = CRPBlock(self.channels, self.channels) + self.crp2 = CRPBlock(self.channels // 2, self.channels // 2) + self.crp3 = CRPBlock(self.channels // 4, self.channels // 4) + self.crp4 = CRPBlock(self.channels // 8, self.channels // 8) + + + self.predictor = nn.Sequential(*[ + nn.Conv2d(self.channels // 16, self.channels // 16, 1, padding=0, groups=self.channels // 16), + nn.Conv2d(self.channels // 16, 1, 3, padding=1) + ]) + if not self.predict_log_depth: + self.act = ops.ACTIVATION_LAYERS[self.activation[0]](*self.activation[1]) + + def criterion(self, pred, gt): + total = 0 + for spec in self.losses: + if len(spec) == 3: + loss_fn = LOSSES[spec[0]](*spec[2]) + else: + loss_fn = LOSSES[spec[0]]() + total += spec[1] * loss_fn(pred, gt) + return total + + def postprocess(self, pred): + if self.predict_log_depth: + return pred.exp() + + return pred + + def forward(self, batch): + + color, raw_depth, mask = batch["color"], batch["raw_depth"], batch["mask"] + + x = torch.cat([color, raw_depth], dim=1) + mask = mask + 1.0 + x = self.stem(x) + + features = self.backbone(x)[::-1] + if self.use_crp: + f1 = self.crp1(self.adapt1(features[0])) + else: + f1 = self.adapt1(features[0]) + f2 = self.adapt2(features[1]) + f3 = self.adapt3(features[2]) + f4 = self.adapt4(features[3]) + + mask_features = self.mask_encoder(mask) + + x = self.modulation32(f1, mask_features[0]) + x = self.fusion_32x16(f2, x) + x = self.crp2(x) if self.use_crp else x + + x = self.modulation16(x, mask_features[1]) + x = self.fusion_16x8(f3, x) + x = self.crp3(x) if self.use_crp else x + + x = self.modulation8(x, mask_features[2]) + x = self.fusion_8x4(f4, x) + x = self.crp4(x) if self.use_crp else x + + x = self.modulation4(x, mask_features[3]) + + x = F.interpolate(x, scale_factor=2, mode=self.upsample) + x = self.modulation4_1(x, mask_features[4]) + x = F.interpolate(x, scale_factor=2, mode=self.upsample) + x = self.modulation4_2(x, mask_features[5]) + + if not self.predict_log_depth: return self.act(self.predictor(x)) + + return self.predictor(x) + diff --git a/saic_depth_completion/modeling/lrn.py b/saic_depth_completion/modeling/lrn.py new file mode 100644 index 0000000..199cb62 --- /dev/null +++ b/saic_depth_completion/modeling/lrn.py @@ -0,0 +1,123 @@ +from functools import partial + +import torch +from torch import nn +import torch.nn.functional as F + + +from saic_depth_completion.modeling.backbone import build_backbone +from saic_depth_completion.modeling.blocks import AdaptiveBlock, MaskEncoder, FusionBlock, CRPBlock +from saic_depth_completion.utils import registry +from saic_depth_completion import ops +from saic_depth_completion.metrics import LOSSES + + + +@registry.MODELS.register("LRN") +class LRN(nn.Module): + def __init__(self, model_cfg): + super(LRN, self).__init__() + + self.predict_log_depth = model_cfg.predict_log_depth + self.losses = model_cfg.criterion + self.activation = model_cfg.activation + self.channels = model_cfg.max_channels + self.upsample = model_cfg.upsample + self.use_crp = model_cfg.use_crp + self.input_mask = model_cfg.input_mask + + in_ch = 4 if not self.input_mask else 5 + self.stem = nn.Sequential( + nn.Conv2d(in_channels=in_ch, out_channels=3, kernel_size=7, padding=3), + nn.BatchNorm2d(3), + nn.ReLU(inplace=True) + ) + self.backbone = build_backbone(model_cfg.backbone) + self.feature_channels = self.backbone.feature_channels + + + self.fusion_32x16 = FusionBlock(self.channels // 2, self.channels, upsample=self.upsample) + self.fusion_16x8 = FusionBlock(self.channels // 4, self.channels // 2, upsample=self.upsample) + self.fusion_8x4 = FusionBlock(self.channels // 8, self.channels // 4, upsample=self.upsample) + + self.adapt1 = nn.Conv2d(self.feature_channels[-1], self.channels, 1, bias=False) + self.adapt2 = nn.Conv2d(self.feature_channels[-2], self.channels // 2, 1, bias=False) + self.adapt3 = nn.Conv2d(self.feature_channels[-3], self.channels // 4, 1, bias=False) + self.adapt4 = nn.Conv2d(self.feature_channels[-4], self.channels // 8, 1, bias=False) + + if self.use_crp: + self.crp1 = CRPBlock(self.channels, self.channels) + self.crp2 = CRPBlock(self.channels // 2, self.channels // 2) + self.crp3 = CRPBlock(self.channels // 4, self.channels // 4) + self.crp4 = CRPBlock(self.channels // 8, self.channels // 8) + + + self.convs = nn.ModuleList([ + nn.Conv2d(self.channels // 8, self.channels // 8, 3, padding=1), + nn.Conv2d(self.channels // 8, self.channels // 16, 3, padding=1), + nn.Conv2d(self.channels // 16, self.channels // 16, 3, padding=1), + nn.Conv2d(self.channels // 16, self.channels // 32, 3, padding=1), + ]) + self.acts = nn.ModuleList([ + ops.ACTIVATION_LAYERS[self.activation[0]](*self.activation[1]), + ops.ACTIVATION_LAYERS[self.activation[0]](*self.activation[1]), + ops.ACTIVATION_LAYERS[self.activation[0]](*self.activation[1]), + ops.ACTIVATION_LAYERS[self.activation[0]](*self.activation[1]), + ]) + + self.predictor = nn.Conv2d(self.channels // 32, 1, 3, padding=1) + + def criterion(self, pred, gt): + total = 0 + for spec in self.losses: + if len(spec) == 3: + loss_fn = LOSSES[spec[0]](*spec[2]) + else: + loss_fn = LOSSES[spec[0]]() + total += spec[1] * loss_fn(pred, gt) + return total + + def postprocess(self, pred): + if self.predict_log_depth: + return pred.exp() + return pred + + def forward(self, batch): + + color, raw_depth, mask = batch["color"], batch["raw_depth"], batch["mask"] + + if self.input_mask: + x = torch.cat([color, raw_depth, mask], dim=1) + else: + x = torch.cat([color, raw_depth], dim=1) + + x = self.stem(x) + + features = self.backbone(x)[::-1] + if self.use_crp: + f1 = self.crp1(self.adapt1(features[0])) + else: + f1 = self.adapt1(features[0]) + f2 = self.adapt2(features[1]) + f3 = self.adapt3(features[2]) + f4 = self.adapt4(features[3]) + + x = self.fusion_32x16(f2, f1) + x = self.crp2(x) if self.use_crp else x + + x = self.fusion_16x8(f3, x) + x = self.crp3(x) if self.use_crp else x + + x = self.fusion_8x4(f4, x) + x = self.crp4(x) if self.use_crp else x + + + x = F.interpolate(x, scale_factor=2, mode=self.upsample) + x = self.acts[0](self.convs[0](x)) + x = self.acts[1](self.convs[1](x)) + + x = F.interpolate(x, scale_factor=2, mode=self.upsample) + x = self.acts[2](self.convs[2](x)) + x = self.acts[3](self.convs[3](x)) + + return self.predictor(x) \ No newline at end of file diff --git a/saic_depth_completion/modeling/meta.py b/saic_depth_completion/modeling/meta.py new file mode 100644 index 0000000..5ec12ab --- /dev/null +++ b/saic_depth_completion/modeling/meta.py @@ -0,0 +1,43 @@ +import torch + +import torch.nn.functional as F + +from saic_depth_completion.utils import registry +# refactor this to +class MetaModel(torch.nn.Module): + def __init__(self, cfg, device): + super(MetaModel, self).__init__() + self.model = registry.MODELS[cfg.model.arch](cfg.model) + self.model.to(device) + self.device= device + + self.rgb_mean = cfg.train.rgb_mean + self.rgb_std = cfg.train.rgb_std + + self.depth_mean = cfg.train.depth_mean + self.depth_std = cfg.train.depth_std + + def forward(self, batch): + return self.model(batch) + + def preprocess(self, batch): + + batch["color"] = batch["color"] - torch.tensor(self.rgb_mean).view(1, 3, 1, 1) + batch["color"] = batch["color"] / torch.tensor(self.rgb_std).view(1, 3, 1, 1) + + mask = batch["raw_depth"] != 0 + batch["raw_depth"][mask] = batch["raw_depth"][mask] - self.depth_mean + batch["raw_depth"][mask] = batch["raw_depth"][mask] / self.depth_std + + for k, v in batch.items(): + batch[k] = v.to(self.device) + + return batch + + def postprocess(self, input): + return self.model.postprocess(input) + def criterion(self, pred, gt): + return self.model.criterion(pred, gt) + + def count_parameters(self): + return sum(p.numel() for p in self.model.parameters() if p.requires_grad) \ No newline at end of file diff --git a/saic_depth_completion/ops/__init__.py b/saic_depth_completion/ops/__init__.py new file mode 100644 index 0000000..c0b30a9 --- /dev/null +++ b/saic_depth_completion/ops/__init__.py @@ -0,0 +1,23 @@ +from functools import partial + +import torch + +from .batch_norm import FrozenBatchNorm2d +from .spade import SPADE, SelfSPADE +from .sean import SEAN + +from saic_depth_completion.utils.registry import Registry + +MODULATION_LAYERS = Registry() +NORM_LAYERS = Registry() +ACTIVATION_LAYERS = Registry() + +ACTIVATION_LAYERS["ReLU"] = torch.nn.ReLU +ACTIVATION_LAYERS["LeakyReLU"] = torch.nn.LeakyReLU + +MODULATION_LAYERS["SPADE"] = SPADE +MODULATION_LAYERS["SelfSPADE"] = SelfSPADE +MODULATION_LAYERS["SEAN"] = SEAN + +NORM_LAYERS["BatchNorm2d"] = torch.nn.BatchNorm2d +NORM_LAYERS["FrozenBatchNorm2d"] = FrozenBatchNorm2d \ No newline at end of file diff --git a/saic_depth_completion/ops/batch_norm.py b/saic_depth_completion/ops/batch_norm.py new file mode 100644 index 0000000..40af004 --- /dev/null +++ b/saic_depth_completion/ops/batch_norm.py @@ -0,0 +1,39 @@ +import torch +from torch import nn +from torch.nn import functional as F + + +class FrozenBatchNorm2d(nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters + are fixed + """ + + def __init__(self, channels, eps=1e-5): + super(FrozenBatchNorm2d, self).__init__() + self.eps = eps + self.register_buffer("weight", torch.ones(channels)) + self.register_buffer("bias", torch.zeros(channels)) + self.register_buffer("running_mean", torch.zeros(channels)) + self.register_buffer("running_var", torch.ones(channels) - eps) + def forward(self, x): + if x.requires_grad: + # When gradients are needed, F.batch_norm will use extra memory + # because its backward op computes gradients for weight/bias as well. + scale = self.weight * (self.running_var + self.eps).rsqrt() + bias = self.bias - self.running_mean * scale + scale = scale.reshape(1, -1, 1, 1) + bias = bias.reshape(1, -1, 1, 1) + return x * scale + bias + else: + # When gradients are not needed, F.batch_norm is a single fused op + # and provide more optimization opportunities. + return F.batch_norm( + x, + self.running_mean, + self.running_var, + self.weight, + self.bias, + training=False, + eps=self.eps, + ) \ No newline at end of file diff --git a/saic_depth_completion/ops/sean.py b/saic_depth_completion/ops/sean.py new file mode 100644 index 0000000..ecff22f --- /dev/null +++ b/saic_depth_completion/ops/sean.py @@ -0,0 +1,55 @@ +import torch +from torch import nn +from torch.nn import functional as F + +class SEAN(nn.Module): + + def __init__(self, x_ch, y_ch, kernel_size=3, upsample='nearest'): + super(SEAN, self).__init__() + assert upsample in ['nearest', 'bilinear'] + self.upsample = upsample + + padding = (kernel_size - 1) // 2 + + self.gamma_y = nn.Conv2d( + y_ch, x_ch, kernel_size=kernel_size, padding=padding, bias=False + ) + self.beta_y = nn.Conv2d( + y_ch, x_ch, kernel_size=kernel_size, padding=padding, bias=False + ) + + self.gamma_x = nn.Conv2d( + x_ch, x_ch, kernel_size=kernel_size, padding=padding, bias=False + ) + self.beta_x = nn.Conv2d( + x_ch, x_ch, kernel_size=kernel_size, padding=padding, bias=False + ) + + self.w_gamma = torch.tensor(1., requires_grad=True).cuda() + self.w_beta = torch.tensor(1., requires_grad=True).cuda() + + + # we assume that there is a some distribution at each cell in tensor + # => we need to compute stats over batch only + self.bn = nn.BatchNorm2d(x_ch, affine=False) + + # self.cuda() + + def forward(self, x, y): + + y = F.interpolate(y, size=x.size()[-2:], mode=self.upsample) + + x_normalized = self.bn(x) + + # do not need relu !!! We should be able to sub from signal + gamma_y = self.gamma_y(y) + beta_y = self.beta_y(y) + + gamma_x = self.gamma_x(x) + beta_x = self.beta_x(x) + + + gamma = (1 - self.w_gamma) * gamma_x + self.w_gamma * gamma_y + beta = (1 - self.w_beta) * beta_x + self.w_beta * beta_y + + return (gamma) * x_normalized + beta \ No newline at end of file diff --git a/saic_depth_completion/ops/spade.py b/saic_depth_completion/ops/spade.py new file mode 100644 index 0000000..45a36e3 --- /dev/null +++ b/saic_depth_completion/ops/spade.py @@ -0,0 +1,77 @@ +import torch +from torch import nn +from torch.nn import functional as F + +class SPADE(nn.Module): + + def __init__(self, x_ch, y_ch, kernel_size=3, upsample='nearest'): + super(SPADE, self).__init__() + self.eps = 1e-5 + assert upsample in ['nearest', 'bilinear'] + self.upsample = upsample + + padding = (kernel_size) // 2 + + self.gamma = nn.Conv2d( + y_ch, x_ch, kernel_size=kernel_size, padding=padding, bias=False + ) + self.beta = nn.Conv2d( + y_ch, x_ch, kernel_size=kernel_size, padding=padding, bias=False + ) + + # we assume that there is a some distribution at each cell in tensor + # => we need to compute stats over batch only + self.bn = nn.BatchNorm2d(x_ch, affine=False) + + def forward(self, x, y): + + y = F.interpolate(y, size=x.size()[-2:], mode=self.upsample) + + x_normalized = self.bn(x) + + # do not need relu !!! We should be able to sub from signal + gamma = self.gamma(y) + beta = self.beta(y) + + return (1+gamma) * x_normalized + beta + + +class SelfSPADE(nn.Module): + + def __init__(self, x_ch, y_ch, kernel_size=3, upsample='nearest'): + super(SelfSPADE, self).__init__() + self.eps = 1e-5 + assert upsample in ['nearest', 'bilinear'] + self.upsample = upsample + + padding = (kernel_size) // 2 + + self.gamma = nn.Conv2d( + y_ch+x_ch, x_ch, kernel_size=kernel_size, padding=padding, bias=False + ) + self.beta = nn.Conv2d( + y_ch+x_ch, x_ch, kernel_size=kernel_size, padding=padding, bias=False + ) + + self.adapt = nn.Conv2d( + y_ch+x_ch, x_ch, kernel_size=1, padding=0, bias=False + ) + + # we assume that there is a some distribution at each cell in tensor + # => we need to compute stats over batch only + self.bn = nn.BatchNorm2d(x_ch, affine=False) + + + def forward(self, x, y): + + y = F.interpolate(y, size=x.size()[-2:], mode=self.upsample) + + x = torch.cat([x, y], dim=1) + + x_normalized = self.bn(self.adapt(x)) + + # do not need relu !!! We should be able to sub from signal + gamma = 0.1 * self.gamma(x) + beta = 0.1 * self.beta(x) + + return (gamma) * x_normalized + beta diff --git a/saic_depth_completion/utils/__init__.py b/saic_depth_completion/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/saic_depth_completion/utils/experiment.py b/saic_depth_completion/utils/experiment.py new file mode 100644 index 0000000..f7b1818 --- /dev/null +++ b/saic_depth_completion/utils/experiment.py @@ -0,0 +1,104 @@ +import os +import yaml +import logging +import shutil +from easydict import EasyDict as edict +from saic_depth_completion.utils.registry import Registry + + +parsers = Registry() + +def setup_experiment(cfg, config_file, postfix="", log_dir="./logs/", tensorboard_dir="./tensorboard/", + delimiter="|", logger=None, training=True, debug=False): + if logger is None: + logger = logging.getLogger(__name__) + + experiment = edict() + experiment.name = delimiter.join( + parsers[cfg.model.arch](cfg.model) + parse_train_params(cfg.train) + ) + logger.info("Experiment name: {}".format(experiment.name)) + if postfix: + experiment.name = experiment.name + "-" + postfix + experiment.dir = os.path.join( + log_dir, experiment.name + ) + experiment.snapshot_dir = os.path.join( + log_dir, experiment.name, "snapshots" + ) + experiment.tensorboard_dir = os.path.join( + tensorboard_dir, experiment.name + ) + + if not debug: + logger.info("Experiment dir: {}".format(experiment.dir)) + os.makedirs(experiment.snapshot_dir, exist_ok=not training) + logger.info("Snapshot dir: {}".format(experiment.snapshot_dir)) + os.makedirs(experiment.tensorboard_dir, exist_ok=not training) + logger.info("Tensorboard dir: {}".format(experiment.tensorboard_dir)) + + if training: + shutil.copy2(config_file, experiment.dir) + + return experiment + + + +@parsers.register("DM-LRN") +def parse_dm_lrn(model_cfg): + model_params = [model_cfg.arch, model_cfg.modulation] + backbone_params = [ + model_cfg.backbone.arch, + "imagenet" if model_cfg.backbone.imagenet else "", + str(model_cfg.backbone.norm_layer).split('.')[-1][:-2], + ] + criterion = [] + for spec in model_cfg.criterion: + if len(spec) == 3: + criterion.append( + "(" + + str(spec[1]) + "*" + spec[0] + + "(" + ",".join( [str(i) for i in spec[2]] ) +")" + + ")" + ) + else: + criterion.append( + "(" + str(spec[1]) + "*" + spec[0] + ")" + ) + loss = "+".join(criterion) + + return model_params + backbone_params + [loss] + +@parsers.register("LRN") +def parse_arch1(model_cfg): + model_params = [ + model_cfg.arch, + "CRP" if model_cfg.use_crp else "NoCRP" + ] + backbone_params = [ + model_cfg.backbone.arch, + "imagenet" if model_cfg.backbone.imagenet else "", + str(model_cfg.backbone.norm_layer).split('.')[-1][:-2], + ] + criterion = [] + for spec in model_cfg.criterion: + if len(spec) == 3: + criterion.append( + "(" + + str(spec[1]) + "*" + spec[0] + + "(" + ",".join( [str(i) for i in spec[2]] ) +")" + + ")" + ) + else: + criterion.append( + "(" + str(spec[1]) + "*" + spec[0] + ")" + ) + loss = "+".join(criterion) + + return model_params + backbone_params + [loss] + +def parse_train_params(train_cfg): + train_params = [ + "lr="+str(train_cfg.lr), + ] + return train_params diff --git a/saic_depth_completion/utils/logger.py b/saic_depth_completion/utils/logger.py new file mode 100644 index 0000000..bc92089 --- /dev/null +++ b/saic_depth_completion/utils/logger.py @@ -0,0 +1,27 @@ +import logging +from colorlog import ColoredFormatter + + +def setup_logger(): + + formatter = ColoredFormatter( + "%(log_color)s%(asctime)s - %(yellow)s%(name)s: %(white)s%(message)s", + "%Y-%m-%d %H:%M:%S", + reset=True, + log_colors={ + 'DEBUG': 'green', + 'INFO': 'green', + 'WARNING': 'red', + 'ERROR': 'red', + 'CRITICAL': 'red', + } + ) + + logger = logging.getLogger('saic-dc') + handler = logging.StreamHandler() + handler.setFormatter(formatter) + logger.addHandler(handler) + logger.setLevel(logging.INFO) + + return logger + diff --git a/saic_depth_completion/utils/meter.py b/saic_depth_completion/utils/meter.py new file mode 100644 index 0000000..39d3692 --- /dev/null +++ b/saic_depth_completion/utils/meter.py @@ -0,0 +1,92 @@ +from collections import deque + +import torch + + +class Statistics(object): + def __init__(self, maxlen=20): + self.enum = deque(maxlen=maxlen) + self.denum = deque(maxlen=maxlen) + self.total = 0.0 + self.count = 0 + + def reset(self): + self.total = 0.0 + self.count = 0 + self.enum.clear() + self.denum.clear() + + def update(self, value, n): + self.enum.append(value) + self.denum.append(n) + self.count += n + self.total += value + + @property + def median(self): + enum = torch.tensor(list(self.enum)) + denum = torch.tensor(list(self.denum)) + sequence = enum / denum + return sequence.median().item() + + @property + def avg(self): + enum = torch.tensor(list(self.enum)) + denum = torch.tensor(list(self.denum)) + avg = enum.sum() / denum.sum() + return avg.item() + + @property + def global_avg(self): + return self.total / self.count + + +class Meter: + def __init__(self, metric_fn, maxlen=20): + self.metric_fn = metric_fn + self.stats = Statistics(maxlen) + + def reset(self): + self.stats.reset() + + def update(self, pred, gt): + value = self.metric_fn(pred, gt) + if isinstance(value, tuple): + self.stats.update(value[0].cpu(), value[1]) + else: + self.stats.update(value.item(), 1) + + @property + def median(self): + return self.stats.median + @property + def avg(self): + return self.stats.avg + + @property + def global_avg(self): + return self.stats.global_avg + +class AggregatedMeter(object): + def __init__(self, metrics, maxlen=20, delimiter=' # '): + self.delimiter = delimiter + self.meters = { + k: Meter(v, maxlen) for k, v in metrics.items() + } + + def reset(self): + for v in self.meters.values(): + v.reset() + + def update(self, pred, gt): + for v in self.meters.values(): + v.update(pred, gt) + + @property + def suffix(self): + suffix = [] + for k, v in self.meters.items(): + suffix.append( + "{}: {:.4f} ({:.4f})".format(k, v.median, v.global_avg) + ) + return self.delimiter.join(suffix) \ No newline at end of file diff --git a/saic_depth_completion/utils/model_zoo.py b/saic_depth_completion/utils/model_zoo.py new file mode 100644 index 0000000..89dc9bb --- /dev/null +++ b/saic_depth_completion/utils/model_zoo.py @@ -0,0 +1,50 @@ +import os +import re +import torch +from torch.hub import load_state_dict_from_url + + +model_resnet_imagenet = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', +} + +hrnet_root = "/dbstore/datasets/HRNet-Image-Classification_weights" +model_hrnet_imagenet = { + "hrnet_w18": os.path.join(hrnet_root, "hrnetv2_w18_imagenet_pretrained.pth"), + "hrnet_w18_small_v1": os.path.join(hrnet_root, "hrnet_w18_small_model_v1.pth"), + "hrnet_w18_small_v2": os.path.join(hrnet_root, "hrnet_w18_small_model_v2.pth"), + "hrnet_w30": os.path.join(hrnet_root, "hrnetv2_w30_imagenet_pretrained.pth"), + "hrnet_w32": os.path.join(hrnet_root, "hrnetv2_w32_imagenet_pretrained.pth"), + "hrnet_w40": os.path.join(hrnet_root, "hrnetv2_w40_imagenet_pretrained.pth"), + "hrnet_w44": os.path.join(hrnet_root, "hrnetv2_w44_imagenet_pretrained.pth"), + "hrnet_w48": os.path.join(hrnet_root, "hrnetv2_w48_imagenet_pretrained.pth"), + "hrnet_w64": os.path.join(hrnet_root, "hrnetv2_w64_imagenet_pretrained.pth"), +} + + +model_efficientnet_imagenet = { + 'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth', + 'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth', + 'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth', + 'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth', + 'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth', + 'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth', + 'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth', + 'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth', +} + +def _load_state_dict_hrnet(key): + state_dict = torch.load(model_hrnet_imagenet[key]) + return state_dict + +def _load_state_dict_resnet(key): + state_dict = load_state_dict_from_url(model_resnet_imagenet[key], progress=True) + return state_dict + +def _load_state_dict_efficientnet(key): + state_dict = load_state_dict_from_url(model_efficientnet_imagenet[key], progress=True) + return state_dict \ No newline at end of file diff --git a/saic_depth_completion/utils/registry.py b/saic_depth_completion/utils/registry.py new file mode 100644 index 0000000..2e805a1 --- /dev/null +++ b/saic_depth_completion/utils/registry.py @@ -0,0 +1,15 @@ +class Registry(dict): + def register(self, name): + if name in self.keys(): + raise ValueError("Registry already contains such key: {}".format(name)) + + def _register(fn): + self.update({name: fn}) + return fn + + return _register + + +MODELS = Registry() +BACKBONES = Registry() + diff --git a/saic_depth_completion/utils/snapshoter.py b/saic_depth_completion/utils/snapshoter.py new file mode 100644 index 0000000..8f7445f --- /dev/null +++ b/saic_depth_completion/utils/snapshoter.py @@ -0,0 +1,69 @@ +import logging +import os +import torch + +import numpy as np + +class Snapshoter: + def __init__( + self, + model, + optimizer=None, + scheduler=None, + save_dir="", + period=10, + logger=None, + ): + self.model = model + self.optimizer = optimizer + self.scheduler = scheduler + self.period = period + + if logger is None: + logger = logging.getLogger(__name__) + self.logger = logger + if not save_dir: + self.logger.warn( + "Snapshoter's arg 'save_dir' was not passed. Using default value({}).".format('./') + ) + self.save_dir = "./" + else: + self.save_dir = save_dir + + def save(self, fname, **kwargs): + + data = dict() + data["model"] = self.model.state_dict() + data["optimizer"] = self.optimizer.state_dict() if self.optimizer else None + data["scheduler"] = self.scheduler.state_dict() if self.scheduler else None + + data.update(kwargs) + + save_file = os.path.join(self.save_dir, "{}.pth".format(fname)) + self.logger.info("Saving snapshot into {}".format(save_file)) + torch.save(data, save_file) + + + def load(self, fname, model_only=False): + if os.path.exists(fname): + path = fname + elif os.path.exists(os.path.join(self.save_dir, fname)): + path = os.path.join(self.save_dir, fname) + else: + self.logger.info("No snapshot found. Initializing model from scratch") + return + + self.logger.info("Loading snapshot from {}".format(path)) + snapshot = torch.load(path, map_location=torch.device("cpu")) + self.model.load_state_dict(snapshot.pop("model")) + + if model_only: return snapshot + + if snapshot["optimizer"] is not None and self.optimizer: + self.logger.info("Loading optimizer from {}".format(path)) + self.optimizer.load_state_dict(snapshot.pop("optimizer")) + if snapshot["scheduler"] is not None and self.scheduler: + self.logger.info("Loading scheduler from {}".format(path)) + self.scheduler.load_state_dict(snapshot.pop("scheduler")) + + return snapshot \ No newline at end of file diff --git a/saic_depth_completion/utils/tensorboard.py b/saic_depth_completion/utils/tensorboard.py new file mode 100644 index 0000000..23d14a8 --- /dev/null +++ b/saic_depth_completion/utils/tensorboard.py @@ -0,0 +1,27 @@ +import torch.utils.tensorboard as tb + +from saic_depth_completion.utils import visualize + +class Tensorboard: + def __init__(self, tb_dir, max_figures=10): + self.tb_dir = tb_dir + self.max_figures = max_figures + def update(self, metrics_dict, epoch, tag="train"): + with tb.SummaryWriter(self.tb_dir) as writer: + for k, v in metrics_dict.items(): + writer.add_scalar(k+"/"+tag, v, epoch) + def add_figures(self, batch, post_pred, tag="train", epoch=0): + with tb.SummaryWriter(self.tb_dir) as writer: + B = batch["color"].shape[0] + self.max_figures = min(B, self.max_figures) + for it in range(self.max_figures): + fig = visualize.figure( + batch["color"][it], batch["raw_depth"][it], + batch["mask"][it], batch["gt_depth"][it], + post_pred[it] + ) + writer.add_figure( + figure=fig, + tag=tag + "_epoch_" + str(epoch), + close=True + ) diff --git a/saic_depth_completion/utils/tracker.py b/saic_depth_completion/utils/tracker.py new file mode 100644 index 0000000..b06d85d --- /dev/null +++ b/saic_depth_completion/utils/tracker.py @@ -0,0 +1,40 @@ +import numpy as np + + +class Tracker: + def __init__( + self, subset, target, snapshoter, init_state=float("inf"), delay=10, compare_fn=np.less, eps=0.03 + ): + self.subset = subset + self.target = target + self.snapshoter = snapshoter + self.state = init_state + self.delay = delay + self.compare_fn = compare_fn + self.epoch_counter = 0 + self.eps = eps + + def update(self, subset, metric_state): + if subset != self.subset: return + + self.epoch_counter += 1 + + if self.epoch_counter < self.delay: return + + # save best model + if self.compare_fn(metric_state[self.target], self.state): + self.state = metric_state[self.target] + self.snapshoter.save("snapshot_{}_{:.4f}".format(self.target, self.state)) + return + + # save model from epsilon neighborhood + if np.abs(self.state - metric_state[self.target]) < self.eps: + self.snapshoter.save("snapshot_eps_{}_{:.4f}".format(self.target, self.state)) + + +class ComposedTracker: + def __init__(self, trackers): + self.trackers = trackers + def update(self, subset, metric_state): + for tracker in self.trackers: + tracker.update(subset, metric_state) \ No newline at end of file diff --git a/saic_depth_completion/utils/visualize.py b/saic_depth_completion/utils/visualize.py new file mode 100644 index 0000000..f531578 --- /dev/null +++ b/saic_depth_completion/utils/visualize.py @@ -0,0 +1,35 @@ +import numpy as np +import matplotlib.pyplot as plt + +def figure(color, raw_depth, mask, gt, pred, close=False): + fig, axes = plt.subplots(3, 2, figsize=(7, 10)) + + color = color.cpu().permute(1, 2, 0) + raw_depth = raw_depth.cpu() + mask = mask.cpu() + gt = gt.cpu() + pred = pred.detach().cpu() + + vmin = min(gt.min(), pred.min()) + vmax = max(gt.max(), pred.max()) + + + axes[0, 0].set_title('RGB') + axes[0, 0].imshow((color - color.min()) / (color.max() - color.min()) ) + + axes[0, 1].set_title('raw_depth') + img = axes[0, 1].imshow(raw_depth[0], cmap='RdBu_r') + fig.colorbar(img, ax=axes[0, 1]) + + axes[1, 0].set_title('mask') + axes[1, 0].imshow(mask[0]) + + axes[1, 1].set_title('gt') + img = axes[1, 1].imshow(gt[0], cmap='RdBu_r', vmin=vmin, vmax=vmax) + fig.colorbar(img, ax=axes[1, 1]) + + axes[2, 1].set_title('pred') + img = axes[2, 1].imshow(pred[0], cmap='RdBu_r', vmin=vmin, vmax=vmax) + fig.colorbar(img, ax=axes[2, 1]) + if close: plt.close(fig) + return fig \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..babf4de --- /dev/null +++ b/setup.py @@ -0,0 +1,18 @@ +import setuptools + + +setuptools.setup( + name="saic-depth-completion", # Replace with your own username + version="0.0.1", + description="Experiments tools", + # long_description=long_description, + # long_description_content_type="text/markdown", + # url="https://github.com/pypa/sampleproject", + packages=setuptools.find_packages(), + classifiers=[ + "Programming Language :: Python :: 3", + # "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + ], + python_requires='>=3.6', +) \ No newline at end of file diff --git a/tools/test_net.py b/tools/test_net.py new file mode 100644 index 0000000..4231df9 --- /dev/null +++ b/tools/test_net.py @@ -0,0 +1,125 @@ +import torch + +import argparse + +from saic_depth_completion.data.datasets.matterport import Matterport +from saic_depth_completion.data.datasets.nyuv2_test import NyuV2Test +from saic_depth_completion.engine.inference import inference +from saic_depth_completion.utils.tensorboard import Tensorboard +from saic_depth_completion.utils.logger import setup_logger +from saic_depth_completion.utils.experiment import setup_experiment +from saic_depth_completion.utils.snapshoter import Snapshoter +from saic_depth_completion.modeling.meta import MetaModel +from saic_depth_completion.config import get_default_config +from saic_depth_completion.data.collate import default_collate +from saic_depth_completion.metrics import Miss, SSIM, DepthL2Loss, DepthL1Loss, DepthRel + +def main(): + parser = argparse.ArgumentParser(description="Some training params.") + + parser.add_argument( + "--default_cfg", dest="default_cfg", type=str, default="arch0", help="Default config" + ) + parser.add_argument( + "--config_file", default="", type=str, metavar="FILE", help="path to config file" + ) + parser.add_argument( + "--save_dir", default="", type=str, help="Save dir for predictions" + ) + parser.add_argument( + "--weights", default="", type=str, metavar="FILE", help="path to config file" + ) + + args = parser.parse_args() + + cfg = get_default_config(args.default_cfg) + cfg.merge_from_file(args.config_file) + cfg.freeze() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + model = MetaModel(cfg, device) + + logger = setup_logger() + + snapshoter = Snapshoter(model, logger=logger) + snapshoter.load(args.weights) + + metrics = { + 'mse': DepthL2Loss(), + 'mae': DepthL1Loss(), + 'd105': Miss(1.05), + 'd110': Miss(1.10), + 'd125_1': Miss(1.25), + 'd125_2': Miss(1.25**2), + 'd125_3': Miss(1.25**3), + 'rel': DepthRel(), + 'ssim': SSIM(), + } + + test_datasets = { + "test_matterport": Matterport(split="test"), + "official_nyu_test": NyuV2Test(split="official_test"), + # + # # first + # "1gr10pv1pd": NyuV2Test(split="1gr10pv1pd"), + # "1gr10pv2pd": NyuV2Test(split="1gr10pv2pd"), + # "1gr10pv5pd": NyuV2Test(split="1gr10pv5pd"), + # + # "1gr25pv1pd": NyuV2Test(split="1gr25pv1pd"), + # "1gr25pv2pd": NyuV2Test(split="1gr25pv2pd"), + # "1gr25pv5pd": NyuV2Test(split="1gr25pv5pd"), + # + # "1gr40pv1pd": NyuV2Test(split="1gr40pv1pd"), + # "1gr40pv2pd": NyuV2Test(split="1gr40pv2pd"), + # "1gr40pv5pd": NyuV2Test(split="1gr40pv5pd"), + # + # #second + # "4gr10pv1pd": NyuV2Test(split="4gr10pv1pd"), + # "4gr10pv2pd": NyuV2Test(split="4gr10pv2pd"), + # "4gr10pv5pd": NyuV2Test(split="4gr10pv5pd"), + # + # "4gr25pv1pd": NyuV2Test(split="4gr25pv1pd"), + # "4gr25pv2pd": NyuV2Test(split="4gr25pv2pd"), + # "4gr25pv5pd": NyuV2Test(split="4gr25pv5pd"), + # + # "4gr40pv1pd": NyuV2Test(split="4gr40pv1pd"), + # "4gr40pv2pd": NyuV2Test(split="4gr40pv2pd"), + # "4gr40pv5pd": NyuV2Test(split="4gr40pv5pd"), + # + # # third + # "8gr10pv1pd": NyuV2Test(split="8gr10pv1pd"), + # "8gr10pv2pd": NyuV2Test(split="8gr10pv2pd"), + # "8gr10pv5pd": NyuV2Test(split="8gr10pv5pd"), + # + # "8gr25pv1pd": NyuV2Test(split="8gr25pv1pd"), + # "8gr25pv2pd": NyuV2Test(split="8gr25pv2pd"), + # "8gr25pv5pd": NyuV2Test(split="8gr25pv5pd"), + # + # "8gr40pv1pd": NyuV2Test(split="8gr40pv1pd"), + # "8gr40pv2pd": NyuV2Test(split="8gr40pv2pd"), + # "8gr40pv5pd": NyuV2Test(split="8gr40pv5pd"), + + } + test_loaders = { + k: torch.utils.data.DataLoader( + dataset=v, + batch_size=1, + shuffle=False, + num_workers=4, + collate_fn=default_collate + ) + for k, v in test_datasets.items() + } + + inference( + model, + test_loaders, + save_dir=args.save_dir, + logger=logger, + metrics=metrics, + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tools/train_matterport.py b/tools/train_matterport.py new file mode 100644 index 0000000..7fd038d --- /dev/null +++ b/tools/train_matterport.py @@ -0,0 +1,120 @@ +import torch + +import argparse + +import numpy as np + +from saic_depth_completion.data.datasets.matterport import Matterport +from saic_depth_completion.engine.train import train +from saic_depth_completion.utils.tensorboard import Tensorboard +from saic_depth_completion.utils.logger import setup_logger +from saic_depth_completion.utils.experiment import setup_experiment +from saic_depth_completion.utils.snapshoter import Snapshoter +from saic_depth_completion.utils.tracker import ComposedTracker, Tracker +from saic_depth_completion.modeling.meta import MetaModel +from saic_depth_completion.config import get_default_config +from saic_depth_completion.data.collate import default_collate +from saic_depth_completion.metrics import Miss, SSIM, DepthL2Loss, DepthL1Loss, DepthRel + + +def main(): + parser = argparse.ArgumentParser(description="Some training params.") + parser.add_argument( + "--debug", dest="debug", type=bool, default=False, help="Setup debug mode" + ) + parser.add_argument( + "--postfix", dest="postfix", type=str, default="", help="Postfix for experiment's name" + ) + parser.add_argument( + "--default_cfg", dest="default_cfg", type=str, default="arch0", help="Default config" + ) + parser.add_argument( + "--config_file", default="", type=str, metavar="FILE", help="path to config file" + ) + parser.add_argument( + "--snapshot_period", default=10, type=int, help="Snapshot model one time over snapshot period" + ) + args = parser.parse_args() + + cfg = get_default_config(args.default_cfg) + cfg.merge_from_file(args.config_file) + cfg.freeze() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + model = MetaModel(cfg, device) + + logger = setup_logger() + experiment = setup_experiment( + cfg, args.config_file, logger=logger, training=True, debug=args.debug, postfix=args.postfix + ) + + + optimizer = torch.optim.Adam( + params=model.parameters(), lr=cfg.train.lr + ) + if not args.debug: + snapshoter = Snapshoter( + model, optimizer, period=args.snapshot_period, logger=logger, save_dir=experiment.snapshot_dir + ) + tensorboard = Tensorboard(experiment.tensorboard_dir) + tracker = ComposedTracker([ + Tracker(subset="test_matterport", target="mse", snapshoter=snapshoter, eps=0.01), + Tracker(subset="val_matterport", target="mse", snapshoter=snapshoter, eps=0.01), + ]) + else: + snapshoter, tensorboard, tracker = None, None, None + + + metrics = { + 'mse': DepthL2Loss(), + 'mae': DepthL1Loss(), + 'd105': Miss(1.05), + 'd110': Miss(1.10), + 'd125_1': Miss(1.25), + 'd125_2': Miss(1.25**2), + 'd125_3': Miss(1.25**3), + 'rel': DepthRel(), + 'ssim': SSIM(), + } + + train_dataset = Matterport(split="train") + train_loader = torch.utils.data.DataLoader( + dataset=train_dataset, + batch_size=cfg.train.batch_size, + shuffle=True, + num_workers=4, + collate_fn=default_collate + ) + + val_datasets = { + "val_matterport": Matterport(split="val"), + "test_matterport": Matterport(split="test"), + } + val_loaders = { + k: torch.utils.data.DataLoader( + dataset=v, + batch_size=cfg.test.batch_size, + shuffle=False, + num_workers=4, + collate_fn=default_collate + ) + for k, v in val_datasets.items() + } + + train( + model, + train_loader, + val_loaders=val_loaders, + optimizer=optimizer, + snapshoter=snapshoter, + epochs=200, + logger=logger, + metrics=metrics, + tensorboard=tensorboard, + tracker=tracker + ) + + +if __name__ == "__main__": + main() \ No newline at end of file