From 5d4a7ed40d99aad8fca97fad4750be102522df1e Mon Sep 17 00:00:00 2001 From: JacopoLungo Date: Sun, 23 Jun 2024 12:52:30 +0200 Subject: [PATCH] Initial commit --- .gitignore | 210 +++++ configs/custom_cfg.yaml | 2 + configs/default_cfg.yaml | 56 ++ configs/trees_cfg.yaml | 19 + metadata/.gitignore | 3 + models/.gitignore | 3 + output/.gitignore | 3 + requirements.txt | 290 +++++++ setup.py | 8 + shell/seg_tile.sh | 4 + src/maxarseg/ESAM_segment/segment.py | 110 +++ src/maxarseg/ESAM_segment/segment_utils.py | 326 ++++++++ .../SAM_segment/build_segmentation.py | 111 +++ src/maxarseg/SAM_segment/road_segmentation.py | 196 +++++ .../SAM_segment/segment_from_boxes.py | 54 ++ src/maxarseg/__init__.py | 0 src/maxarseg/assemble/build.py | 51 ++ src/maxarseg/assemble/delimiters.py | 102 +++ src/maxarseg/assemble/filter.py | 41 + src/maxarseg/assemble/g_build_utils.py | 111 +++ src/maxarseg/assemble/gen_gdf.py | 123 +++ src/maxarseg/assemble/holders.py | 768 ++++++++++++++++++ src/maxarseg/assemble/names.py | 61 ++ src/maxarseg/configs.py | 214 +++++ src/maxarseg/detect/detect.py | 154 ++++ src/maxarseg/detect/detect_utils.py | 118 +++ src/maxarseg/efficient_sam/__init__.py | 7 + .../efficient_sam/build_efficient_sam.py | 26 + src/maxarseg/efficient_sam/efficient_sam.py | 310 +++++++ .../efficient_sam/efficient_sam_decoder.py | 315 +++++++ .../efficient_sam/efficient_sam_encoder.py | 257 ++++++ src/maxarseg/efficient_sam/mlp.py | 29 + .../efficient_sam/two_way_transformer.py | 264 ++++++ src/maxarseg/explore_folders.py | 140 ++++ src/maxarseg/geo_datasets/geoDatasets.py | 373 +++++++++ src/maxarseg/main_noTGeo.py | 51 ++ src/maxarseg/main_seg_event_w_config.py | 46 ++ .../main_seg_event_w_config_partitioned.py | 82 ++ src/maxarseg/main_seg_single_tile.py | 107 +++ src/maxarseg/main_seg_tile.py | 92 +++ src/maxarseg/main_seg_tile_glbl_detections.py | 106 +++ src/maxarseg/main_seg_tile_w_config.py | 52 ++ src/maxarseg/output.py | 114 +++ src/maxarseg/plotting_utils.py | 94 +++ src/maxarseg/polygonize.py | 424 ++++++++++ src/maxarseg/samplers/samplers.py | 430 ++++++++++ src/maxarseg/samplers/samplers_utils.py | 256 ++++++ .../scripts/count_build_single_event.py | 180 ++++ .../scripts/count_buildings_and_roads.py | 43 + src/maxarseg/scripts/count_builds.py | 186 +++++ src/maxarseg/scripts/downloadMaxar.py | 86 ++ src/maxarseg/scripts/downloadRoads.py | 28 + src/maxarseg/scripts/make-gis-friendly.py | 29 + 53 files changed, 7265 insertions(+) create mode 100644 .gitignore create mode 100644 configs/custom_cfg.yaml create mode 100644 configs/default_cfg.yaml create mode 100644 configs/trees_cfg.yaml create mode 100644 metadata/.gitignore create mode 100644 models/.gitignore create mode 100644 output/.gitignore create mode 100644 requirements.txt create mode 100644 setup.py create mode 100644 shell/seg_tile.sh create mode 100644 src/maxarseg/ESAM_segment/segment.py create mode 100644 src/maxarseg/ESAM_segment/segment_utils.py create mode 100644 src/maxarseg/SAM_segment/build_segmentation.py create mode 100644 src/maxarseg/SAM_segment/road_segmentation.py create mode 100644 src/maxarseg/SAM_segment/segment_from_boxes.py create mode 100644 src/maxarseg/__init__.py create mode 100644 src/maxarseg/assemble/build.py create mode 100644 src/maxarseg/assemble/delimiters.py create mode 100644 src/maxarseg/assemble/filter.py create mode 100644 src/maxarseg/assemble/g_build_utils.py create mode 100644 src/maxarseg/assemble/gen_gdf.py create mode 100644 src/maxarseg/assemble/holders.py create mode 100644 src/maxarseg/assemble/names.py create mode 100644 src/maxarseg/configs.py create mode 100644 src/maxarseg/detect/detect.py create mode 100644 src/maxarseg/detect/detect_utils.py create mode 100644 src/maxarseg/efficient_sam/__init__.py create mode 100644 src/maxarseg/efficient_sam/build_efficient_sam.py create mode 100644 src/maxarseg/efficient_sam/efficient_sam.py create mode 100644 src/maxarseg/efficient_sam/efficient_sam_decoder.py create mode 100644 src/maxarseg/efficient_sam/efficient_sam_encoder.py create mode 100644 src/maxarseg/efficient_sam/mlp.py create mode 100644 src/maxarseg/efficient_sam/two_way_transformer.py create mode 100644 src/maxarseg/explore_folders.py create mode 100644 src/maxarseg/geo_datasets/geoDatasets.py create mode 100644 src/maxarseg/main_noTGeo.py create mode 100644 src/maxarseg/main_seg_event_w_config.py create mode 100644 src/maxarseg/main_seg_event_w_config_partitioned.py create mode 100644 src/maxarseg/main_seg_single_tile.py create mode 100644 src/maxarseg/main_seg_tile.py create mode 100644 src/maxarseg/main_seg_tile_glbl_detections.py create mode 100644 src/maxarseg/main_seg_tile_w_config.py create mode 100644 src/maxarseg/output.py create mode 100644 src/maxarseg/plotting_utils.py create mode 100644 src/maxarseg/polygonize.py create mode 100644 src/maxarseg/samplers/samplers.py create mode 100644 src/maxarseg/samplers/samplers_utils.py create mode 100644 src/maxarseg/scripts/count_build_single_event.py create mode 100644 src/maxarseg/scripts/count_buildings_and_roads.py create mode 100644 src/maxarseg/scripts/count_builds.py create mode 100644 src/maxarseg/scripts/downloadMaxar.py create mode 100644 src/maxarseg/scripts/downloadRoads.py create mode 100644 src/maxarseg/scripts/make-gis-friendly.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b79be79 --- /dev/null +++ b/.gitignore @@ -0,0 +1,210 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# IDE and editors +.idea/ +.vscode/ +outputs/ +.vscode/tasks.json + +.python-version + +#gpu memory snapshots +*.pickle +data + +#scalene profiler outputs +scalene/ + +#user dependent jobs +j_jobs/ +jobs_mosaic/ +jobs/ + +#test files +notebooks/to_delete.ipynb +src/maxarseg/geo_datasets/datasets.py + +*.geojson + +#shape files +*.shp +*.shx +*.dbf +*.prj +*.cpg +*.qix +*.qmd +*.sbn +*.sbx +*.xml +*.lock + +#geopacakge files +*.gpkg + +*.csv +*.npy +*.tif +*.parquet +*.paquet + +stats/ +labelled/ +notebooks/ + +morocco_builds/ + +#recycle bin +zz_recBin/ diff --git a/configs/custom_cfg.yaml b/configs/custom_cfg.yaml new file mode 100644 index 0000000..cac9255 --- /dev/null +++ b/configs/custom_cfg.yaml @@ -0,0 +1,2 @@ +event: + ix: 5 \ No newline at end of file diff --git a/configs/default_cfg.yaml b/configs/default_cfg.yaml new file mode 100644 index 0000000..1c1a050 --- /dev/null +++ b/configs/default_cfg.yaml @@ -0,0 +1,56 @@ +event: + ix: None + when: "pre" + +models: + gd: + bs: 1 + size: 600 + stride: 400 + device: "cuda:0" + root_path: "./models/GDINO" + config_file_path: "./models/GDINO/configs/GroundingDINO_SwinT_OGC.py" + weight_path: "./models/GDINO/weights/groundingdino_swint_ogc.pth" + text_prompt: "bush" + box_threshold: 0.15 + text_threshold: 0.30 + + df: + bs: 16 + size: 600 + patch_overlap: 0.25 + device: + - 0 #to indicate cuda:0 input [int(0)], to not specify use 'auto' + box_threshold: 0.1 + + esam: + bs: 1 + num_parall_queries: 200 + size: 1024 + stride: 768 + device: "cuda:0" + root_path: "./models/EfficientSAM" + +detection: + trees: + use_GD: true + use_DF: false + nms_threshold: 0.5 + min_ratio_GD_boxes_edges: 0.5 + max_area_boxes_mt2: 6000 + perc_reduce_tree_boxes: 0 + + buildings: + ext_mt_build_box: 0 + +segmentation: + general: + clean_mask: true + rmv_holes_area_th: 80 + rmv_small_obj_area_th: 80 + + roads: + road_width_mt: 5 + +output: + out_dir_root: "./output" diff --git a/configs/trees_cfg.yaml b/configs/trees_cfg.yaml new file mode 100644 index 0000000..a5a50c3 --- /dev/null +++ b/configs/trees_cfg.yaml @@ -0,0 +1,19 @@ +models: + df: + bs: 16 + size: 600 + patch_overlap: 0.25 + device: + - 0 #to indicate cuda:0 input [int(0)], to not specify use 'auto' + box_threshold: 0.1 + + gd: + size: 1024 + stride: 1024 + box_threshold: 0.12 + +detection: + trees: + nms_threshold: 0.5 + min_ratio_GD_boxes_edges: 0.5 + max_area_boxes_mt2: 7000 # best in range 6000-8000 \ No newline at end of file diff --git a/metadata/.gitignore b/metadata/.gitignore new file mode 100644 index 0000000..317b32c --- /dev/null +++ b/metadata/.gitignore @@ -0,0 +1,3 @@ +# ignore all files except for this one +* +!.gitignore \ No newline at end of file diff --git a/models/.gitignore b/models/.gitignore new file mode 100644 index 0000000..317b32c --- /dev/null +++ b/models/.gitignore @@ -0,0 +1,3 @@ +# ignore all files except for this one +* +!.gitignore \ No newline at end of file diff --git a/output/.gitignore b/output/.gitignore new file mode 100644 index 0000000..317b32c --- /dev/null +++ b/output/.gitignore @@ -0,0 +1,3 @@ +# ignore all files except for this one +* +!.gitignore \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..30c7eba --- /dev/null +++ b/requirements.txt @@ -0,0 +1,290 @@ +addict==2.4.0 +aenum==3.1.15 +affine==2.4.0 +aiohttp==3.9.3 +aiosignal==1.3.1 +alabaster==0.7.16 +albumentations==1.4.2 +annotated-types==0.6.0 +antlr4-python3-runtime==4.9.3 +anyio==4.0.0 +argon2-cffi==23.1.0 +argon2-cffi-bindings==21.2.0 +arrow==1.3.0 +asttokens==2.4.1 +async-lru==2.0.4 +async-timeout==4.0.3 +attrs==23.1.0 +austin-dist==3.6.0 +Babel==2.13.1 +backcall==0.2.0 +beautifulsoup4==4.12.2 +bitsandbytes==0.41.0 +bleach==6.1.0 +bqplot==0.12.42 +branca==0.6.0 +cachetools==5.3.2 +certifi==2023.7.22 +cffi==1.16.0 +charset-normalizer==3.3.1 +click==8.1.7 +click-plugins==1.1.1 +cligj==0.7.2 +cogeo-mosaic==7.1.0 +color-operations==0.1.1 +coloredlogs==15.0.1 +colour==0.1.5 +comm==0.1.4 +commonmark==0.9.1 +contourpy==1.1.1 +cycler==0.12.1 +debugpy==1.8.0 +decorator==5.1.1 +deepforest==1.3.3 +defusedxml==0.7.1 +docopt==0.6.2 +docstring_parser==0.16 +docutils==0.20.1 +efficientnet_pytorch==0.7.1 +einops==0.7.0 +exceptiongroup==1.1.3 +executing==2.0.0 +fastjsonschema==2.18.1 +filelock==3.12.4 +fiona==1.9.5 +flatbuffers==24.3.25 +folium==0.14.0 +fonttools==4.43.1 +fqdn==1.5.1 +frozenlist==1.4.1 +fsspec==2024.3.0 +gdown==4.7.1 +geojson==3.0.1 +geopandas==0.14.0 +groundingdino-py==0.4.0 +h11==0.14.0 +httpcore==0.18.0 +httpx==0.25.0 +huggingface-hub==0.21.4 +humanfriendly==10.0 +hydra-core==1.3.2 +idna==3.4 +igraph==0.11.4 +imagecodecs==2024.1.1 +imageio==2.34.0 +imagesize==1.4.1 +importlib_metadata==7.0.2 +importlib_resources==6.3.1 +ipyevents==2.0.2 +ipyfilechooser==0.6.0 +ipykernel==6.26.0 +ipyleaflet==0.17.4 +ipython==8.12.3 +ipython-genutils==0.2.0 +ipytree==0.2.2 +ipywidgets==8.1.1 +isoduration==20.11.0 +jedi==0.19.1 +Jinja2==3.1.2 +joblib==1.3.2 +json5==0.9.14 +jsonargparse==4.27.6 +jsonpointer==2.4 +jsonschema==4.19.1 +jsonschema-specifications==2023.7.1 +jupyter==1.0.0 +jupyter-console==6.6.3 +jupyter-events==0.8.0 +jupyter-lsp==2.2.0 +jupyter_client==8.5.0 +jupyter_core==5.4.0 +jupyter_server==2.9.1 +jupyter_server_terminals==0.4.4 +jupyterlab==4.0.7 +jupyterlab-pygments==0.2.2 +jupyterlab-widgets==3.0.9 +jupyterlab_server==2.25.0 +kiwisolver==1.4.5 +kornia==0.7.2 +kornia_rs==0.1.2 +lazy_loader==0.3 +leafmap==0.27.1 +lightly==1.4.7 +lightly-utils==0.0.2 +lightning==2.2.1 +lightning-utilities==0.11.0 +loguru==0.7.2 +markdown-it-py==3.0.0 +MarkupSafe==2.1.3 +matplotlib==3.8.0 +matplotlib-inline==0.1.6 +mdurl==0.1.2 +mistune==3.0.2 +morecantile==5.0.0 +mpmath==1.3.0 +multidict==6.0.5 +multimethod==1.11.2 +munch==4.0.0 +nbclient==0.8.0 +nbconvert==7.16.3 +nbformat==5.9.2 +nest-asyncio==1.5.8 +networkx==3.2.1 +notebook==7.0.6 +notebook_shim==0.2.3 +numexpr==2.8.7 +numpy==1.26.1 +nvidia-cublas-cu11==11.11.3.6 +nvidia-cublas-cu12==12.1.3.1 +nvidia-cuda-cupti-cu11==11.8.87 +nvidia-cuda-cupti-cu12==12.1.105 +nvidia-cuda-nvrtc-cu11==11.8.89 +nvidia-cuda-nvrtc-cu12==12.1.105 +nvidia-cuda-runtime-cu11==11.8.89 +nvidia-cuda-runtime-cu12==12.1.105 +nvidia-cudnn-cu11==8.7.0.84 +nvidia-cudnn-cu12==8.9.2.26 +nvidia-cufft-cu11==10.9.0.58 +nvidia-cufft-cu12==11.0.2.54 +nvidia-curand-cu11==10.3.0.86 +nvidia-curand-cu12==10.3.2.106 +nvidia-cusolver-cu11==11.4.1.48 +nvidia-cusolver-cu12==11.4.5.107 +nvidia-cusparse-cu11==11.7.5.86 +nvidia-cusparse-cu12==12.1.0.106 +nvidia-nccl-cu11==2.19.3 +nvidia-nccl-cu12==2.19.3 +nvidia-nvjitlink-cu12==12.4.99 +nvidia-nvtx-cu11==11.8.86 +nvidia-nvtx-cu12==12.1.105 +omegaconf==2.3.0 +onnx==1.16.0 +onnxruntime==1.17.1 +onnxsim==0.4.36 +opencv-python==4.9.0.80 +opencv-python-headless==4.9.0.80 +overrides==7.4.0 +packaging==23.2 +pandas==2.1.1 +pandocfilters==1.5.0 +parso==0.8.3 +pexpect==4.8.0 +pickleshare==0.7.5 +Pillow==10.1.0 +pipdeptree==2.16.2 +platformdirs==3.11.0 +pretrainedmodels==0.7.4 +progressbar2==4.4.2 +prometheus-client==0.17.1 +prompt-toolkit==3.0.39 +protobuf==5.26.0 +psutil==5.9.6 +ptyprocess==0.7.0 +pure-eval==0.2.2 +pyarrow==15.0.2 +pycocotools==2.0.7 +pycparser==2.21 +pydantic==2.6.4 +pydantic-settings==2.2.1 +pydantic_core==2.16.3 +Pygments==2.16.1 +pyparsing==3.1.1 +pyproj==3.6.1 +pyquadkey2==0.2.2 +pyshp==2.3.1 +PySocks==1.7.1 +pystac==1.9.0 +pystac-client==0.7.5 +python-box==7.1.1 +python-dateutil==2.8.2 +python-dotenv==1.0.0 +python-json-logger==2.0.7 +python-utils==3.8.2 +pytorch-lightning==2.2.1 +pytz==2023.3.post1 +PyYAML==6.0.1 +pyzmq==25.1.1 +qtconsole==5.4.4 +QtPy==2.4.1 +rasterio==1.3.9 +recommonmark==0.7.1 +referencing==0.30.2 +regex==2023.12.25 +requests==2.31.0 +reverse_geocoder==1.5.1 +rfc3339-validator==0.1.4 +rfc3986-validator==0.1.1 +rich==13.7.1 +rio-tiler==6.2.4 +rpds-py==0.10.6 +Rtree==1.2.0 +ruamel.yaml==0.18.6 +ruamel.yaml.clib==0.2.8 +safetensors==0.4.2 +scikit-image==0.22.0 +scikit-learn==1.4.1.post1 +scipy==1.12.0 +scooby==0.9.2 +segment_anything @ git+https://github.com/facebookresearch/segment-anything.git@6fdee8f2727f4506cfbbe553e23b895e27956588 +segmentation-models-pytorch==0.3.3 +Send2Trash==1.8.2 +shapely==2.0.2 +six==1.16.0 +slidingwindow==0.0.14 +sniffio==1.3.0 +snowballstemmer==2.2.0 +snuggs==1.4.7 +soupsieve==2.5 +Sphinx==7.2.6 +sphinxcontrib-applehelp==1.0.8 +sphinxcontrib-devhelp==1.0.6 +sphinxcontrib-htmlhelp==2.0.5 +sphinxcontrib-jsmath==1.0.1 +sphinxcontrib-qthelp==1.0.7 +sphinxcontrib-serializinghtml==1.1.10 +stack-data==0.6.3 +supermorecado==0.1.2 +supervision==0.6.0 +sympy==1.12 +tdqm==0.0.1 +tensorboardX==2.6.2.2 +terminado==0.17.1 +texttable==1.7.0 +threadpoolctl==3.4.0 +tifffile==2024.2.12 +timm==0.9.2 +tinycss2==1.2.1 +tokenizers==0.15.2 +tomli==2.0.1 +torch==2.2.1 +torchaudio==2.2.1 +torchgeo==0.5.2 +torchmetrics==1.3.2 +torchpack @ git+https://github.com/zhijian-liu/torchpack.git@3a5a9f7ac665444e1eb45942ee3f8fc7ffbd84e5 +torchprofile==0.0.4 +torchvision==0.17.1 +tornado==6.3.3 +tqdm==4.66.1 +traitlets==5.12.0 +traittypes==0.2.1 +transformers==4.38.2 +triton==2.2.0 +types-python-dateutil==2.8.19.14 +typeshed_client==2.5.1 +typing_extensions==4.10.0 +tzdata==2023.3 +uri-template==1.3.0 +urllib3==2.0.7 +wcwidth==0.2.8 +webcolors==1.13 +webencodings==0.5.1 +websocket-client==1.6.4 +whitebox==2.3.1 +whiteboxgui==2.3.0 +widgetsnbextension==4.0.9 +xmltodict==0.13.0 +xyzservices==2023.10.1 +yapf==0.40.2 +yarg==0.1.9 +yarl==1.9.4 +zipp==3.18.1 diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..f3993d8 --- /dev/null +++ b/setup.py @@ -0,0 +1,8 @@ +from setuptools import setup, find_packages + +setup( + name='maxarseg', + version='0.1', + packages=find_packages(where='src'), + package_dir={'': 'src'}, +) \ No newline at end of file diff --git a/shell/seg_tile.sh b/shell/seg_tile.sh new file mode 100644 index 0000000..52630df --- /dev/null +++ b/shell/seg_tile.sh @@ -0,0 +1,4 @@ +python -B -O -m maxarseg.main_seg_event_w_config \ + --out_dir_root "./output/test" \ + --config "./configs/trees_cfg.yaml" \ + --event_ix 1 \ No newline at end of file diff --git a/src/maxarseg/ESAM_segment/segment.py b/src/maxarseg/ESAM_segment/segment.py new file mode 100644 index 0000000..c3d3732 --- /dev/null +++ b/src/maxarseg/ESAM_segment/segment.py @@ -0,0 +1,110 @@ +import numpy as np +import torch + + +def ESAM_from_inputs(original_img_tsr: torch.tensor, #b, c, h, w + input_points: torch.tensor, #b, max_queries, 2, 2 + input_labels: torch.tensor, #b, max_queries, 2 + efficient_sam, + num_parall_queries: int = 50, + device = 'cpu', + empty_cuda_cache = True): + + img_b_tsr = original_img_tsr.div(255) + batch_size, _, input_h, input_w = img_b_tsr.shape + + img_b_tsr = img_b_tsr.to(device) + input_points = input_points.to(device) + input_labels = input_labels.to(device) + + image_embeddings = efficient_sam.get_image_embeddings(img_b_tsr) + + stop = input_points.shape[1] + if stop > 0: #if there is at least a query in a single image in the batch + for i in range(0, stop , num_parall_queries): + start_idx = i + end_idx = min(i + num_parall_queries, stop) + #TODO: check if multimask_output False is faster + predicted_logits, predicted_iou = efficient_sam.predict_masks(image_embeddings, + input_points[:, start_idx: end_idx], + input_labels[:, start_idx: end_idx], + multimask_output=True, + input_h = input_h, + input_w = input_w, + output_h=input_h, + output_w=input_w) + + if i == 0: + #print('predicetd_logits:', predicted_logits.shape) + np_complete_masks = predicted_logits[:,:,0].cpu().detach().numpy() + else: + np_complete_masks = np.concatenate((np_complete_masks, predicted_logits[:,:,0].cpu().detach().numpy()), axis=1) + #TODO: check if empty_cuda_cache False is faster + if empty_cuda_cache: + del predicted_logits, predicted_iou + torch.cuda.empty_cache() + else: #if there are no queries (in any image in the batch) + np_complete_masks = np.ones((batch_size, 0, input_h, input_w)) * float('-inf') #equal to set False on all the mask + + return np_complete_masks #shape (b, masks, h, w) + +def ESAM_from_inputs_fast(original_img_tsr: torch.Tensor, #b, c, h, w + input_points: torch.Tensor, #b, max_queries, 2, 2 + input_labels: torch.Tensor, #b, max_queries, 2 + efficient_sam, + num_tree_boxes, #(b, 1) + num_parall_queries: int = 5, + device = 'cpu'): + + num_tree_boxes = int(num_tree_boxes[0]) + + original_img_tsr = original_img_tsr.div(255) + batch_size, _, input_h, input_w = original_img_tsr.shape + + original_img_tsr = original_img_tsr.to(device) + input_points = input_points.to(device) + input_labels = input_labels.to(device) + with torch.no_grad(): + image_embeddings = efficient_sam.get_image_embeddings(original_img_tsr) + + tree_build_mask = torch.full((2, input_h, input_w), float('-inf'), dtype = torch.float32, device = device) + num_batch_tree_only = num_tree_boxes // num_parall_queries + trees_in_mixed_batch = round(num_parall_queries * (num_tree_boxes/num_parall_queries - num_tree_boxes // num_parall_queries)) + + stop = input_points.shape[1] + + for y, i in enumerate(range(0, stop , num_parall_queries)): + start_idx = i + end_idx = min(i + num_parall_queries, stop) + + with torch.no_grad(): + predicted_logits, predicted_iou = efficient_sam.predict_masks(image_embeddings, + input_points[:, start_idx: end_idx], + input_labels[:, start_idx: end_idx], + multimask_output=True, + input_h = input_h, + input_w = input_w, + output_h=input_h, + output_w=input_w) + + masks = predicted_logits[0,:,0]#.cpu().detach().numpy() # (num_img, prompt, multi, h, w) -> (max_queries, h, w) + + + # poly + # append to geodataframe + + + if y < num_batch_tree_only or input_points[0, start_idx: end_idx].shape[0] == trees_in_mixed_batch: #only trees + tree_build_mask[0] = torch.max(tree_build_mask[0], torch.max(masks, dim=0).values) + elif y > num_batch_tree_only or trees_in_mixed_batch == 0: #only build + tree_build_mask[1] = torch.max(tree_build_mask[1], torch.max(masks, dim=0).values) + else: #trees and build + tree_build_mask[0] = torch.max(tree_build_mask[0], torch.max(masks[:trees_in_mixed_batch], dim=0).values) + tree_build_mask[1] = torch.max(tree_build_mask[1], torch.max(masks[trees_in_mixed_batch:], dim=0).values) + + # filter the -inf values to 0 TODO: move outside + tree_build_mask[0] = torch.where(tree_build_mask[0] == float('-inf'), torch.tensor(0, dtype = torch.float32, device = device), tree_build_mask[0]) + tree_build_mask[1] = torch.where(tree_build_mask[1] == float('-inf'), torch.tensor(0, dtype = torch.float32, device = device), tree_build_mask[1]) + + tree_build_mask = tree_build_mask.cpu().detach().numpy() + return tree_build_mask #shape (b, masks, h, w) \ No newline at end of file diff --git a/src/maxarseg/ESAM_segment/segment_utils.py b/src/maxarseg/ESAM_segment/segment_utils.py new file mode 100644 index 0000000..a4d0291 --- /dev/null +++ b/src/maxarseg/ESAM_segment/segment_utils.py @@ -0,0 +1,326 @@ +import numpy as np +from typing import List, Tuple, Union +from skimage import morphology +from shapely.geometry import Polygon +from scipy.signal.windows import tukey + + +def get_input_pts_and_lbs(tree_boxes_b: List, #list of array of shape (query_img_x, 4) + building_boxes_b: List, + max_detect: int): + input_lbs = [] + input_pts = [] + pts_pad_value = -10 + lbs_pad_value = 0 + for tree_detec, build_detec in zip(tree_boxes_b, building_boxes_b): + tree_build_detect = np.concatenate((tree_detec, build_detec)) #(query_img_x, 4) + num_query_img_x = tree_build_detect.shape[0] + lbs = np.array([[2,3]] * num_query_img_x).reshape(-1,2) #(query_img_x, 2) + + pad_len = max_detect - num_query_img_x + pad_width = ((0,pad_len),(0, 0)) + padded_tree_build_detect = np.pad(tree_build_detect, pad_width, constant_values=pts_pad_value) + img_input_pts = np.expand_dims(padded_tree_build_detect, axis = 0).reshape(-1,2,2) # (max_queries, 2, 2) + input_pts.append(img_input_pts) + + padded_lbs = np.pad(lbs, pad_width, constant_values = lbs_pad_value)# (max_queries, 2) + input_lbs.append(padded_lbs) + + return np.array(input_pts), np.array(input_lbs) # (batch_size, max_queries, 2, 2), (batch_size, max_queries, 2) + + +def discern(all_mask_b: np.array, num_trees4img:np.array, num_build4img: np.array): + """ + Discern the masks of the trees, buildings and padding from the all_mask_b array + Inputs: + all_mask_b: np.array of shape (b, masks, h, w) + num_trees4img: np.array of shape (b,) + num_build4img: np.array of shape (b,) + + Outputs: + tree_mask_b: np.array of shape (b, h, w) + build_mask_b: np.array of shape (b, h, w) + pad_mask_b: np.array of shape (b, h, w) + """ + h, w = all_mask_b.shape[2:] + tree_mask_b = build_mask_b = pad_mask_b = np.full((1, h, w), False) + + all_mask_b = np.greater_equal(all_mask_b, 0) #from logits to bool + + for all_mask, tree_ix, build_ix in zip (all_mask_b, num_trees4img, num_build4img): + #all_mask.shape = (num_mask, h, w) + tree_mask = all_mask[ : tree_ix].any(axis=0) #(h, w) + tree_mask_b = np.concatenate((tree_mask_b, tree_mask[None, ...]), axis=0) #(b, h, w) + + build_mask = all_mask[tree_ix : (tree_ix + build_ix)].any(axis=0) + build_mask_b = np.concatenate((build_mask_b, build_mask[None, ...]), axis=0) + + pad_mask = all_mask[(tree_ix + build_ix) : ].any(axis=0) + pad_mask_b = np.concatenate((pad_mask_b, pad_mask[None, ...]), axis=0) + + return tree_mask_b[1:], build_mask_b[1:], pad_mask_b[1:] #all (b, h, w), slice out the first element + +def discern_mode(all_mask_b: np.array, num_trees4img:np.array, num_build4img: np.array, mode: str = 'bchw'): + """ + Discern the masks of the trees, buildings and padding from the all_mask_b array + Inputs: + all_mask_b: np.array of shape (b, masks, h, w) + num_trees4img: np.array of shape (b,) + num_build4img: np.array of shape (b,) + mode: 'bchw' or 'cbhw'. To specify the output dimension. [batch channel height width] or [channel batch height width] + + Outputs: + out: np.array of shape (b, c, h, w) or (c, b, h, w) + """ + h, w = all_mask_b.shape[2:] + tree_mask_b = build_mask_b = pad_mask_b = np.full((1, h, w), False) + + all_mask_b = np.greater_equal(all_mask_b, 0) #from logits to bool + + + + for all_mask, tree_ix, build_ix in zip (all_mask_b, num_trees4img, num_build4img): + #all_mask.shape = (num_mask, h, w) + tree_mask = all_mask[ : tree_ix].any(axis=0) # Squash the tree masks. Get shape (h, w) + tree_mask_b = np.concatenate((tree_mask_b, tree_mask[None, ...]), axis=0) #(b, h, w) + + build_mask = all_mask[tree_ix : (tree_ix + build_ix)].any(axis=0) # Squash the build masks. Get shape (h, w) + build_mask_b = np.concatenate((build_mask_b, build_mask[None, ...]), axis=0) + + pad_mask = all_mask[(tree_ix + build_ix) : ].any(axis=0) + pad_mask_b = np.concatenate((pad_mask_b, pad_mask[None, ...]), axis=0) + + if mode == 'bchw': + out = np.stack((tree_mask_b[1:], build_mask_b[1:], pad_mask_b[1:]), axis=1) # (b, c, h, w) , slice out the first element of dim 1 + elif mode == 'cbhw': + out = np.stack((tree_mask_b[1:], build_mask_b[1:], pad_mask_b[1:]), axis=0) # (c, b, h, w) , slice out the first element of dim 1 + return out + +def discern_mode_smooth(all_mask_b: np.array, num_trees4img:np.array, num_build4img: np.array, mode: str = 'bchw'): + """ + Discern the masks of the trees, buildings and padding from the all_mask_b array + Inputs: + all_mask_b: np.array of shape (b, masks, h, w) + num_trees4img: np.array of shape (b,) + num_build4img: np.array of shape (b,) + mode: 'bchw' or 'cbhw'. To specify the output dimension. [batch channel height width] or [channel batch height width] + + Outputs: + out: np.array of shape (b, c, h, w) or (c, b, h, w) + """ + h, w = all_mask_b.shape[2:] + tree_mask_b = build_mask_b = pad_mask_b = np.full((1, h, w), float('-inf'), dtype=np.float32) + + for all_mask, tree_ix, build_ix in zip (all_mask_b, num_trees4img, num_build4img): + #all_mask.shape = (num_mask, h, w) + tree_mask = all_mask[ : tree_ix].max(axis=0, initial = float('-inf')) # Squash the tree masks. Get shape (h, w) + tree_mask_b = np.concatenate((tree_mask_b, tree_mask[None, ...]), axis=0) #(b, h, w) + + build_mask = all_mask[tree_ix : (tree_ix + build_ix)].max(axis=0, initial = float('-inf')) # Squash the build masks. Get shape (h, w) + build_mask_b = np.concatenate((build_mask_b, build_mask[None, ...]), axis=0) + + pad_mask = all_mask[(tree_ix + build_ix) : ].max(axis=0, initial = float('-inf')) + pad_mask_b = np.concatenate((pad_mask_b, pad_mask[None, ...]), axis=0) + + if mode == 'bchw': + out = np.stack((tree_mask_b[1:], build_mask_b[1:], pad_mask_b[1:]), axis=1) # (b, c, h, w) , slice out the first element of dim 1 + elif mode == 'cbhw': + out = np.stack((tree_mask_b[1:], build_mask_b[1:], pad_mask_b[1:]), axis=0) # (c, b, h, w) , slice out the first element of dim 1 + return out + +def rmv_mask_b_overlap(overlapping_masks_b: np.array): #(b, c, h, w) + """ + Remove overlapping between the masks. Giving priority according to the inverse of the order of + the masks. + Third (building) mask has priority over second (trees) mask, and so on. + """ + + disjoined_masks_b = np.copy(overlapping_masks_b) + for i in range(overlapping_masks_b.shape[1] - 1): + sum_mask = np.sum(overlapping_masks_b[:,i:], axis=1) + disjoined_masks_b[:,i] = np.where(sum_mask > 1, False, overlapping_masks_b[:, i]) + + return disjoined_masks_b + +#Use the batch version of this function +def rmv_mask_overlap(overlapping_masks: np.array): + """ + Remove overlapping between the masks. Giving priority according to the inverse of the order of + the masks. + Third (building) mask has priority over second (trees) mask, and so on. + Inputs: + overlapping_masks: np.array of shape (c, h, w) + Outputs: + no_overlap_masks: np.array of shape (c, h, w) + """ + no_overlap_masks = np.copy(overlapping_masks) + for i in range(overlapping_masks.shape[0] - 1): + sum_mask = np.sum(overlapping_masks[i:], axis=0) + no_overlap_masks[i] = np.where(sum_mask > 1, False, overlapping_masks[i]) + + return no_overlap_masks + +def write_canvas(canvas: np.array, + patch_masks_b: np.array, + img_ixs: np.array, + stride: int, + total_cols: int) -> np.array: + """ + Write the patch masks in the canvas + Inputs: + canvas: np.array of shape (channel, h_tile, w_tile) + patch_masks_b: np.array of shape (b, channel, h_patch, w_patch) + img_ixs: np.array of shape (b,) + """ + size = patch_masks_b.shape[-1] + #print("img_ixs", img_ixs) + for img_ix, patch_mask in zip(img_ixs, patch_masks_b): + rows_changed = img_ix // total_cols + cols_changed = img_ix % total_cols + inv_base = (canvas.shape[1] - 1 - size) - (stride * rows_changed) + base = (stride * cols_changed) + canva_writable_space = canvas[:, inv_base: inv_base + size, base: base + size].shape[1:] #useful when reached the border of the canva + #print('\nparte di canva', canvas[:, inv_base: inv_base + size, base: base + size].shape) + #print('patch', patch_mask[:, :canva_writable_space[0], :canva_writable_space[1]].shape) + canvas[:, inv_base: inv_base + size, base: base + size] = patch_mask[:, :canva_writable_space[0], :canva_writable_space[1]] + + return canvas + +def write_canvas_geo(canvas: np.array, + patch_masks_b: np.array, + top_lft_indexes: List, + smooth: bool) -> np.array: + """ + Write the patch masks in the canvas. + + Args: + canvas (np.array): The canvas to write the patch masks on. It should have shape (channel, h_tile, w_tile). + patch_masks_b (np.array): The patch masks to be written on the canvas. It should have shape (b, channel, h_patch, w_patch). + top_lft_indexes (List): The top left indexes of each patch mask in the canvas. + smooth (bool): If True, it expects patch_mask to have logits, otherwise it should contain bools. + + Returns: + np.array: The updated canvas with the patch masks written on it. + """ + + size = patch_masks_b.shape[-1] + for patch_mask, top_left_index in zip(patch_masks_b, top_lft_indexes): + I = np.s_[:, top_left_index[0]: top_left_index[0] + size, top_left_index[1]: top_left_index[1] + size] #index var in the canvas where to add the patch + #max_idxs is useful when reached the border of the canva, it contains the height and width that you can write on the canva + max_idxs = canvas[I].shape[1:] + + #print('\nparte di canva', canvas[:, inv_base: inv_base + size, base: base + size].shape) + #print('patch', patch_mask[:, :max_idxs[0], :max_idxs[1]].shape) + if smooth: + canvas[I] = np.maximum(canvas[I], patch_mask[:, :max_idxs[0], :max_idxs[1]]) #element-wise max between the canva and the patch + #elif smooth == 'avg': + # canvas[I] = (canvas[I] + patch_mask[:, :max_idxs[0], :max_idxs[1]]) / 2 #TODO: this is wrong + else: + canvas[I] = patch_mask[:, :max_idxs[0], :max_idxs[1]] + + return canvas + + +def write_canvas_geo_window(canvas: np.ndarray, + weights: np.ndarray, + patch_masks_b: np.ndarray, + top_lft_indexes: List): + """ + Write the patch masks in the canvas. + + Args: + canvas (np.array): The canvas to write the patch masks on. It should have shape (channel, h_tile, w_tile). + weights (np.array): The canvas that contains the weights used to weight the contribution of each patch. It should have shape (channel, h_tile, w_tile). + patch_masks_b (np.array): The patch masks to be written on the canvas. It should have shape (b, channel, h_patch, w_patch). + top_lft_indexes (List): The top left indexes of each patch mask in the canvas. + smooth (bool): If True, it expects patch_mask to have logits, otherwise it should contain bools. + + Returns: + np.array: The updated canvas with the patch masks written on it. + """ + # initialize the window, a 2d array containing a cosine window, of the same size as the patch + # window = np.outer(np.hanning(patch_masks_b.shape[-1]), np.hanning(patch_masks_b.shape[-1])) + + # tukey window + window1d = tukey(patch_masks_b.shape[-1], alpha=0.5) # TODO: make it conditional + window = np.outer(window1d, window1d) + + mask_size = patch_masks_b.shape[-1] + + for patch_mask, top_left_index in zip(patch_masks_b, top_lft_indexes): + c_start_y = max(0, top_left_index[0]) + c_start_x = max(0, top_left_index[1]) + c_end_y = min(canvas.shape[1], top_left_index[0] + mask_size) + c_end_x = min(canvas.shape[2], top_left_index[1] + mask_size) + + m_start_y = max(0, -top_left_index[0]) + m_start_x = max(0, -top_left_index[1]) + m_end_y = m_start_y + (c_end_y - c_start_y) + m_end_x = m_start_x + (c_end_x - c_start_x) + + # if m_start_y != 0 or m_start_x != 0 or m_end_y != mask_size or m_end_x != mask_size: + # pass + + canvas[0, c_start_y: c_end_y, c_start_x: c_end_x] = canvas[0, c_start_y: c_end_y, c_start_x: c_end_x] + patch_mask[0, m_start_y: m_end_y, m_start_x: m_end_x] * window[m_start_y: m_end_y, m_start_x: m_end_x] + canvas[1, c_start_y: c_end_y, c_start_x: c_end_x] = canvas[1, c_start_y: c_end_y, c_start_x: c_end_x] + patch_mask[1, m_start_y: m_end_y, m_start_x: m_end_x] * window[m_start_y: m_end_y, m_start_x: m_end_x] + + weights[c_start_y: c_end_y, c_start_x: c_end_x] = weights[c_start_y: c_end_y, c_start_x: c_end_x] + window[m_start_y: m_end_y, m_start_x: m_end_x] + + return canvas, weights + + + + +def clean_masks(masks: np.ndarray, area_threshold = 80, min_size = 80) -> np.ndarray: + """ + Cleans the input masks by removing small holes and objects. + + Args: + masks (np.array): The input masks to be cleaned. Can be a single mask or a stack of masks. + area_threshold (int, optional): The area threshold for removing small holes. Defaults to 80. + min_size (int, optional): The minimum size for removing small objects. Defaults to 80. + + Returns: + np.array: The cleaned masks. With the same dimensions as the input masks. + """ + single_mask = False + if len(masks.shape) == 2: + single_mask = True + masks = np.expand_dims(masks, axis=0) + + clear_masks = [] + masks_int = masks.astype(np.uint8) + for mask in masks_int: + clear_mask = morphology.remove_small_holes(mask, area_threshold = area_threshold) + clear_mask = morphology.remove_small_objects(clear_mask, min_size = min_size) + #clear_mask = morphology.binary_opening(mask) + #clear_mask = morphology.binary_closing(clear_mask) + clear_masks.append(clear_mask) + + if single_mask: + clear_masks = clear_masks[0] + else: + clear_masks = np.stack(clear_masks, axis=0) + + return clear_masks + +def merge_masks(masks: np.ndarray): + """ + Merges multiple masks into a single mask. + Labels: + road = 0 + tree = 1 + building = 2 + background = 255 + + Args: + masks (list): A list of masks to be merged. + + Returns: + np.ndarray: The merged mask. + """ + merged_mask = np.full_like(masks[0], fill_value=255, dtype=np.uint8) + for i, mask in enumerate(masks): + merged_mask[mask.astype(bool)] = i + + return merged_mask diff --git a/src/maxarseg/SAM_segment/build_segmentation.py b/src/maxarseg/SAM_segment/build_segmentation.py new file mode 100644 index 0000000..198f19b --- /dev/null +++ b/src/maxarseg/SAM_segment/build_segmentation.py @@ -0,0 +1,111 @@ +import pandas as pd +import geopandas as gpd +from shapely.geometry import shape +import numpy as np +import torch + + +def building_gdf(country, csv_path = './metadata/buildings_dataset_links.csv', dataset_crs = None, quiet = False): + """ + Returns a geodataframe with the buildings of the country passed as input. + It downloads the dataset from a link in the dataset-links.csv file. + Coordinates are converted in the crs passed as input. + Inputs: + country: the country of which to download the buildings. Example: 'Tanzania' + root: the root directory of the dataset-links.csv file + dataset_crs: the crs in which to convert the coordinates of the buildings + quiet: if True, it doesn't print anything + """ + dataset_links = pd.read_csv(csv_path) + country_links = dataset_links[dataset_links.Location == country] + #TODO: eventualmente filtrare anche sul quadkey dell evento + if not quiet: + print(f"Found {len(country_links)} links for {country}") + + gdfs = [] + for _, row in country_links.iterrows(): + df = pd.read_json(row.Url, lines=True) + df["geometry"] = df["geometry"].apply(shape) + gdf_down = gpd.GeoDataFrame(df, crs=4326) + gdfs.append(gdf_down) + + gdfs = pd.concat(gdfs) + if dataset_crs is not None: #se inserito il crs del dataset, lo converto + gdfs = gdfs.to_crs(dataset_crs) + return gdfs + + +def rel_polyg_coord(geodf:gpd.GeoDataFrame, + ref_coords:tuple, + res): + """ + Returns the relative coordinates of a polygon w.r.t. a reference bbox. + Goes from absolute geo coords to relative coords in the image. + + Inputs: + geodf: dataframe with polygons in the 'geometry' column + ref_coords: a tuple in the format (minx, miny, maxx, maxy) + res: resolution of the image + Returns: + a list of lists of tuples with the relative coordinates of the bboxes [[(p1_minx1, p1_miny1), (p1_minx2, p1_miny2), ...], [(p2_minx1, p2_miny1), (p2_minx2, p2_miny2), ...], ...] + """ + result = [] + ref_minx, ref_maxy = ref_coords[0], ref_coords[3] #coords of top left corner + + for geom in geodf['geometry']: + x_s, y_s = geom.exterior.coords.xy + rel_x_s = (np.array(x_s) - ref_minx) / res + rel_y_s = (ref_maxy - np.array(y_s)) / res + rel_coords = list(zip(rel_x_s, rel_y_s)) + result.append(rel_coords) + return result + +def segment_buildings(predictor, building_boxes, img4Sam: np.array, use_bbox = True, use_center_points = False): + """ + Segment the buildings in the image using the predictor passed as input. + The image has to be encoded the image before calling this function. + Inputs: + predictor: the predictor to use for the segmentation + building_boxes: a list of tuples containing the building's bounding boxes in formtat (minx, miny, maxx, maxy) = (top left corner, bottom right corner) + img4Sam: the image previously encoded + use_bbox: if True, the bounding boxes are used for the segmentation + use_center_points: if True, the center points of the bounding boxes are used for the segmentation + + Returns: + mask: a np array of shape (1, h, w). The mask is True where there is a building, False elsewhere + bboxes: a list of tuples containing the bounding boxes of the buildings used for the segmentation + #!used_points: a np array of shape (n, 2) where n is the number of buildings. The array contains the center points of the bounding boxes of the buildings in the image + """ + + building_boxes_t = torch.tensor(building_boxes, device=predictor.device) + + transformed_boxes = None + if use_bbox: + transformed_boxes = predictor.transform.apply_boxes_torch(building_boxes_t, img4Sam.shape[:2]) + + transformed_points = None + transformed_point_labels = None + """if use_center_points: #TODO: aggiustare l'utilizzo di punti, al momento non funziona + point_coords = torch.tensor([[(sublist[0] + sublist[2])/2, (sublist[1] + sublist[3])/2] for sublist in building_boxes_t], device=predictor.device) + point_labels = torch.tensor([1] * point_coords.shape[0], device=predictor.device)[:, None] + transformed_points = predictor.transform.apply_coords_torch(point_coords, img4Sam.shape[:2]).unsqueeze(1) + transformed_point_labels = point_labels[:, None]""" + + masks, _, _ = predictor.predict_torch( + point_coords=transformed_points, + point_labels=transformed_point_labels, + boxes=transformed_boxes, + multimask_output=False, + ) + #mask is a tensor (n, 1, h, w) where n = number of mask = numb. of input boxes + mask = np.any(masks.cpu().numpy(), axis = 0) + + used_boxes = None + if use_bbox: + used_boxes = building_boxes + + used_points = None + """if use_center_points: + used_points = point_coords.cpu().numpy()""" + + return mask, used_boxes, used_points #returns all the np array diff --git a/src/maxarseg/SAM_segment/road_segmentation.py b/src/maxarseg/SAM_segment/road_segmentation.py new file mode 100644 index 0000000..e50d950 --- /dev/null +++ b/src/maxarseg/SAM_segment/road_segmentation.py @@ -0,0 +1,196 @@ +import numpy as np +import geopandas as gpd +from shapely.geometry import shape, Polygon, LineString, MultiPoint, Point +from typing import List, Union +from rasterio.features import rasterize +from skimage import morphology + + + +def rel_road_lines(geodf: gpd.GeoDataFrame, + query_bbox_poly: Polygon, + res): + """ + Given a Geodataframe containing Linestrings with geo coords, + returns the relative coordinates of those lines w.r.t. a reference bbox + + Inputs: + geodf: GeoDataFrame containing the Linestring + query_bbox_poly: Polygon of the reference bbox + res: resolution of the image + Returns: + result: list of LineString with the relative coordinates + """ + ref_coords = query_bbox_poly.bounds + ref_minx, ref_maxy = ref_coords[0], ref_coords[3] #coords of top left corner + + result = [] + for line in geodf.geometry: + x_s, y_s = line.coords.xy + + rel_x_s = (np.array(x_s) - ref_minx) / res + rel_y_s = (ref_maxy - np.array(y_s)) / res + rel_coords = list(zip(rel_x_s, rel_y_s)) + line = LineString(rel_coords) + result.append(line) + return result + +def line2points(lines: Union[LineString, List[LineString]], points_dist) -> List[Point]: + """ + Given a single or a list of shapely.LineString, + returns a list of shapely points along all the lines, spaced by points_dist + """ + if not isinstance(lines, list): + lines = [lines] + points = [] + for line in lines: + points.extend([line.interpolate(dist) for dist in np.arange(0, line.length, points_dist)]) + return points + +def get_offset_lines(lines: Union[LineString, List[LineString]], distance=35): + """ + Create two offset lines from a single or list of shapely.LineString at distance = 'distance' + """ + if not isinstance(lines, list): + lines = [lines] + + offset_lines = [] + for line in lines: + for side in [-1, +1]: + offset_lines.append(line.offset_curve(side*distance )) + return offset_lines + +def clear_roads(lines: Union[LineString, List[LineString]], bg_points, distance) -> List[Point]: + """ + Given a list of shapely.LineString and a list of shapely.Point, + remove bg points that may be on the road + """ + candidate_bg_pts = bg_points + final_bg_pts = set(bg_points) + + if not isinstance(lines, list): + lines = [lines] + + for line in lines: + line_space = line.buffer(distance) + for point in candidate_bg_pts: + if line_space.contains(point): + final_bg_pts.discard(point) + + return list(final_bg_pts) + +def rmv_rnd_fraction(points, fraction_to_keep): + """ + Removes a random fraction of the points + """ + np.random.shuffle(points) + points = points[:int(len(points)*fraction_to_keep)] + return points + +def rmv_pts_out_img(points: np.array, sample_size)-> np.array: + """ + Given a np.array of points (n, 2), + removes points outside the image + """ + if len(points) != 0: + points = points[np.logical_and(np.logical_and(points[:, 0] >= 0, points[:, 0] < sample_size), np.logical_and(points[:, 1] >= 0, points[:, 1] < sample_size))] + return points + +def segment_roads(predictor, + road_lines: Union[LineString, List[LineString]], + sample_size, + img4Sam = None, + road_point_dist = 50, + bg_point_dist = 80, + offset_distance = 50, + do_clean_mask = True): + """ + Segment the roads in the image using the predictor passed as input. + If passed as input the image is encoded on the fly, otherwise it has to be encoded before calling this function. + + Inputs: + predictor: the predictor to use for the segmentation + road_lines: a list of shapely.LineString containing the roads + sample_size: the size of the image + img4Sam: the image to encode if not already encoded + road_point_dist: the distance between two points on the road + bg_point_dist: the distance between two points in the road's offset lines + offset_distance: the offset distance + do_clean_mask: if True, the mask is cleaned by removing parts outside the offset lines + + Returns: + final_mask: a np array of shape (1, h, w). The mask is True where there is a road, False elsewhere + final_pt_coords4Sam: a np array of shape (n, 2) where n is the number of points. The array contains the coordinates of the points used for the segmentation + final_labels4Sam: a np array of shape (n,) where n is the number of points. The array contains the labels of the points used for the segmentation + """ + + #Decide if encoding here or outside the function + if img4Sam is not None: + predictor.set_image(img4Sam) + + #initialize an empty mask + final_mask = np.full((sample_size, sample_size), False) + + final_pt_coords4Sam = [] + final_labels4Sam = [] + + if not isinstance(road_lines, list): + road_lines = [road_lines] + + for road in road_lines: + road_pts = line2points(road, road_point_dist) #turn the road into a list of shapely points + np_roads_pts = np.array([list(pt.coords)[0] for pt in road_pts]) #turn the shapely points into a numpy array + np_roads_pts = rmv_pts_out_img(np_roads_pts, sample_size) #remove road points outside the image + np_road_labels = np.array([1]*np_roads_pts.shape[0]) #create the labels for the road points + + bg_lines = get_offset_lines(road, offset_distance) #create two offset lines from the road + bg_pts = line2points(bg_lines, bg_point_dist) #turn the offset lines into a list of shapely points + bg_pts = clear_roads(road_lines, bg_pts, offset_distance - 4) #remove bg points that may be on other roads + np_bg_pts = np.array([list(pt.coords)[0] for pt in bg_pts]) #turn the shapely points into a numpy array + np_bg_pts = rmv_pts_out_img(np_bg_pts, sample_size) #remove road points outside the image + np_bg_labels = np.array([0]*np_bg_pts.shape[0]) #create the labels for the bg points + + if len(np_bg_labels) == 0 or len(np_road_labels) < 2: #if there are no bg_points or 0 or 1 road points skip the road + continue + + pt_coords4Sam = np.concatenate((np_roads_pts, np_bg_pts)) #tmp list + labels4Sam = np.concatenate((np_road_labels, np_bg_labels)) + + final_pt_coords4Sam.extend(pt_coords4Sam.tolist()) #global list + final_labels4Sam.extend(labels4Sam.tolist()) #global list + + mask, _, _ = predictor.predict( + point_coords=pt_coords4Sam, + point_labels=labels4Sam, + multimask_output=False, + ) + final_mask = np.logical_or(final_mask, mask[0]) + + if do_clean_mask: + final_mask = clean_mask(road_lines, final_mask, offset_distance - 10) #TODO: eventualmente aggiungere un parametro per l'additional_cleaning + + return final_mask[np.newaxis, :], np.array(final_pt_coords4Sam), np.array(final_labels4Sam) + +def clean_mask(road_lines: Union[LineString, List[LineString]], + final_mask_2d: np.array, + offset_distance, + additional_cleaning = False): + """ + Clean the mask by removing parts outside the offset lines. + The additional_cleaning parameter is used to remove small holes and small objects from the mask. + """ + + if not isinstance(road_lines, list): + road_lines = [road_lines] + + line_buffers = [line.buffer(offset_distance) for line in road_lines] + buffer_roads = rasterize(line_buffers, out_shape=final_mask_2d.shape) + clear_mask = np.logical_and(final_mask_2d, buffer_roads) + + if additional_cleaning: #TODO: controllare meglio cosa fanno queste funzioni e tunare i parametri + clear_mask = morphology.remove_small_holes(clear_mask, area_threshold=500) + clear_mask = morphology.remove_small_objects(clear_mask, min_size=500) + clear_mask = morphology.binary_opening(clear_mask) + clear_mask = morphology.binary_closing(clear_mask) + + return clear_mask diff --git a/src/maxarseg/SAM_segment/segment_from_boxes.py b/src/maxarseg/SAM_segment/segment_from_boxes.py new file mode 100644 index 0000000..273d8aa --- /dev/null +++ b/src/maxarseg/SAM_segment/segment_from_boxes.py @@ -0,0 +1,54 @@ +import torch +import numpy as np + +def segment_from_boxes(predictor, boxes, img4Sam, use_bbox = True, use_center_points = False): + """ + Segment the buildings in the image using the predictor passed as input. + The image has to be encoded the image before calling this function. + Inputs: + predictor: the predictor to use for the segmentation + building_boxes: a list of tuples or a 2d np.array (b, 4) containing the building's bounding boxes. A single box is in the format (minx, miny, maxx, maxy) = (top left corner, bottom right corner) + img4Sam: the image previously encoded + use_bbox: if True, the bounding boxes are used for the segmentation + use_center_points: if True, the center points of the bounding boxes are used for the segmentation + + Returns: + mask: a np array of shape (1, h, w). The mask is True where there is a building, False elsewhere + bboxes: a list of tuples containing the bounding boxes of the buildings used for the segmentation + #!used_points: a np array of shape (n, 2) where n is the number of buildings. The array contains the center points of the bounding boxes of the buildings in the image + """ + + boxes_t = torch.tensor(boxes, device=predictor.device) + + #if use_bboxes + transformed_boxes = None + if use_bbox: + transformed_boxes = predictor.transform.apply_boxes_torch(boxes_t, img4Sam.shape[:2]) + + #if use points + transformed_points = None + transformed_point_labels = None + if use_center_points: #TODO: aggiustare l'utilizzo di punti, al momeno non funziona + point_coords = torch.tensor([[(sublist[0] + sublist[2])/2, (sublist[1] + sublist[3])/2] for sublist in building_boxes_t], device=predictor.device) + point_labels = torch.tensor([1] * point_coords.shape[0], device=predictor.device)[:, None] + transformed_points = predictor.transform.apply_coords_torch(point_coords, img4Sam.shape[:2]).unsqueeze(1) + transformed_point_labels = point_labels[:, None] + + masks, _, _ = predictor.predict_torch( + point_coords=transformed_points, + point_labels=transformed_point_labels, + boxes=transformed_boxes, + multimask_output=False, + ) + #mask è un tensore di dimensione (n, 1, h, w) dove n è il numero di maschere (=numero di box passate) + mask = np.any(masks.cpu().numpy(), axis = 0) + + used_boxes = None + if use_bbox: + used_boxes = boxes + + used_points = None + if use_center_points: + used_points = point_coords.cpu().numpy() + + return mask, used_boxes, used_points #returna tutti np array diff --git a/src/maxarseg/__init__.py b/src/maxarseg/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/maxarseg/assemble/build.py b/src/maxarseg/assemble/build.py new file mode 100644 index 0000000..5f7c12a --- /dev/null +++ b/src/maxarseg/assemble/build.py @@ -0,0 +1,51 @@ +from pathlib import Path +import sys +import geopandas as gpd +from maxarseg.build import geoDatasets +from pyquadkey2 import quadkey +import pandas as pd +from shapely import geometry +import json +from typing import List, Tuple, Union +import time +import os +from torchgeo.datasets import stack_samples +from maxarseg import samplers +from torch.utils.data import DataLoader +from tqdm import tqdm +import numpy as np +from maxarseg.segment import segment +import torch +import rasterio +from rasterio.features import rasterize + +########################### +# Retrieve bbox coordinates +########################### + +def old_get_bbox_roads(mosaic_bbox: Union[List[Tuple], Tuple[Tuple]], region_name, roads_root = '/nfs/projects/overwatch/maxar-segmentation/microsoft-roads'): + """ + Get a gdf containing the roads that intersect the mosaic_bbox. + Input: + mosaic_bbox: Bounding box of the mosaic in format (lon, lat). Example: ((-16.5, 13.5), (-15.5, 14.5)) + region_name: Name of the region. Example: 'AfricaWest-Full' + roads_root: Root directory of the roads datasets + """ + if region_name[-4:] != '.tsv': + region_name = region_name + '.tsv' + + roads_root = Path(roads_root) + road_df = pd.read_csv(roads_root/region_name, names =['country', 'geometry'], sep='\t') + road_df['geometry'] = road_df['geometry'].apply(json.loads).apply(lambda d: geometry.shape(d.get('geometry'))) + road_gdf = gpd.GeoDataFrame(road_df, crs=4326) + + (minx, miny), (maxx, maxy) = mosaic_bbox + vertices = [(minx, miny), (maxx, miny), (maxx, maxy), (minx, maxy), (minx, miny)] #lon lat + query_bbox_poly = geometry.Polygon(vertices) + + hits = road_gdf.geometry.intersects(query_bbox_poly) + + return road_gdf[hits] + + + diff --git a/src/maxarseg/assemble/delimiters.py b/src/maxarseg/assemble/delimiters.py new file mode 100644 index 0000000..ff18994 --- /dev/null +++ b/src/maxarseg/assemble/delimiters.py @@ -0,0 +1,102 @@ +from pathlib import Path +import sys +import geopandas as gpd +from typing import List, Tuple, Union +from maxarseg.assemble import names +import pyproj +import glob + +def get_mosaic_bbox(event_name, mosaic_name, path_mosaic_metatada = './metadata/from_github_maxar_metadata/datasets', extra_mt = 0, return_proj_coords = False): + """ + Get the bbox of a mosaic. It return the coordinates of the bottom left and top right corners. + Input: + event_name: Example: 'Gambia-flooding-8-11-2022' + mosaic_name: It could be an element of the output of get_mosaics_names(). Example: '104001007A565700' + path_mosaic_metatada: Path to the folder containing the geojson + extra_mt: Extra meters added to all bbox sides. The center of the bbox remanis the same. (To be sure all elements are included) + return_proj_coords: If True, it returns the coordinates in the projection of the mosaic. + Output: + pair of cordinates in format (lon, lat) or (x, y) if return_proj_coords is True + """ + path_mosaic_metatada = Path(path_mosaic_metatada) + file_name = mosaic_name + '.geojson' + geojson_path = path_mosaic_metatada / event_name / file_name + try: + gdf = gpd.read_file(geojson_path) + except: + file_pattern = str(path_mosaic_metatada / event_name /mosaic_name) + '*inv.geojson' + file_list = glob.glob(f"{file_pattern}") + assert len(file_list) == 1, f"Found {len(file_list)} files with pattern {file_pattern}. Expected 1 file." + gdf = gpd.read_file(file_list[0]) + + minx = sys.maxsize + miny = sys.maxsize + maxx = 0 + maxy = 0 + + for _, row in gdf.iterrows(): + tmp_minx, tmp_miny, tmp_maxx, tmp_maxy = [float(el) for el in row['proj:bbox'].split(',')] + if tmp_minx < minx: + minx = tmp_minx + if tmp_miny < miny: + miny = tmp_miny + if tmp_maxx > maxx: + maxx = tmp_maxx + if tmp_maxy > maxy: + maxy = tmp_maxy + + #enlarge bbox + minx -= (extra_mt/2) + miny -= (extra_mt/2) + maxx += (extra_mt/2) + maxy += (extra_mt/2) + if not return_proj_coords: + source_crs = gdf['proj:epsg'].values[0] + target_crs = pyproj.CRS('EPSG:4326') + transformer = pyproj.Transformer.from_crs(source_crs, target_crs) + + bott_left_lat, bott_left_lon = transformer.transform(minx, miny) + top_right_lat, top_right_lon = transformer.transform(maxx, maxy) + + return ((bott_left_lon, bott_left_lat), (top_right_lon, top_right_lat)), gdf['proj:epsg'].values[0] + + return ((minx, miny), (maxx, maxy)), gdf['proj:epsg'].values[0] + +def get_event_bbox(event_name, extra_mt = 0, when = None, return_proj_coords = False): + + minx = sys.maxsize + miny = sys.maxsize + maxx = 0 + maxy = 0 + + crs_set = set() + first_crs = None + for mosaic_name in names.get_mosaics_names(event_name, when = when): + ((tmp_minx, tmp_miny), (tmp_maxx, tmp_maxy)), crs = get_mosaic_bbox(event_name, mosaic_name, extra_mt = extra_mt, return_proj_coords = True) + first_crs = crs if first_crs is None else first_crs + transformer = pyproj.Transformer.from_crs(crs, first_crs) + tmp_minx, tmp_miny = transformer.transform(tmp_minx, tmp_miny) + tmp_maxx, tmp_maxy = transformer.transform(tmp_maxx, tmp_maxy) + + crs_set.add(crs) + if tmp_minx < minx: + minx = tmp_minx + if tmp_miny < miny: + miny = tmp_miny + if tmp_maxx > maxx: + maxx = tmp_maxx + if tmp_maxy > maxy: + maxy = tmp_maxy + + if not return_proj_coords: + + source_crs = first_crs #list(crs_set)[0] + target_crs = pyproj.CRS('EPSG:4326') + transformer = pyproj.Transformer.from_crs(source_crs, target_crs) + + bott_left_lat, bott_left_lon = transformer.transform(minx, miny) + top_right_lat, top_right_lon = transformer.transform(maxx, maxy) + + return (bott_left_lon, bott_left_lat), (top_right_lon, top_right_lat) + + return (minx, miny), (maxx, maxy) \ No newline at end of file diff --git a/src/maxarseg/assemble/filter.py b/src/maxarseg/assemble/filter.py new file mode 100644 index 0000000..67cf2e8 --- /dev/null +++ b/src/maxarseg/assemble/filter.py @@ -0,0 +1,41 @@ +from shapely import geometry +import geopandas as gpd +from typing import List, Tuple, Union + +def filter_gdf_w_bbox(gbl_gdf: gpd.GeoDataFrame, bbox: Union[List[Tuple], Tuple[Tuple]]) -> gpd.GeoDataFrame: + """ + Filter a geodataframe with a bbox. + Input: + gbl_gdf: the geodataframe to be filtered + mosaic_bbox: Bounding box of the mosaic in format (lon, lat). Example: ((-16.5, 13.5), (-15.5, 14.5)) + Output: + a filtered geodataframe + + """ + (minx, miny), (maxx, maxy) = bbox + vertices = [(minx, miny), (maxx, miny), (maxx, maxy), (minx, maxy), (minx, miny)] #lon lat + query_bbox_poly = geometry.Polygon(vertices) + + hits = gbl_gdf.geometry.intersects(query_bbox_poly) + #TODO: magari funziona anche... + #hits = gbl_gdf.sindex.query(query_bbox_poly) + + return gbl_gdf[hits] + +def maybe_faster_filter_gdf_w_bbox(gbl_gdf: gpd.GeoDataFrame, bbox: Union[List[Tuple], Tuple[Tuple]]) -> gpd.GeoDataFrame: + """ + Filter a geodataframe with a bbox. + Input: + gbl_gdf: the geodataframe to be filtered + mosaic_bbox: Bounding box of the mosaic in format (lon, lat). Example: ((-16.5, 13.5), (-15.5, 14.5)) + Output: + a filtered geodataframe + + """ + (minx, miny), (maxx, maxy) = bbox + vertices = [(minx, miny), (maxx, miny), (maxx, maxy), (minx, maxy), (minx, miny)] #lon lat + query_bbox_poly = geometry.Polygon(vertices) + + hits = gbl_gdf.sindex.query(query_bbox_poly) + + return gbl_gdf[hits] \ No newline at end of file diff --git a/src/maxarseg/assemble/g_build_utils.py b/src/maxarseg/assemble/g_build_utils.py new file mode 100644 index 0000000..107eca2 --- /dev/null +++ b/src/maxarseg/assemble/g_build_utils.py @@ -0,0 +1,111 @@ +import os +import shapely +import geopandas as gpd +from typing import Tuple, Optional + +import functools +import glob +import gzip +import multiprocessing +import os +import shutil +import tempfile +from typing import List, Optional, Tuple + +import geopandas as gpd +from IPython import display +import pandas as pd +#import s2geometry as s2 +import s2cell as s2 +import shapely +#import tensorflow as tf +import tqdm.notebook + +data_type = 'polygons' +BUILDING_DOWNLOAD_PATH = (f'gs://open-buildings-data/v3/{data_type}_s2_level_6_gzip_no_header') + +def get_filename_and_region_dataframe(your_own_wkt_polygon: str) -> Tuple[str, gpd.geodataframe.GeoDataFrame]: + """Returns output filename and a geopandas dataframe with one region row.""" + data_type = 'polygons' + filename = f'open_buildings_v3_{data_type}_your_own_wkt_polygon.csv.gz' + region_df = gpd.GeoDataFrame(geometry=gpd.GeoSeries.from_wkt([your_own_wkt_polygon]), crs='EPSG:4326') + if not isinstance(region_df.iloc[0].geometry, shapely.geometry.polygon.Polygon) and \ + not isinstance(region_df.iloc[0].geometry, shapely.geometry.multipolygon.MultiPolygon): + raise ValueError("`your_own_wkt_polygon` must be a POLYGON or MULTIPOLYGON.") + print(f'Preparing your_own_wkt_polygon.') + return filename, region_df + +def get_bounding_box_s2_covering_tokens(region_geometry: shapely.geometry.base.BaseGeometry) -> List[str]: + region_bounds = region_geometry.bounds + s2_lat_lng_rect = s2.S2LatLngRect_FromPointPair(s2.S2LatLng_FromDegrees(region_bounds[1], region_bounds[0]), + s2.S2LatLng_FromDegrees(region_bounds[3], region_bounds[2])) + coverer = s2.S2RegionCoverer() + # NOTE: Should be kept in-sync with s2 level in BUILDING_DOWNLOAD_PATH. + coverer.set_fixed_level(6) + coverer.set_max_cells(1000000) + return [cell.ToToken() for cell in coverer.GetCovering(s2_lat_lng_rect)] + +def s2_token_to_shapely_polygon(s2_token: str) -> shapely.geometry.polygon.Polygon: + s2_cell = s2.S2Cell(s2.S2CellId_FromToken(s2_token, len(s2_token))) + coords = [] + for i in range(4): + s2_lat_lng = s2.S2LatLng(s2_cell.GetVertex(i)) + coords.append((s2_lat_lng.lng().degrees(), s2_lat_lng.lat().degrees())) + return shapely.geometry.Polygon(coords) + +def download_s2_token(s2_token: str, region_df: gpd.geodataframe.GeoDataFrame) -> Optional[str]: + """Downloads the matching CSV file with polygons for the `s2_token`. + + NOTE: Only polygons inside the region are kept. + NOTE: Passing output via a temporary file to reduce memory usage. + + Args: + s2_token: S2 token for which to download the CSV file with building polygons. + The S2 token should be at the same level as the files in BUILDING_DOWNLOAD_PATH. + region_df: A geopandas dataframe with only one row that contains the region for which to keep polygons. + + Returns: + Either filepath which contains a gzipped CSV without header for the `s2_token` subfiltered to only contain building polygons inside the region + or None which means that there were no polygons inside the region for this `s2_token`. + """ + s2_cell_geometry = s2_token_to_shapely_polygon(s2_token) + region_geometry = region_df.iloc[0].geometry + prepared_region_geometry = shapely.prepared.prep(region_geometry) + # If the s2 cell doesn't intersect the country geometry at all then we can + # know that all rows would be dropped so instead we can just return early. + if not prepared_region_geometry.intersects(s2_cell_geometry): + return None + try: + # Using tf.io.gfile.GFile gives better performance than passing the GCS path + # directly to pd.read_csv. + #with tf.io.gfile.GFile(os.path.join(BUILDING_DOWNLOAD_PATH, f'{s2_token}_buildings.csv.gz'), 'rb') as gf: + with open(os.path.join(BUILDING_DOWNLOAD_PATH, f'{s2_token}_buildings.csv.gz'), 'rb') as gf: + # If the s2 cell is fully covered by country geometry then can skip + # filtering as we need all rows. + if prepared_region_geometry.covers(s2_cell_geometry): + with tempfile.NamedTemporaryFile(mode='w+b', delete=False) as tmp_f: + shutil.copyfileobj(gf, tmp_f) + return tmp_f.name + # Else take the slow path. + # NOTE: We read in chunks to save memory. + csv_chunks = pd.read_csv(gf, chunksize=2000000, dtype=object, compression='gzip', header=None) + tmp_f = tempfile.NamedTemporaryFile(mode='w+b', delete=False) + tmp_f.close() + for csv_chunk in csv_chunks: + points = gpd.GeoDataFrame(geometry=gpd.points_from_xy(csv_chunk[1], csv_chunk[0]), crs='EPSG:4326') + # sjoin 'within' was faster than using shapely's 'within' directly. + points = gpd.sjoin(points, region_df, predicate='within') + csv_chunk = csv_chunk.iloc[points.index] + csv_chunk.to_csv( + tmp_f.name, + mode='ab', + index=False, + header=False, + compression={ + 'method': 'gzip', + 'compresslevel': 1 + }) + return tmp_f.name + + except FileNotFoundError: + return None \ No newline at end of file diff --git a/src/maxarseg/assemble/gen_gdf.py b/src/maxarseg/assemble/gen_gdf.py new file mode 100644 index 0000000..2c77e9c --- /dev/null +++ b/src/maxarseg/assemble/gen_gdf.py @@ -0,0 +1,123 @@ +import pandas as pd +from pathlib import Path +import geopandas as gpd +from pyquadkey2 import quadkey +import pandas as pd +from shapely import geometry +import json +from maxarseg import assemble +import time +from shapely.wkt import loads + + +def intersecting_qks(bott_left_lon_lat: tuple, top_right_lon_lat: tuple, min_level=7, max_level=9): + """ + Get the quadkeys that intersect the given bbox. + Input: + bott_left_lon_lat: Tuple with the coordinates of the bottom left corner. Example: (-16.5, 13.5) + top_right_lon_lat: Tuple with the coordinates of the top right corner. Example: (-15.5, 14.5) + min_level: Minimum level of the quadkeys + max_level: Maximum level of the quadkeys + """ + bott_left_lon, bott_left_lat = bott_left_lon_lat + top_right_lon, top_right_lat = top_right_lon_lat + qk_bott_left = quadkey.from_geo((bott_left_lat, bott_left_lon), level=max_level) #lat, lon + qk_top_right = quadkey.from_geo((top_right_lat, top_right_lon), level=max_level) #lat, lon + hits = qk_bott_left.difference(qk_top_right) + candidate_hits = set() + + for hit in hits: + current_qk = hit + for _ in range(min_level, max_level): + current_qk = current_qk.parent() + candidate_hits.add(current_qk) + hits.extend(candidate_hits) + return [int(str(hit)) for hit in hits] + +def qk_building_gdf(qk_list, csv_path = 'metadata/buildings_dataset_links.csv', dataset_crs = None, quiet = False): + """ + Returns a geodataframe with the buildings in the quadkeys given as input. + It downloads the dataset from a link in the dataset-links.csv file. + Coordinates are converted in the crs passed as input. + + Inputs: + qk_list: the list of quadkeys to look for in the csv + root: the root directory of the dataset-links.csv file + dataset_crs: the crs in which to convert the coordinates of the buildings + quiet: if True, it doesn't print anything + """ + building_links_df = pd.read_csv(csv_path) + country_links = building_links_df[building_links_df['QuadKey'].isin(qk_list)] + + if not quiet: + print(f"\nBuildings: found {len(country_links)} links matching: {qk_list}") + + if len(country_links) == 0: + print("MS-Buildings: No buildings for this region") + return gpd.GeoDataFrame() + gdfs = [] + for _, row in country_links.iterrows(): + df = pd.read_json(row.Url, lines=True) + df["geometry"] = df["geometry"].apply(geometry.shape) + gdf_down = gpd.GeoDataFrame(df, crs=4326) + gdfs.append(gdf_down) + + gdfs = pd.concat(gdfs) + if dataset_crs is not None: #if the crs is passed, convert the coordinates + gdfs = gdfs.to_crs(dataset_crs) + return gdfs + +def google_building_gdf(event_name, bbox): + """ + Generate a GeoDataFrame containing Google building data filtered by a bounding box. + + Args: + - event_name (str): The name of the event. + - bbox (tuple): A tuple representing the bounding box coordinates in the format (minx, miny, maxx, maxy). + + Returns: + - gdf_filtered (GeoDataFrame): A GeoDataFrame containing the filtered Google building data. + """ + root = '/nfs/projects/overwatch/maxar-segmentation/google-open-buildings' + f_name = 'open_buildings_v3_'+ event_name + '.csv' + file_path = Path(root) / f_name + df = pd.read_csv(file_path) #TODO: leggere solo le colonne necessarie e filtrare magari su confidence boxes + df['geometry'] = df['geometry'].apply(loads) + gdf = gpd.GeoDataFrame(df, geometry='geometry', crs=4326) + gdf_filtered = assemble.filter.filter_gdf_w_bbox(gdf, bbox) + return gdf_filtered + +def get_region_road_gdf(region_name, roads_root = '/nfs/projects/overwatch/maxar-segmentation/microsoft-roads'): + #TODO: cercare di velocizzare la lettura dei dati delle strade + """ + Get a gdf containing the roads of a region. + Input: + region_name: Name of the region. Example: 'AfricaWest-Full' + roads_root: Root directory of the roads datasets + """ + start_time = time.time() + print(f'Roads: reading roads for the whole {region_name} region') + if region_name[-4:] != '.tsv': + region_name = region_name + '.tsv' + + def custom_json_loads(s): + try: + return geometry.shape(json.loads(s)['geometry']) + except: + return geometry.LineString() + + roads_root = Path(roads_root) + if region_name != 'USA.tsv': + print('Roads: not in USA. Region name:', region_name) + region_road_df = pd.read_csv(roads_root/region_name, names =['country', 'geometry'], sep='\t') + else: + print('is USA:', region_name) + region_road_df = pd.read_csv(roads_root/region_name, names =['geometry'], sep='\t') + #region_road_df['geometry'] = region_road_df['geometry'].apply(json.loads).apply(lambda d: geometry.shape(d.get('geometry'))) + #slightly faster + region_road_df['geometry'] = region_road_df['geometry'].apply(custom_json_loads) + region_road_gdf = gpd.GeoDataFrame(region_road_df, crs=4326) + end_time = time.time() + elapsed_time = end_time - start_time + print(f"Elapsed time for reading roads: {elapsed_time:.2f} seconds") + return region_road_gdf \ No newline at end of file diff --git a/src/maxarseg/assemble/holders.py b/src/maxarseg/assemble/holders.py new file mode 100644 index 0000000..7231b25 --- /dev/null +++ b/src/maxarseg/assemble/holders.py @@ -0,0 +1,768 @@ +#Generic +from pathlib import Path +from tqdm import tqdm +import threading + +import os +import sys +from time import time, perf_counter +import numpy as np +import rasterio +from rasterio.features import rasterize +import warnings +import torch +from torch.utils.data import DataLoader +from torchgeo.datasets import stack_samples +import torchvision +import geopandas as gpd +from typing import Tuple +import matplotlib.pyplot as plt +from shapely import geometry +import json +import pandas as pd + +#My functions +from maxarseg.assemble import delimiters, filter, gen_gdf, names +from maxarseg.ESAM_segment import segment, segment_utils +from maxarseg.samplers import samplers, samplers_utils +from maxarseg.geo_datasets import geoDatasets +from maxarseg.configs import SegmentConfig, DetectConfig +from maxarseg.detect import detect, detect_utils +from maxarseg import output +from maxarseg import plotting_utils + +#GroundingDino +from groundingdino.util.inference import load_model as GD_load_model +from groundingdino.util.inference import predict as GD_predict + +#Deep forest +from deepforest import main + +#esam +from maxarseg.efficient_sam.build_efficient_sam import build_efficient_sam_vitt + +# Ignore all warnings +warnings.filterwarnings('ignore') + + +class Mosaic: + def __init__(self, + name, + event + ): + + #Mosaic + self.name = name + self.event = event + self.bbox, self.crs = delimiters.get_mosaic_bbox(self.event.name, + self.name, + self.event.maxar_metadata_path, + extra_mt=1000) + + self.when = list((self.event.maxar_root / self.event.name).glob('**/*'+self.name))[0].parts[-2] + self.tiles_paths = list((self.event.maxar_root / self.event.name / self.when / self.name).glob('*.tif')) + self.tiles_num = len(self.tiles_paths) + + #Check if img is bw + with rasterio.open(self.tiles_paths[0]) as src: + num_bands = src.count + if num_bands == 1: + print(f'Image {self.tiles_paths[0]} is in black and white', ) + self.is_rgb = False + else: + self.is_rgb = True + + #Roads + self.road_gdf = None + self.proj_road_gdf = None + self.road_num = None + + #Buildings + self.build_gdf = None + self.proj_build_gdf = None + self.sindex_proj_build_gdf = None + self.build_num = None + + #models + self.GD_model = None + self.DF_model = None + self.ESAM_model = None + + def __str__(self) -> str: + return self.name + + def set_road_gdf(self): + if self.event.road_gdf is None: + self.event.set_road_gdf() + + self.road_gdf = filter.filter_gdf_w_bbox(self.event.road_gdf, self.bbox) + self.proj_road_gdf = self.road_gdf.to_crs(self.crs) + self.road_num = len(self.road_gdf) + print(f'Roads in {self.name} mosaic: {self.road_num}') + + def set_build_gdf(self): + qk_hits = gen_gdf.intersecting_qks(*self.bbox) + self.build_gdf = gen_gdf.qk_building_gdf(qk_hits, csv_path = self.event.buildings_ds_links_path) + + if len(self.build_gdf) == 0: #here use google buildings + try: + self.build_gdf = gen_gdf.google_building_gdf(event_name=self.event.name, bbox=self.bbox) + if len(self.build_gdf) == 0: + self.build_gdf = None + self.proj_build_gdf = None + print('No buildings found for this mosaic either in Ms Buildings or in Google Open Buildings') + return False + except Exception as e: + print('No buildings found for this mosaic either in Ms Buildings or in Google Open Buildings') + return False + + self.proj_build_gdf = self.build_gdf.to_crs(self.crs) + self.sindex_proj_build_gdf = self.proj_build_gdf.sindex + + # Method no more used + def seg_road_tile(self, tile_path, aoi_mask) -> np.ndarray: + with rasterio.open(tile_path) as src: + transform = src.transform + tile_h = src.height + tile_w = src.width + tile_shape = (tile_h, tile_w) + + cfg = self.event.cfg + #aoi_mask = rasterize(tile_aoi_gdf.geometry, out_shape = tile_shape, fill=False, default_value=True, transform = transform) + + query_bbox_poly = samplers_utils.path_2_tile_aoi(tile_path) + road_lines = self.proj_road_gdf[self.proj_road_gdf.geometry.intersects(query_bbox_poly)] + + if len(road_lines) != 0: + buffered_lines = road_lines.geometry.buffer(cfg.get('segmentation/roads/road_width_mt')) + road_mask = rasterize(buffered_lines, out_shape=(tile_h, tile_w), transform=transform) + road_mask = np.where(aoi_mask, road_mask, False) + else: + print('No roads') + road_mask = np.zeros((tile_h, tile_w)) + return road_mask #shape: (h, w) + + def polyg_road_tile(self, tile_aoi_gdf: gpd.GeoDataFrame) -> gpd.GeoSeries: + cfg = self.event.cfg + road_lines = samplers_utils.filter_road_gdf_vs_aois_gdf(self.proj_road_gdf, tile_aoi_gdf) + if len(road_lines) != 0: + buffered_lines = road_lines.geometry.buffer(cfg.get('segmentation/roads/road_width_mt')) + intersected_buffered_lines_ser = samplers_utils.intersection_road_gdf_vs_aois_gdf(buffered_lines, tile_aoi_gdf) + else : + print('No roads') + intersected_buffered_lines_ser = gpd.GeoSeries() + return intersected_buffered_lines_ser + + def detect_trees_tile_DeepForest(self, tile_path) -> Tuple[np.ndarray, ...]: + cfg = self.event.cfg + if self.DF_model is None: + self.DF_model = main.deepforest(config_args = { 'devices' : cfg.get('models/df/device'), + 'retinanet': {'score_thresh': cfg.get('models/df/box_threshold')}, + 'accelerator': 'cuda', + 'batch_size': cfg.get('models/df/bs')}) + self.DF_model.use_release() + + boxes_df = self.DF_model.predict_tile(tile_path, + return_plot = False, + patch_size = cfg.get('models/df/size'), + patch_overlap = cfg.get('models/df/patch_overlap')) + + + boxes = boxes_df.iloc[:, :4].values + score = boxes_df['score'].values + + return boxes, score + + def noTGeo_detect_trees_tile_GD(self, tile_path, tile_aoi_gdf: gpd.GeoDataFrame, aoi_mask) -> Tuple[np.ndarray, np.ndarray]: + cfg = self.event.cfg + #load model + model = GD_load_model(cfg.get('models/gd/config_file_path'), cfg.get('models/gd/weight_path')).to(cfg.get('models/gd/device')) + print('\n- GD model device:', next(model.parameters()).device) + + dataset = geoDatasets.SingleTileDataset(str(tile_path), tile_aoi_gdf, aoi_mask) + sampler = samplers.SinglePatchSampler(dataset, patch_size=cfg.get('models/esam/size'), stride=cfg.get('models/esam/stride')) + dataloader = DataLoader(dataset, sampler=sampler, collate_fn=geoDatasets.single_sample_collate_fn) + + glb_tile_tree_boxes = torch.empty(0, 4) + all_logits = torch.empty(0) + + for batch in tqdm(dataloader, total = len(dataloader), desc="Detecting Trees with GDino"): + img_b = batch['image'].permute(0,2,3,1).numpy().astype('uint8') + + for img, img_top_left_index in zip(img_b, batch['top_lft_index']): + image_transformed = detect_utils.GD_img_load(img) + tree_boxes, logits, phrases = GD_predict(model, + image_transformed, + cfg.get('models/gd/text_prompt'), + cfg.get('models/gd/box_threshold'), + cfg.get('models/gd/text_threshold'), + device = cfg.get('models/gd/device')) + + rel_xyxy_tree_boxes = detect_utils.GDboxes2SamBoxes(tree_boxes, img_shape = cfg.get('models/gd/size')) + top_left_xy = np.array([img_top_left_index[1], #from an index to xyxy + img_top_left_index[0], + img_top_left_index[1], + img_top_left_index[0]]) + + #turn boxes from patch xyxy coords to global xyxy coords + glb_xyxy_tree_boxes = rel_xyxy_tree_boxes + top_left_xy + + glb_tile_tree_boxes = np.concatenate((glb_tile_tree_boxes, glb_xyxy_tree_boxes)) + all_logits = np.concatenate((all_logits, logits)) + + #del model and free GPU + #TODO: if enough space in GPU, keep the model loaded + del model + + return glb_tile_tree_boxes, all_logits + + def detect_trees_tile_GD(self, tile_path, tile_aoi_gdf: gpd.GeoDataFrame, aoi_mask) -> Tuple[np.ndarray, np.ndarray]: + cfg = self.event.cfg + #load model + model = GD_load_model(cfg.get('models/gd/config_file_path'), cfg.get('models/gd/weight_path')).to(cfg.get('models/gd/device')) + print('\n- GD model device:', next(model.parameters()).device) + + dataset = geoDatasets.MxrSingleTileNoEmpty(str(tile_path), tile_aoi_gdf, aoi_mask=aoi_mask) + sampler = samplers.BatchGridGeoSampler(dataset, batch_size=cfg.get('models/gd/bs'), size=cfg.get('models/gd/size'), stride=cfg.get('models/gd/stride')) + dataloader = DataLoader(dataset , batch_sampler=sampler, collate_fn=stack_samples) + + glb_tile_tree_boxes = torch.empty(0, 4) + all_logits = torch.empty(0) + + for batch in tqdm(dataloader, total = len(dataloader), desc="Detecting Trees with GDino"): + img_b = batch['image'].permute(0,2,3,1).numpy().astype('uint8') + + for img, img_top_left_index in zip(img_b, batch['top_lft_index']): + image_transformed = detect_utils.GD_img_load(img) + tree_boxes, logits, phrases = GD_predict(model, + image_transformed, + cfg.get('models/gd/text_prompt'), + cfg.get('models/gd/box_threshold'), + cfg.get('models/gd/text_threshold'), + device = cfg.get('models/gd/device')) + + rel_xyxy_tree_boxes = detect_utils.GDboxes2SamBoxes(tree_boxes, img_shape = cfg.get('models/gd/size')) + top_left_xy = np.array([img_top_left_index[1], #from an index to xyxy + img_top_left_index[0], + img_top_left_index[1], + img_top_left_index[0]]) + + #turn boxes from patch xyxy coords to global xyxy coords + glb_xyxy_tree_boxes = rel_xyxy_tree_boxes + top_left_xy + + glb_tile_tree_boxes = np.concatenate((glb_tile_tree_boxes, glb_xyxy_tree_boxes)) + all_logits = np.concatenate((all_logits, logits)) + + #del model and free GPU + #TODO: if enough space in GPU, keep the model loaded + del model + + return glb_tile_tree_boxes, all_logits + + def detect_trees_tile(self, tile_path, tile_aoi_gdf, aoi_mask, georef = True): + with rasterio.open(tile_path) as src: + to_xy = src.xy + crs = src.crs + + cfg = self.event.cfg + if cfg.get('detection/trees/use_DF'): + deepForest_glb_tile_tree_boxes, deepForest_scores = self.detect_trees_tile_DeepForest(tile_path) + if cfg.get('detection/trees/use_GD'): + GD_glb_tile_tree_boxes, GD_scores = self.noTGeo_detect_trees_tile_GD(tile_path, tile_aoi_gdf, aoi_mask) + + if cfg.get('detection/trees/use_DF') and cfg.get('detection/trees/use_GD'): + glb_tile_tree_boxes = np.concatenate((GD_glb_tile_tree_boxes, deepForest_glb_tile_tree_boxes)) + glb_tile_tree_scores = np.concatenate((GD_scores, deepForest_scores)) + elif cfg.get('detection/trees/use_DF'): + glb_tile_tree_boxes = deepForest_glb_tile_tree_boxes + glb_tile_tree_scores = deepForest_scores + elif cfg.get('detection/trees/use_GD'): + glb_tile_tree_boxes = GD_glb_tile_tree_boxes + glb_tile_tree_scores = GD_scores + + print('Number of tree boxes before filtering: ', len(glb_tile_tree_boxes)) + + keep_ix_box_area = detect_utils.filter_on_box_area_mt2(glb_tile_tree_boxes, + max_area_mt2 = cfg.get('detection/trees/max_area_boxes_mt2'), + box_format = 'xyxy') + glb_tile_tree_boxes = glb_tile_tree_boxes[keep_ix_box_area] + glb_tile_tree_scores = glb_tile_tree_scores[keep_ix_box_area] + print('boxes area filtering: ', len(keep_ix_box_area) - np.sum(keep_ix_box_area), 'boxes removed') + + keep_ix_box_ratio = detect_utils.filter_on_box_ratio(glb_tile_tree_boxes, + min_edges_ratio = cfg.get('detection/trees/min_ratio_GD_boxes_edges'), + box_format = 'xyxy') + glb_tile_tree_boxes = glb_tile_tree_boxes[keep_ix_box_ratio] + glb_tile_tree_scores = glb_tile_tree_scores[keep_ix_box_ratio] + print('box edge ratio filtering:', len(keep_ix_box_ratio) - np.sum(keep_ix_box_ratio), 'boxes removed') + + keep_ix_nms = torchvision.ops.nms(torch.tensor(glb_tile_tree_boxes), torch.tensor(glb_tile_tree_scores.astype(np.float64)), cfg.get('detection/trees/nms_threshold')) + len_bf_nms = len(glb_tile_tree_boxes) + glb_tile_tree_boxes = glb_tile_tree_boxes[keep_ix_nms] + glb_tile_tree_scores = glb_tile_tree_scores[keep_ix_nms] + print('nms filtering:', len_bf_nms - len(keep_ix_nms), 'boxes removed') + + if len(glb_tile_tree_boxes.shape) == 1: + glb_tile_tree_boxes = np.expand_dims(glb_tile_tree_boxes, axis = 0) + if glb_tile_tree_scores.size == 1: + glb_tile_tree_scores = np.expand_dims(glb_tile_tree_scores, axis = 0) + + if georef: #create a gdf with the boxes in proj coordinates + for i, box in enumerate(glb_tile_tree_boxes): + #need to invert x and y to go from col row to row col index + try: + glb_tile_tree_boxes[i] = np.array(to_xy(box[1], box[0]) + to_xy(box[3], box[2])) + # catch and print + except Exception as e: + print(f'Error in box {i}: {e}') + cols = {'score': list(glb_tile_tree_scores), + 'geometry': [samplers_utils.xyxyBox2Polygon(box) for box in glb_tile_tree_boxes]} + + gdf = gpd.GeoDataFrame(cols, crs = crs) + + if self.event.cross_wlb == True: + #keep only tree detections that are inside tile_aoi_gdf + gdf = gdf[gdf['geometry'].apply(lambda x: tile_aoi_gdf.intersects(x).any())] + print("Not in aoi:", len(glb_tile_tree_scores) - len(gdf), "boxes removed") + print('Number of tree boxes after all filtering: ', len(gdf)) + return gdf + + return glb_tile_tree_boxes #xyxy format, global (tile) index + + def noTGeo_seg_glb_tree_and_build_tile(self, tile_path: str, tile_aoi_gdf: gpd.GeoDataFrame, aoi_mask: np.ndarray): + cfg = self.event.cfg + if self.build_gdf is None: #set buildings at mosaic level + self.set_build_gdf() + + tile_building_gdf = self.proj_build_gdf.iloc[self.sindex_proj_build_gdf.query(samplers_utils.path_2_tile_aoi(tile_path))] + + trees_gdf = self.detect_trees_tile(tile_path, tile_aoi_gdf = tile_aoi_gdf, aoi_mask = aoi_mask, georef = True) + + dataset = geoDatasets.SingleTileDataset(str(tile_path), tile_aoi_gdf, aoi_mask) + sampler = samplers.SinglePatchSampler(dataset, patch_size=cfg.get('models/esam/size'), stride=cfg.get('models/esam/stride')) + dataloader = DataLoader(dataset, sampler=sampler, collate_fn=geoDatasets.single_sample_collate_fn) + + canvas = np.zeros((2,) + dataset.tile_shape, dtype=np.float32) # dim (3, h_tile, w_tile). The dim 0 is: tree, build + weights = np.zeros(dataset.tile_shape, dtype=np.float32) # dim (h_tile, w_tile) + for _, batch in tqdm(enumerate(dataloader), total = len(dataloader), desc = "Segmenting"): + original_img_tsr = batch['image'] + + #TREES + #get the tree boxes in batches and the number of trees for each image + #tree_boxes_b è una lista con degli array di shape (n, 4) dove n è il numero di tree boxes + if len(trees_gdf) == 0: + tree_boxes_b = [np.empty((0, 4))] + num_trees4img = [0] + else: + tree_boxes_b, num_trees4img = detect.get_batch_boxes(batch['bbox'], + proj_gdf = trees_gdf, + dataset_res = dataset.res, + ext_mt = 0) + + #BUILDINGS + #get the building boxes in batches and the number of buildings for each image + #building_boxes_b è una lista con degli array di shape (n, 4) dove n è il numero di building boxes + building_boxes_b, num_build4img = detect.get_refined_batch_boxes(batch['bbox'], + proj_gdf = tile_building_gdf, + dataset_res = dataset.res, + ext_mt = cfg.get('detection/buildings/ext_mt_build_box')) + + if num_trees4img[0] > 0 or num_build4img[0] > 0: + + max_detect = max(num_trees4img + num_build4img) + + #obtain the right input for the ESAM model (trees + buildings) + input_points, input_labels = segment_utils.get_input_pts_and_lbs(tree_boxes_b, building_boxes_b, max_detect) + + # segment the image and get for each image as many masks as the number of boxes, + # for GPU constraint use num_parall_queries + tree_build_mask = segment.ESAM_from_inputs_fast(original_img_tsr = original_img_tsr, + input_points = torch.from_numpy(input_points), + input_labels = torch.from_numpy(input_labels), + num_tree_boxes= num_trees4img, + efficient_sam = self.event.efficient_sam, + device = cfg.get('models/esam/device'), + num_parall_queries = cfg.get('models/esam/num_parall_queries')) + + else: + #print('no prompts in patch, skipping...') + tree_build_mask = np.zeros((2, *original_img_tsr.shape[2:]), dtype = np.float32) #(2, h, w) + + canvas, weights = segment_utils.write_canvas_geo_window(canvas = canvas, + weights = weights, + patch_masks_b = np.expand_dims(tree_build_mask, axis=0), + top_lft_indexes = batch['top_lft_index'], + ) + + canvas = np.divide(canvas, weights, out=np.zeros_like(canvas), where=weights!=0) + canvas = np.greater(canvas, 0) #turn logits into bool + canvas = np.where(aoi_mask, canvas, False) + return canvas + + def seg_glb_tree_and_build_tile_fast(self, tile_path: str, tile_aoi_gdf: gpd.GeoDataFrame, aoi_mask: np.ndarray): + cfg = self.event.cfg + if self.build_gdf is None: #set buildings at mosaic level + self.set_build_gdf() + + tile_building_gdf = self.proj_build_gdf.iloc[self.sindex_proj_build_gdf.query(samplers_utils.path_2_tile_aoi(tile_path))] + + trees_gdf = self.detect_trees_tile(tile_path, tile_aoi_gdf = tile_aoi_gdf, aoi_mask = aoi_mask, georef = True) + + dataset = geoDatasets.MxrSingleTileNoEmpty(str(tile_path), tile_aoi_gdf, aoi_mask) + sampler = samplers.BatchGridGeoSampler(dataset, + batch_size=cfg.get('models/esam/bs'), + size=cfg.get('models/esam/size'), + stride=cfg.get('models/esam/stride')) + dataloader = DataLoader(dataset , batch_sampler=sampler, collate_fn=stack_samples) + + canvas = np.zeros((2,) + samplers_utils.tile_sizes(dataset), dtype=np.float32) # dim (3, h_tile, w_tile). The dim 0 is: tree, build + weights = np.zeros(samplers_utils.tile_sizes(dataset), dtype=np.float32) # dim (h_tile, w_tile) + for batch_ix, batch in tqdm(enumerate(dataloader), total = len(dataloader), desc = "Segmenting"): + original_img_tsr = batch['image'] + + #TREES + #get the tree boxes in batches and the number of trees for each image + #tree_boxes_b è una lista con degli array di shape (n, 4) dove n è il numero di tree boxes + if len(trees_gdf) == 0: + tree_boxes_b = [np.empty((0, 4))] + num_trees4img = [0] + else: + tree_boxes_b, num_trees4img = detect.get_batch_boxes(batch['bbox'], + proj_gdf = trees_gdf, + dataset_res = dataset.res, + ext_mt = 0) + + #BUILDINGS + #get the building boxes in batches and the number of buildings for each image + #building_boxes_b è una lista con degli array di shape (n, 4) dove n è il numero di building boxes + building_boxes_b, num_build4img = detect.get_refined_batch_boxes(batch['bbox'], + proj_gdf = tile_building_gdf, + dataset_res = dataset.res, + ext_mt = cfg.get('detection/buildings/ext_mt_build_box')) + + if num_trees4img[0] > 0 or num_build4img[0] > 0: + + max_detect = max(num_trees4img + num_build4img) + + #obtain the right input for the ESAM model (trees + buildings) + input_points, input_labels = segment_utils.get_input_pts_and_lbs(tree_boxes_b, building_boxes_b, max_detect) + + # segment the image and get for each image as many masks as the number of boxes, + # for GPU constraint use num_parall_queries + tree_build_mask = segment.ESAM_from_inputs_fast(original_img_tsr = original_img_tsr, + input_points = torch.from_numpy(input_points), + input_labels = torch.from_numpy(input_labels), + num_tree_boxes= num_trees4img, + efficient_sam = self.event.efficient_sam, + device = cfg.get('models/esam/device'), + num_parall_queries = cfg.get('models/esam/num_parall_queries')) + + else: + #print('no prompts in patch, skipping...') + tree_build_mask = np.zeros((2, *original_img_tsr.shape[2:]), dtype = np.float32) #(2, h, w) + + canvas, weights = segment_utils.write_canvas_geo_window(canvas = canvas, + weights = weights, + patch_masks_b = np.expand_dims(tree_build_mask, axis=0), + top_lft_indexes = batch['top_lft_index'], + ) + + canvas = np.divide(canvas, weights, out=np.zeros_like(canvas), where=weights!=0) + canvas = np.greater(canvas, 0) #turn logits into bool + canvas = np.where(aoi_mask, canvas, False) + return canvas + + def segment_tile(self, tile_path, out_dir_root, overwrite = False, separate_masks = True): + """ + glbl_det: if True tree detection are computed at tile level, if False at patch level + """ + + #create folder if it does not exists + tile_path = Path(tile_path) + out_dir_root = Path(out_dir_root) + out_names = output.gen_names(tile_path, separate_masks) + (out_dir_root / out_names[0]).parent.mkdir(parents=True, exist_ok=True) #create folder if not exists + if not overwrite: + for out_name in out_names: + if (out_dir_root / out_name).exists(): + print(f'File {out_dir_root / out_name} already exists') + return True + #assert not (out_dir_root / out_name).exists(), f'File {out_dir_root / out_name} already exists' + + tile_aoi_gdf = samplers_utils.path_2_tile_aoi_no_water(tile_path, self.event.filtered_wlb_gdf) + + if tile_aoi_gdf.iloc[0].geometry.is_empty: #tile completely on water + pass + # print("\nSave an empty mask") + # thread = threading.Thread(target=self.save_all_blank, + # args=(out_dir_root, tile_path, out_names, separate_masks)) + #self.save_all_blank(out_dir_root, tile_path, out_names, separate_masks) + + else: + #retrieve roads and build at mosaic level if not already done + if self.build_gdf is None: + response = self.set_build_gdf() + if response == False: + return False + if self.road_gdf is None: + self.set_road_gdf() + with rasterio.open(tile_path) as src: + transform = src.transform + tile_shape = (src.height, src.width) + aoi_mask = rasterize(tile_aoi_gdf.geometry, out_shape = tile_shape, fill=False, default_value=True, transform = transform) + + #tree_and_build_mask = self.seg_glb_tree_and_build_tile_fast(tile_path, tile_aoi_gdf, aoi_mask) + tree_and_build_mask = self.noTGeo_seg_glb_tree_and_build_tile(tile_path, tile_aoi_gdf, aoi_mask) + + #thread = threading.Thread(target=self.postprocess_and_save, + # args=(tree_and_build_mask, out_dir_root, tile_path, out_names, tile_aoi_gdf, aoi_mask,separate_masks)) + + #thread.start() + self.postprocess_and_save(tree_and_build_mask, out_dir_root, tile_path, out_names, tile_aoi_gdf, aoi_mask, separate_masks) + + return True + #TODO: not working but should be faster + def seg_and_poly_road_tile(self, tile_path, tile_aoi_gdf): + cfg = self.event.cfg + with rasterio.open(tile_path) as src: + transform = src.transform + tile_h = src.height + tile_w = src.width + + intersected_buffered_lines_ser = self.polyg_road_tile(tile_aoi_gdf) + if len(intersected_buffered_lines_ser) != 0: + road_mask = rasterize(intersected_buffered_lines_ser, out_shape=(tile_h, tile_w), transform=transform) + else: #no roads + print('No roads') + road_mask = np.zeros((tile_h, tile_w)) + + return road_mask, intersected_buffered_lines_ser + + # function that wraps from postprocessing to be used in a separate thread + def postprocess_and_save(self, tree_and_build_mask, out_dir_root, tile_path, out_names, tile_aoi_gdf, aoi_mask, separate_masks = True): + cfg = self.event.cfg + road_mask = self.seg_road_tile(tile_path, aoi_mask) + road_series = self.polyg_road_tile(tile_aoi_gdf) + #road_mask, road_series = self.seg_and_poly_road_tile(tile_path, tile_aoi_gdf) + tree_and_build_mask_copy = tree_and_build_mask.copy() + overlap_masks = np.concatenate((np.expand_dims(road_mask, axis=0), tree_and_build_mask) , axis = 0) + no_overlap_masks = segment_utils.rmv_mask_overlap(overlap_masks) + if cfg.get('segmentation/general/clean_mask'): + print('Cleaning the masks: holes_area_th = ', cfg.get('segmentation/general/rmv_holes_area_th'), 'small_obj_area = ', cfg.get('segmentation/general/rmv_small_obj_area_th')) + no_overlap_masks = segment_utils.clean_masks(no_overlap_masks, + area_threshold = cfg.get('segmentation/general/rmv_holes_area_th'), + min_size = cfg.get('segmentation/general/rmv_small_obj_area_th')) + print('Mask cleaning done') + + output.masks2Tifs(tile_path, + no_overlap_masks, + out_names = out_names, + separate_masks = separate_masks, + out_dir_root = out_dir_root) + try: + output.masks2parquet(tile_path, + tree_and_build_mask_copy, + out_dir_root=out_dir_root, + out_names=out_names, road_series=road_series) + except Exception as e: + print(f'Error in saving parquet: {e}') + + def save_all_blank(self, out_dir_root, tile_path, out_names, separate_masks = True): + tile_h, tile_w = samplers_utils.tile_path_2_tile_size(tile_path) + masks = np.zeros((3, tile_h, tile_w)).astype(bool) + output.masks2Tifs(tile_path, + masks, + out_names = out_names, + separate_masks = separate_masks, + out_dir_root = out_dir_root) + + def segment_all_tiles(self, out_dir_root, time_per_tile = []): + mos_seg_tile = 1 + for tile_path in self.tiles_paths: + print('') + print(f'Starting segmenting tile {tile_path}, ({mos_seg_tile}/{self.tiles_num}), ({self.event.segmented_tiles}/{self.event.total_tiles})') + print('') + start_time = perf_counter() + response = self.segment_tile(tile_path, out_dir_root=out_dir_root, separate_masks=False) + end_time = perf_counter() + if response == False: #this means that buildings footprint are not available for the mosaic, go to next mosaic + return time_per_tile, False + execution_time = end_time - start_time + time_per_tile.append(execution_time) + print(f'Finished segmenting tile {tile_path} in {execution_time:.2f} seconds') + print(f'Average time per tile: {np.mean(time_per_tile):.2f} seconds') + self.event.segmented_tiles += 1 + mos_seg_tile += 1 + return time_per_tile, True + +class Event: + def __init__(self, + name, + cfg, + maxar_root = '/nfs/projects/overwatch/maxar-segmentation/maxar-open-data', + maxar_metadata_path = './metadata/from_github_maxar_metadata/datasets', + region = 'infer'): + #Configs + self.cfg = cfg + self.time_per_tile = [] + + #esam + self.efficient_sam = build_efficient_sam_vitt(os.path.join(self.cfg.get('models/esam/root_path'), 'weights/efficient_sam_vitt.pt')).to(self.cfg.get('models/esam/device')) + + #gdino + #self.gdino = + + #Paths + self.maxar_root = Path(maxar_root) + self.buildings_ds_links_path = Path('./metadata/buildings_dataset_links.csv') + self.maxar_metadata_path = Path(maxar_metadata_path) + + #Event + self.name = name + self.when = cfg.get('event/when') + self.region_name = names.get_region_name(self.name) if region == 'infer' else region + self.bbox = delimiters.get_event_bbox(self.name, extra_mt=1000) #TODO può essere ottimizzata sfruttando i mosaici + self.all_mosaics_names = names.get_mosaics_names(self.name, self.maxar_root, self.when) + + self.wlb_gdf = gpd.read_file('./metadata/eventi_confini_complete.gpkg') + self.filtered_wlb_gdf = self.wlb_gdf[self.wlb_gdf['event names'] == self.name] + if self.filtered_wlb_gdf.iloc[0].geometry is None: + print('Evento interamente su terra') + self.cross_wlb = False + self.filtered_wlb_gdf = None + else: + print('Evento su bordo') + self.cross_wlb = True + + print(f'Creating event: {self.name}\nRegion: {self.region_name}\nMosaics: {self.all_mosaics_names}') + #Roads + self.road_gdf = None + + #Mosaics + self.mosaics = {} + + #Init mosaics + for m_name in self.all_mosaics_names: + self.mosaics[m_name] = Mosaic(m_name, self) + + self.total_tiles = sum([mosaic.tiles_num for mosaic in self.mosaics.values()]) + self.segmented_tiles = 1 + + def set_seg_config(self, seg_config): + self.seg_config = seg_config + + #Roads methods + def set_road_gdf(self): #set road_gdf for the event + region_road_gdf = gen_gdf.get_region_road_gdf(self.region_name) + self.road_gdf = filter.filter_gdf_w_bbox(region_road_gdf, self.bbox) + + def fast_set_road_gdf(self, roads_root = '/nfs/projects/overwatch/maxar-segmentation/microsoft-roads'): + """ + Get a gdf containing the roads of a region. + Input: + region_name: Name of the region. Example: 'AfricaWest-Full' + roads_root: Root directory of the roads datasets + """ + start_time = time.time() + print(f'Roads: reading roads for the whole {self.region_name} region') + if self.region_name[-4:] != '.tsv': + region_name = self.region_name + '.tsv' + + def custom_json_loads(s): + try: + return geometry.shape(json.loads(s)['geometry']) + except: + return geometry.LineString() + + chunksize = 100_000 + + (minx, miny), (maxx, maxy) = self.bbox + vertices = [(minx, miny), (maxx, miny), (maxx, maxy), (minx, maxy), (minx, miny)] #lon lat + query_bbox_poly = geometry.Polygon(vertices) + + roads_root = Path(roads_root) + if region_name != 'USA.tsv': + print('Roads: not in USA. Region name:', region_name) + for chunk in pd.read_csv(roads_root/region_name, names=['country', 'geometry'], sep='\t', chunksize=chunksize): + hits = gbl_gdf.sindex.query(query_bbox_poly) + gbl_gdf[hits] + #region_road_df = pd.read_csv(roads_root/region_name, names =['country', 'geometry'], sep='\t') + else: + print('is USA:', region_name) + region_road_df = pd.read_csv(roads_root/region_name, names =['geometry'], sep='\t') + #region_road_df['geometry'] = region_road_df['geometry'].apply(json.loads).apply(lambda d: geometry.shape(d.get('geometry'))) + #slightly faster + region_road_df['geometry'] = region_road_df['geometry'].apply(custom_json_loads) + region_road_gdf = gpd.GeoDataFrame(region_road_df, crs=4326) + end_time = time.time() + elapsed_time = end_time - start_time + print(f"Elapsed time for reading roads: {elapsed_time:.2f} seconds") + return region_road_gdf + + def set_mos_road_gdf(self, mosaic_name): #set road_gdf for the mosaic + if self.road_gdf is None: + self.set_road_gdf() + + self.mosaics[mosaic_name].set_road_gdf() + + def set_all_mos_road_gdf(self): #set road_gdf for all the mosaics + for mosaic_name, mosaic in self.mosaics.items(): + if mosaic.road_gdf is None: + self.set_mos_road_gdf(mosaic_name) + + #Buildings methods + def set_build_gdf_in_mos(self, mosaic_name): + self.mosaics[mosaic_name].set_build_gdf() + + def set_build_gdf_all_mos(self): + for mosaic_name, mosaic in self.mosaics.items(): + if mosaic.build_gdf is None: + self.set_build_gdf_in_mos(mosaic_name) + + def get_roi_polygon(self, wkt: bool = False): + poly = samplers_utils.xyxy_2_Polygon(self.bbox) + if wkt: + return poly.wkt + else: + return poly + + def get_mosaic(self, mosaic_name): + return self.mosaics[mosaic_name] + + #Segment methods + def seg_all_mosaics(self, out_dir_root): + mos_count = 1 + for __, mosaic in self.mosaics.items(): + if mosaic.is_rgb: + print(f"Start segmenting mosaic: {mosaic.name}, ({mos_count}/{len(self.mosaics)})") + times, response = mosaic.segment_all_tiles(out_dir_root=out_dir_root, time_per_tile=self.time_per_tile) + self.time_per_tile.extend(times) + mos_count += 1 + if response == False: + print(f'Buildings footprint not available for mosaic: {mosaic.name}. Proceeding to next mosaic...') + self.segmented_tiles += mosaic.tiles_num + continue + else: + print(f"First image of mosaic {mosaic.name} is not rgb, we assume the whole mosaic is not rgb. Skipping it...") + mos_count += 1 + self.segmented_tiles += mosaic.tiles_num + continue + + def seg_mos_by_keys(self, keys, out_dir_root): + mos_count = 1 + for mos_name in keys: + mosaic = self.mosaics[mos_name] + if mosaic.is_rgb: + print(f"Start segmenting mosaic: {mosaic.name}, ({mos_count}/{len(keys)})") + times, response = mosaic.segment_all_tiles(out_dir_root=out_dir_root, time_per_tile=self.time_per_tile) + self.time_per_tile.extend(times) + mos_count += 1 + if response == False: + print(f'Buildings footprint not available for mosaic: {mosaic.name}. Proceeding to next mosaic...') + self.segmented_tiles += mosaic.tiles_num + continue + else: + print(f"First image of mosaic {mosaic.name} is not rgb, we assume the whole mosaic is not rgb. Skipping it...") + mos_count += 1 + self.segmented_tiles += mosaic.tiles_num + continue \ No newline at end of file diff --git a/src/maxarseg/assemble/names.py b/src/maxarseg/assemble/names.py new file mode 100644 index 0000000..3f79ded --- /dev/null +++ b/src/maxarseg/assemble/names.py @@ -0,0 +1,61 @@ +from pathlib import Path +import pandas as pd +import glob +import os + + +def get_region_name(event_name, metadata_root = './metadata'): + """ + Get the region associate with the input event. + It is based in the event_id2State2Region.csv file. + Input: + event_name: Example: 'southafrica-flooding22' + Output: + region name: Example: 'AfricaSouth-Full' + """ + + metadata_root = Path(metadata_root) + df = pd.read_csv(metadata_root / 'evet_id2State2Region.csv') + return df[df['event_id'] == event_name]['region'].values[0] + + +def get_all_events(data_root = '/nfs/projects/overwatch/maxar-segmentation/maxar-open-data'): + """ + Get all the events name in the data_root folder. + Input: + data_root: Example: '/nfs/projects/overwatch/maxar-segmentation/maxar-open-data' + Output: + all_events: List of events. + """ + + data_root = Path(data_root) + all_events = [] + for event_name in glob.glob('**/*.tif', recursive = True, root_dir=data_root): + if event_name.split('/')[0] not in all_events: + all_events.append(event_name.split('/')[0]) + return sorted(list(all_events)) + + +def get_mosaics_names(event_name, data_root = '/nfs/projects/overwatch/maxar-segmentation/maxar-open-data', when = None): + """ + Get all the mosaic names for an event. + Input: + event_name: Example: 'Gambia-flooding-8-11-2022' + data_root: Example: '/nfs/projects/overwatch/maxar-segmentation/maxar-open-data' + when: 'pre' or 'post'. Default matches both + Output: + all_mosaics: List of mosaic names. Example: ['104001007A565700', '104001007A565800'] + """ + + data_root = Path(data_root) + all_mosaics = [] + if when in ['pre', 'post']: + for mosaic_name in glob.glob('*', root_dir=data_root/event_name/when): + all_mosaics.append(mosaic_name) + elif when is None or when == 'None': + for mosaic_name in glob.glob('**/*', root_dir=data_root/event_name): + #all_mosaics.append(mosaic_name.split('/')[1]) + all_mosaics.append(os.path.split(mosaic_name)[1]) + else: + raise ValueError('Variable when must be: "pre", "post" or "None"') + return all_mosaics \ No newline at end of file diff --git a/src/maxarseg/configs.py b/src/maxarseg/configs.py new file mode 100644 index 0000000..6b9e18a --- /dev/null +++ b/src/maxarseg/configs.py @@ -0,0 +1,214 @@ + +from pathlib import Path +import os +import sys +import yaml + + +from groundingdino.util.inference import load_model as GD_load_model +from maxarseg.efficient_sam.build_efficient_sam import build_efficient_sam_vitt + + +class SegmentConfig: + """ + Config class for the segmentation pipeline. + It contains detection and segmentation parameters as well as the models themselves. + """ + def __init__(self, + batch_size, + size = 600, + stride = 400, + + device = 'cuda', + GD_root = "./models/GDINO", + GD_config_file = "GroundingDINO_SwinT_OGC.py", + GD_weights = "groundingdino_swint_ogc.pth", + + TEXT_PROMPT = 'bush', #'green tree' + BOX_THRESHOLD = 0.15, + TEXT_THRESHOLD = 0.30, + + max_area_GD_boxes_mt2 = 6000, + min_ratio_GD_boxes_edges = 0, + perc_reduce_tree_boxes = 0, + + road_width_mt = 5, + ext_mt_build_box = 0, + + ESAM_root = './models/EfficientSAM', + ESAM_num_parall_queries = 5, + smooth_patch_overlap = False, #if this is false, stride could be equal to size + use_separate_detect_config = False, + + clean_masks_bool = False, + rmv_holes_area_th = 80, + rmv_small_obj_area_th = 80): + + #General + self.batch_size = batch_size + self.size = size + self.stride = stride # Overlap between each patch = (size - stride) + self.device = device + self.smooth_patch_overlap = smooth_patch_overlap + + if not use_separate_detect_config: #if you are not using a separate detect_config then define here all the detection configuration + #Grounding Dino (Trees) + self.GD_root = Path(GD_root) + self.CONFIG_PATH = self.GD_root / GD_config_file + self.WEIGHTS_PATH = self.GD_root / GD_weights + + self.GD_model = GD_load_model(self.CONFIG_PATH, self.WEIGHTS_PATH).to(self.device) + self.TEXT_PROMPT = TEXT_PROMPT + self.BOX_THRESHOLD = BOX_THRESHOLD + self.TEXT_THRESHOLD = TEXT_THRESHOLD + self.max_area_GD_boxes_mt2 = max_area_GD_boxes_mt2 + self.min_ratio_GD_boxes_edges = min_ratio_GD_boxes_edges + self.perc_reduce_tree_boxes = perc_reduce_tree_boxes + + #Efficient SAM + self.efficient_sam = build_efficient_sam_vitt(os.path.join(ESAM_root, 'weights/efficient_sam_vitt.pt')).to(self.device) + self.ESAM_num_parall_queries = ESAM_num_parall_queries + + #Roads + self.road_width_mt = road_width_mt + + #Buildings + self.ext_mt_build_box = ext_mt_build_box + + #Post proc + self.clean_masks_bool = clean_masks_bool + self.rmv_holes_area_th = rmv_holes_area_th + self.rmv_small_obj_area_th = rmv_small_obj_area_th + + if not use_separate_detect_config: + print('\n- GD model device:', next(self.GD_model.parameters()).device) + print('- Efficient SAM device:', next(self.efficient_sam.parameters()).device) + + def __str__(self) -> str: + return f'{self.TEXT_PROMPT = }\n{self.BOX_THRESHOLD = }\n{self.TEXT_THRESHOLD = }\n{self.max_area_GD_boxes_mt2 = }\n{self.min_ratio_GD_boxes_edges = }\n{self.perc_reduce_tree_boxes = }\n{self.road_width_mt = }\n{self.ext_mt_build_box = }' +class DetectConfig: + + def __init__(self, + size = 600, + stride = 400, + device = 'cuda', + + GD_batch_size = 1, + GD_root = "./models/GDINO", + GD_config_file = "GroundingDINO_SwinT_OGC.py", + GD_weights = "groundingdino_swint_ogc.pth", + + TEXT_PROMPT = 'bush', #'green tree' + BOX_THRESHOLD = 0.15, + TEXT_THRESHOLD = 0.30, + + DF_patch_size = 800, + DF_patch_overlap = 0.25, + DF_box_threshold = 0.1, + DF_batch_size = 16, + + max_area_GD_boxes_mt2 = 6000, + min_ratio_GD_boxes_edges = 0.0, + perc_reduce_tree_boxes = 0.0, + nms_threshold = 0.5): + + #General + self.size = size + self.stride = stride # Overlap between each patch = (size - stride) + self.device = device + + #Grounding Dino (Trees) + self.GD_batch_size = GD_batch_size + self.GD_root = Path(GD_root) + self.CONFIG_PATH = self.GD_root / GD_config_file + self.WEIGHTS_PATH = self.GD_root / GD_weights + + #self.GD_model = GD_load_model(self.CONFIG_PATH, self.WEIGHTS_PATH).to(self.device) + self.TEXT_PROMPT = TEXT_PROMPT + self.BOX_THRESHOLD = BOX_THRESHOLD + self.TEXT_THRESHOLD = TEXT_THRESHOLD + + #DeepForest + self.DF_patch_size = DF_patch_size + self.DF_patch_overlap = DF_patch_overlap + self.DF_box_threshold = DF_box_threshold + self.DF_device = [int(device.split(':')[-1])] if len(device.split(':')) > 1 else 'auto' #Remove the port number from the device (e.g. 'cuda:0' -> 'cuda') + self.DF_batch_size = DF_batch_size + + #Filtering + self.max_area_GD_boxes_mt2 = max_area_GD_boxes_mt2 + self.min_ratio_GD_boxes_edges = min_ratio_GD_boxes_edges + self.perc_reduce_tree_boxes = perc_reduce_tree_boxes + self.nms_threshold = nms_threshold + + def __str__(self) -> str: + return f'{self.TEXT_PROMPT = }\n{self.BOX_THRESHOLD = }\n{self.TEXT_THRESHOLD = }\n{self.max_area_GD_boxes_mt2 = }\n{self.min_ratio_GD_boxes_edges = }\n{self.perc_reduce_tree_boxes = }' + +def merge_dictionaries_recursively(dict1, dict2): + ''' Update two config dictionaries recursively. + Args: + dict1 (dict): first dictionary to be updated + dict2 (dict): second dictionary which entries should be preferred + ''' + if dict2 is None: return + + for k, v in dict2.items(): + if k not in dict1: + dict1[k] = dict() + if isinstance(v, dict): + merge_dictionaries_recursively(dict1[k], v) + else: + dict1[k] = v + +class Config(object): + """Simple dict wrapper that adds a thin API allowing for slash-based retrieval of + nested elements, e.g. cfg.get_config("meta/dataset_name") + """ + def __init__(self, config_path, default_path='./configs/default_cfg.yaml'): + with open(config_path) as cf_file: + cfg = yaml.safe_load( cf_file.read() ) + + if default_path is not None: + with open(default_path) as def_cf_file: + default_cfg = yaml.safe_load( def_cf_file.read() ) + + merge_dictionaries_recursively(default_cfg, cfg) + + self._data = default_cfg + + def get(self, path=None, default=None): + # we need to deep-copy self._data to avoid over-writing its data + sub_dict = dict(self._data) + + if path is None: + return sub_dict + + path_items = path.split("/")[:-1] + data_item = path.split("/")[-1] + + for path_item in path_items: + sub_dict = sub_dict.get(path_item) + + value = sub_dict.get(data_item, default) + if value is None: + raise ValueError(f"Path '{path}' not found in config file") + return value + + def set(self, path, value): + if path is None: + raise ValueError("Path cannot be None") + + path_items = path.split("/") + sub_dict = self._data + + # Traverse the dictionary except for the last key + for path_item in path_items[:-1]: + if path_item not in sub_dict: + sub_dict[path_item] = {} # Create a new dictionary if the key does not exist + sub_dict = sub_dict[path_item] + + # Set the value to the last key in the path + sub_dict[path_items[-1]] = value + + def __str__(self) -> str: + return str(self._data) \ No newline at end of file diff --git a/src/maxarseg/detect/detect.py b/src/maxarseg/detect/detect.py new file mode 100644 index 0000000..e82d4af --- /dev/null +++ b/src/maxarseg/detect/detect.py @@ -0,0 +1,154 @@ +from maxarseg.detect import detect_utils +from groundingdino.util.inference import predict as GD_predict +import numpy as np +from typing import List +import geopandas as gpd +from maxarseg.samplers import samplers, samplers_utils +from maxarseg.geo_datasets import geoDatasets +from torch.utils.data import DataLoader +from torchgeo.datasets.utils import BoundingBox + + + +def get_GD_boxes(img_batch: np.array, #b,h,w,c + GDINO_model, + TEXT_PROMPT, + BOX_THRESHOLD, + TEXT_THRESHOLD, + dataset_res, + device, + max_area_mt2 = 3000, + min_edges_ratio = 0, + reduce_perc = 0): + + batch_tree_boxes4Sam = [] + sample_size = img_batch.shape[1] + num_trees4img = [] + + for img in img_batch: + image_transformed = detect_utils.GD_img_load(img) + tree_boxes, logits, phrases = GD_predict(GDINO_model, image_transformed, TEXT_PROMPT, BOX_THRESHOLD, TEXT_THRESHOLD, device = device) + #tree_boxes4Sam = [] + if len(tree_boxes) > 0: + keep_ix_tree_boxes_area = detect_utils.filter_on_box_area_mt2(tree_boxes, sample_size, dataset_res, max_area_mt2 = max_area_mt2) + keep_ix_tree_boxes_ratio = detect_utils.filter_on_box_ratio(tree_boxes, min_edges_ratio = min_edges_ratio) + keep_ix_tree_boxes = keep_ix_tree_boxes_area & keep_ix_tree_boxes_ratio + + reduced_tree_boxes = detect_utils.reduce_tree_boxes(tree_boxes[keep_ix_tree_boxes], reduce_perc = reduce_perc) + + tree_boxes4Sam = detect_utils.GDboxes2SamBoxes(reduced_tree_boxes, sample_size) + + num_trees4img.append(tree_boxes4Sam.shape[0]) + batch_tree_boxes4Sam.append(tree_boxes4Sam) + else: + num_trees4img.append(0) + batch_tree_boxes4Sam.append(np.empty((0,4))) + return batch_tree_boxes4Sam, np.array(num_trees4img) + +"""def get_GD_boxes_tile_optimized(img_batch: np.array, #b,h,w,c + GDINO_model, + TEXT_PROMPT, + BOX_THRESHOLD, + TEXT_THRESHOLD, + dataset_res, + device, + max_area_mt2 = 3000, + min_edges_ratio = 0, + reduce_perc = 0): + + batch_tree_boxes4Sam = [] + sample_size = img_batch.shape[1] + num_trees4img = [] + + for img in img_batch: + image_transformed = detect_utils.GD_img_load(img) + tree_boxes, logits, phrases = GD_predict(GDINO_model, image_transformed, TEXT_PROMPT, BOX_THRESHOLD, TEXT_THRESHOLD, device = device) + #tree_boxes4Sam = [] + if len(tree_boxes) > 0: + keep_ix_tree_boxes_area = detect_utils.filter_on_box_area_mt2(tree_boxes, sample_size, dataset_res, max_area_mt2 = max_area_mt2) + keep_ix_tree_boxes_ratio = detect_utils.filter_on_box_ratio(tree_boxes, min_edges_ratio = min_edges_ratio) + keep_ix_tree_boxes = keep_ix_tree_boxes_area & keep_ix_tree_boxes_ratio + + reduced_tree_boxes = detect_utils.reduce_tree_boxes(tree_boxes[keep_ix_tree_boxes], reduce_perc = reduce_perc) + + tree_boxes4Sam = detect_utils.GDboxes2SamBoxes(reduced_tree_boxes, sample_size) + + num_trees4img.append(tree_boxes4Sam.shape[0]) + batch_tree_boxes4Sam.append(tree_boxes4Sam) + else: + num_trees4img.append(0) + batch_tree_boxes4Sam.append(np.empty((0,4))) + return batch_tree_boxes4Sam, np.array(num_trees4img)""" + +# OLD FUNCTION +def get_batch_buildings_boxes(batch_bbox: List[BoundingBox], proj_buildings_gdf: gpd.GeoDataFrame, dataset_res, ext_mt = 10): + batch_building_boxes = [] + num_build4img = [] + for bbox in batch_bbox: + query_bbox_poly = samplers_utils.boundingBox_2_Polygon(bbox) #from patch bbox to polygon + index_MS_buildings = proj_buildings_gdf.sindex #get spatial index + buildig_hits = index_MS_buildings.query(query_bbox_poly) #query buildinds index + num_build4img.append(len(buildig_hits)) #append number of buildings + #building_boxes = [] + if len(buildig_hits) > 0: #if there are buildings from proj geo cords to indexes + building_boxes = samplers_utils.rel_bbox_coords(proj_buildings_gdf.iloc[buildig_hits], query_bbox_poly.bounds, dataset_res, ext_mt=ext_mt) + building_boxes = np.array(building_boxes) + else: #append empty array if no buildings + building_boxes = np.empty((0,4)) + + batch_building_boxes.append(building_boxes) + + return batch_building_boxes, np.array(num_build4img) + +def get_batch_boxes(batch_bbox: List[BoundingBox], proj_gdf: gpd.GeoDataFrame, dataset_res, ext_mt = 0): + """ + Given a batch of bounding boxes in a proj crs, it returns the boxes in the right coordinates relative to the sampled patch. + It is necessary that the bbox and the gdf are in the same crs. + """ + batch_boxes = [] + num_boxes4img = [] + gdf_index = proj_gdf.sindex + for bbox in batch_bbox: + query_patch_poly = samplers_utils.boundingBox_2_Polygon(bbox) #from patch bbox to polygon + + hits = gdf_index.query(query_patch_poly) #query index + + num_boxes4img.append(len(hits)) #append number of boxes + + if len(hits) > 0: #if there is at least a box in the query_bbox_poly + + boxes = samplers_utils.rel_bbox_coords(geodf = proj_gdf.iloc[hits], + ref_coords = query_patch_poly.bounds, + res = dataset_res, + ext_mt = ext_mt) + boxes = np.array(boxes) + else: #append empty array if no buildings + boxes = np.empty((0,4)) + + batch_boxes.append(boxes) + + return batch_boxes, np.array(num_boxes4img) + +def get_refined_batch_boxes(batch_bbox: List[BoundingBox], proj_gdf: gpd.GeoDataFrame, dataset_res, ext_mt = 0): + if len(batch_bbox) != 1: + raise ValueError("Invalid input: batch_bbox should contain exactly one bounding box (segmentation batch size must be 1)") + bbox = batch_bbox[0] + query_patch_poly = samplers_utils.boundingBox_2_Polygon(bbox) #from patch bbox to polygon + try: + intersec_geom = proj_gdf.intersection(query_patch_poly) + except Exception as e: + proj_gdf['geometry'] = proj_gdf['geometry'].apply(lambda geom: geom.buffer(0) if not geom.is_valid else geom) + intersec_geom = proj_gdf.intersection(query_patch_poly) + valid_gdf = intersec_geom[~intersec_geom.is_empty] + num_boxes4img = [len(valid_gdf)] + if len(valid_gdf) > 0: + boxes = samplers_utils.rel_bbox_coords(geodf = valid_gdf, + ref_coords = query_patch_poly.bounds, + res = dataset_res, + ext_mt = ext_mt) + else: + boxes = np.empty((0,4)) + + batch_boxes = [boxes] + + return batch_boxes, np.array(num_boxes4img) \ No newline at end of file diff --git a/src/maxarseg/detect/detect_utils.py b/src/maxarseg/detect/detect_utils.py new file mode 100644 index 0000000..9cd55a6 --- /dev/null +++ b/src/maxarseg/detect/detect_utils.py @@ -0,0 +1,118 @@ +import torch +from typing import Union, List +from torchvision.ops import box_convert +import torchvision +import numpy as np +import groundingdino.datasets.transforms as T +from PIL import Image + + + +def GDboxes2SamBoxes(boxes: torch.Tensor, img_shape: Union[tuple[float, float], float]): + """ + Convert the boxes from the format cxcywh to the format xyxy. + Inputs: + boxes: torch.Tensor of shape (N, 4). Where boxes are in the format cxcywh (the output of GroundingDINO). + img_shape: tuple (h, w) + img_res: float, the resolution of the image (mt/pxl). + Output: + boxes: torch.Tensor of shape (N, 4). Where boxes are in the format xyxy. + """ + if isinstance(img_shape, (float, int)): + img_shape = (img_shape, img_shape) + + h, w = img_shape + SAM_boxes = boxes.clone() + SAM_boxes = SAM_boxes * torch.Tensor([w, h, w, h]) + SAM_boxes = box_convert(boxes=SAM_boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy() + return SAM_boxes + +def GD_img_load(np_img_rgb: np.array)-> torch.Tensor: + """ + Transform the image from np.array to torch.Tensor and normalize it. + """ + transform = T.Compose( + [ + T.RandomResize([800], max_size=1333), + T.ToTensor(), + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ] + ) + image_pillow = Image.fromarray(np_img_rgb) + # try: + # image_pillow = Image.fromarray(np_img_rgb) + # except: + # print("Error in Image.fromarray") + image_transformed, _ = transform(image_pillow, None) + return image_transformed + +def filter_on_box_area_mt2(boxes, img_shape: Union[tuple[float, float], float] = None, img_res = None, min_area_mt2 = 0, max_area_mt2 = 1500, box_format = 'cxcywh'): + """ + Filter boxes based on min and max area. + Inputs: + boxes: torch.Tensor of shape (N, 4). Where boxes are in the format cxcywh (the output of GroundingDINO) or xyxy (rel coords). + img_shape: tuple (h, w) + img_res: float, the resolution of the image (mt/pxl). + min_area_mt2: float + max_area_mt2: float + box_format: str, the format of the boxes. 'cxcywh' or 'xyxy'. + Output: + keep_ix: torch.Tensor of shape (N,) + """ + if box_format == 'cxcywh': + if isinstance(img_shape, (float, int)): + img_shape = (img_shape, img_shape) + + h, w = img_shape + tmp_boxes = boxes.clone() + tmp_boxes = tmp_boxes * torch.Tensor([w, h, w, h]) + + area_mt2 = torch.prod(tmp_boxes[:,2:], dim=1) * img_res**2 + + elif box_format == 'xyxy': + width = boxes[:, 2] - boxes[:, 0] + height = boxes[:, 3] - boxes[:, 1] + area_mt2 = width * height + + keep_ix = (area_mt2 > min_area_mt2) & (area_mt2 < max_area_mt2) + + return keep_ix + +def filter_on_box_ratio(boxes, min_edges_ratio = 0, box_format = 'cxcywh',): + """ + Filter boxes based on the ratio between the edges. + """ + if box_format == 'cxcywh': + keep_ix = (boxes[:,2] / boxes[:,3] > min_edges_ratio) & (boxes[:,3] / boxes[:,2] > min_edges_ratio) + elif box_format == 'xyxy': + width = boxes[:, 2] - boxes[:, 0] #xmax - xmin + height = boxes[:, 3] - boxes[:, 1] #ymax - ymin + keep_ix = (width / height > min_edges_ratio) & (height / width > min_edges_ratio) + return keep_ix + +def reduce_tree_boxes(boxes, reduce_perc): + """ + Reduce the size of the boxes by reduce_perc. Keeping the center fixed. + Input: + boxes: torch.Tensor of shape (N, 4). Where boxes are in the format cxcywh (the output of GroundingDINO). + reduce_perc: float, the float to reduce the boxes. + Output: + boxes: torch.Tensor of shape (N, 4). Where reduced boxes are in the format cxcywh. + """ + reduced_boxes = boxes.clone() + reduced_boxes[:,2:] = reduced_boxes[:,2:] * (1 - reduce_perc) + return reduced_boxes + +def rel2glb_xyxy(rel_xyxy_tree_boxes, top_left_xy): + """ + Convert the relative coordinates of the boxes to global coordinates. + Inputs: + rel_xyxy_tree_boxes: torch.Tensor of shape (N, 4). Where boxes are in the format xyxy. + top_left_xy: tuple (x, y) + Output: + glb_xyxy_tree_boxes: torch.Tensor of shape (N, 4). Where boxes are in the format xyxy. + """ + glb_xyxy_tree_boxes = rel_xyxy_tree_boxes.clone() + glb_xyxy_tree_boxes[:,[0,2]] += top_left_xy[0] + glb_xyxy_tree_boxes[:,[1,3]] += top_left_xy[1] + return glb_xyxy_tree_boxes \ No newline at end of file diff --git a/src/maxarseg/efficient_sam/__init__.py b/src/maxarseg/efficient_sam/__init__.py new file mode 100644 index 0000000..22a2d29 --- /dev/null +++ b/src/maxarseg/efficient_sam/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +from .build_efficient_sam import ( + build_efficient_sam_vitt, + build_efficient_sam_vits, +) diff --git a/src/maxarseg/efficient_sam/build_efficient_sam.py b/src/maxarseg/efficient_sam/build_efficient_sam.py new file mode 100644 index 0000000..360223e --- /dev/null +++ b/src/maxarseg/efficient_sam/build_efficient_sam.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .efficient_sam import build_efficient_sam + +def build_efficient_sam_vitt(checkpoint): + if checkpoint is None: + checkpoint = "weights/efficient_sam_vitt.pt" + return build_efficient_sam( + encoder_patch_embed_dim=192, + encoder_num_heads=3, + checkpoint=checkpoint, + ).eval() + + +def build_efficient_sam_vits(checkpoint): + if checkpoint is None: + checkpoint = "weights/efficient_sam_vits.pt" + return build_efficient_sam( + encoder_patch_embed_dim=384, + encoder_num_heads=6, + checkpoint=checkpoint, + ).eval() diff --git a/src/maxarseg/efficient_sam/efficient_sam.py b/src/maxarseg/efficient_sam/efficient_sam.py new file mode 100644 index 0000000..3a3ba4c --- /dev/null +++ b/src/maxarseg/efficient_sam/efficient_sam.py @@ -0,0 +1,310 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Any, List, Tuple, Type + +import torch +import torch.nn.functional as F + +from torch import nn, Tensor + +from .efficient_sam_decoder import MaskDecoder, PromptEncoder +from .efficient_sam_encoder import ImageEncoderViT +from .two_way_transformer import TwoWayAttentionBlock, TwoWayTransformer + +class EfficientSam(nn.Module): + mask_threshold: float = 0.0 + image_format: str = "RGB" + + def __init__( + self, + image_encoder: ImageEncoderViT, + prompt_encoder: PromptEncoder, + decoder_max_num_input_points: int, + mask_decoder: MaskDecoder, + pixel_mean: List[float] = [0.485, 0.456, 0.406], + pixel_std: List[float] = [0.229, 0.224, 0.225], + ) -> None: + """ + SAM predicts object masks from an image and input prompts. + + Arguments: + image_encoder (ImageEncoderViT): The backbone used to encode the + image into image embeddings that allow for efficient mask prediction. + prompt_encoder (PromptEncoder): Encodes various types of input prompts. + mask_decoder (MaskDecoder): Predicts masks from the image embeddings + and encoded prompts. + pixel_mean (list(float)): Mean values for normalizing pixels in the input image. + pixel_std (list(float)): Std values for normalizing pixels in the input image. + """ + super().__init__() + self.image_encoder = image_encoder + self.prompt_encoder = prompt_encoder + self.decoder_max_num_input_points = decoder_max_num_input_points + self.mask_decoder = mask_decoder + self.register_buffer( + "pixel_mean", torch.Tensor(pixel_mean).view(1, 3, 1, 1), False + ) + self.register_buffer( + "pixel_std", torch.Tensor(pixel_std).view(1, 3, 1, 1), False + ) + + @torch.jit.export + def predict_masks( + self, + image_embeddings: torch.Tensor, + batched_points: torch.Tensor, + batched_point_labels: torch.Tensor, + multimask_output: bool, + input_h: int, + input_w: int, + output_h: int = -1, + output_w: int = -1, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predicts masks given image embeddings and prompts. This only runs the decoder. + + Arguments: + image_embeddings: A tensor of shape [B, C, H, W] or [B*max_num_queries, C, H, W] + batched_points: A tensor of shape [B, max_num_queries, num_pts, 2] + batched_point_labels: A tensor of shape [B, max_num_queries, num_pts] + Returns: + A tuple of two tensors: + low_res_mask: A tensor of shape [B, max_num_queries, 256, 256] of predicted masks + iou_predictions: A tensor of shape [B, max_num_queries] of estimated IOU scores + """ + + batch_size, max_num_queries, num_pts, _ = batched_points.shape + num_pts = batched_points.shape[2] + rescaled_batched_points = self.get_rescaled_pts(batched_points, input_h, input_w) + + if num_pts > self.decoder_max_num_input_points: + rescaled_batched_points = rescaled_batched_points[ + :, :, : self.decoder_max_num_input_points, : + ] + batched_point_labels = batched_point_labels[ + :, :, : self.decoder_max_num_input_points + ] + elif num_pts < self.decoder_max_num_input_points: + rescaled_batched_points = F.pad( + rescaled_batched_points, + (0, 0, 0, self.decoder_max_num_input_points - num_pts), + value=-1.0, + ) + batched_point_labels = F.pad( + batched_point_labels, + (0, self.decoder_max_num_input_points - num_pts), + value=-1.0, + ) + + sparse_embeddings = self.prompt_encoder( + rescaled_batched_points.reshape( + batch_size * max_num_queries, self.decoder_max_num_input_points, 2 + ), + batched_point_labels.reshape( + batch_size * max_num_queries, self.decoder_max_num_input_points + ), + ) + sparse_embeddings = sparse_embeddings.view( + batch_size, + max_num_queries, + sparse_embeddings.shape[1], + sparse_embeddings.shape[2], + ) + low_res_masks, iou_predictions = self.mask_decoder( + image_embeddings, + self.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + multimask_output=multimask_output, + ) + _, num_predictions, low_res_size, _ = low_res_masks.shape + + if output_w > 0 and output_h > 0: + output_masks = F.interpolate( + low_res_masks, (output_h, output_w), mode="bicubic" + ) + output_masks = torch.reshape( + output_masks, + (batch_size, max_num_queries, num_predictions, output_h, output_w), + ) + else: + output_masks = torch.reshape( + low_res_masks, + ( + batch_size, + max_num_queries, + num_predictions, + low_res_size, + low_res_size, + ), + ) + iou_predictions = torch.reshape( + iou_predictions, (batch_size, max_num_queries, num_predictions) + ) + sorted_ids = torch.argsort(iou_predictions, dim=-1, descending=True) + iou_predictions = torch.take_along_dim(iou_predictions, sorted_ids, dim=2) + output_masks = torch.take_along_dim( + output_masks, sorted_ids[..., None, None], dim=2 + ) + return output_masks, iou_predictions + + def get_rescaled_pts(self, batched_points: torch.Tensor, input_h: int, input_w: int): + return torch.stack( + [ + torch.where( + batched_points[..., 0] >= 0, + batched_points[..., 0] * self.image_encoder.img_size / input_w, + -1.0, + ), + torch.where( + batched_points[..., 1] >= 0, + batched_points[..., 1] * self.image_encoder.img_size / input_h, + -1.0, + ), + ], + dim=-1, + ) + + @torch.jit.export + def get_image_embeddings(self, batched_images) -> torch.Tensor: + """ + Predicts masks end-to-end from provided images and prompts. + If prompts are not known in advance, using SamPredictor is + recommended over calling the model directly. + + Arguments: + batched_images: A tensor of shape [B, 3, H, W] + Returns: + List of image embeddings each of of shape [B, C(i), H(i), W(i)]. + The last embedding corresponds to the final layer. + """ + batched_images = self.preprocess(batched_images) + return self.image_encoder(batched_images) + + def forward( + self, + batched_images: torch.Tensor, + batched_points: torch.Tensor, + batched_point_labels: torch.Tensor, + scale_to_original_image_size: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predicts masks end-to-end from provided images and prompts. + If prompts are not known in advance, using SamPredictor is + recommended over calling the model directly. + + Arguments: + batched_images: A tensor of shape [B, 3, H, W] + batched_points: A tensor of shape [B, num_queries, max_num_pts, 2] + batched_point_labels: A tensor of shape [B, num_queries, max_num_pts] + + Returns: + A list tuples of two tensors where the ith element is by considering the first i+1 points. + low_res_mask: A tensor of shape [B, 256, 256] of predicted masks + iou_predictions: A tensor of shape [B, max_num_queries] of estimated IOU scores + """ + batch_size, _, input_h, input_w = batched_images.shape + image_embeddings = self.get_image_embeddings(batched_images) + return self.predict_masks( + image_embeddings, + batched_points, + batched_point_labels, + multimask_output=True, + input_h=input_h, + input_w=input_w, + output_h=input_h if scale_to_original_image_size else -1, + output_w=input_w if scale_to_original_image_size else -1, + ) + + def preprocess(self, x: torch.Tensor) -> torch.Tensor: + """Normalize pixel values and pad to a square input.""" + if ( + x.shape[2] != self.image_encoder.img_size + or x.shape[3] != self.image_encoder.img_size + ): + x = F.interpolate( + x, + (self.image_encoder.img_size, self.image_encoder.img_size), + mode="bilinear", + ) + return (x - self.pixel_mean) / self.pixel_std + + +def build_efficient_sam(encoder_patch_embed_dim, encoder_num_heads, checkpoint=None): + img_size = 1024 + encoder_patch_size = 16 + encoder_depth = 12 + encoder_mlp_ratio = 4.0 + encoder_neck_dims = [256, 256] + decoder_max_num_input_points = 6 + decoder_transformer_depth = 2 + decoder_transformer_mlp_dim = 2048 + decoder_num_heads = 8 + decoder_upscaling_layer_dims = [64, 32] + num_multimask_outputs = 3 + iou_head_depth = 3 + iou_head_hidden_dim = 256 + activation = "gelu" + normalization_type = "layer_norm" + normalize_before_activation = False + + assert activation == "relu" or activation == "gelu" + if activation == "relu": + activation_fn = nn.ReLU + else: + activation_fn = nn.GELU + + image_encoder = ImageEncoderViT( + img_size=img_size, + patch_size=encoder_patch_size, + in_chans=3, + patch_embed_dim=encoder_patch_embed_dim, + normalization_type=normalization_type, + depth=encoder_depth, + num_heads=encoder_num_heads, + mlp_ratio=encoder_mlp_ratio, + neck_dims=encoder_neck_dims, + act_layer=activation_fn, + ) + + image_embedding_size = image_encoder.image_embedding_size + encoder_transformer_output_dim = image_encoder.transformer_output_dim + + sam = EfficientSam( + image_encoder=image_encoder, + prompt_encoder=PromptEncoder( + embed_dim=encoder_transformer_output_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(img_size, img_size), + ), + decoder_max_num_input_points=decoder_max_num_input_points, + mask_decoder=MaskDecoder( + transformer_dim=encoder_transformer_output_dim, + transformer=TwoWayTransformer( + depth=decoder_transformer_depth, + embedding_dim=encoder_transformer_output_dim, + num_heads=decoder_num_heads, + mlp_dim=decoder_transformer_mlp_dim, + activation=activation_fn, + normalize_before_activation=normalize_before_activation, + ), + num_multimask_outputs=num_multimask_outputs, + activation=activation_fn, + normalization_type=normalization_type, + normalize_before_activation=normalize_before_activation, + iou_head_depth=iou_head_depth - 1, + iou_head_hidden_dim=iou_head_hidden_dim, + upscaling_layer_dims=decoder_upscaling_layer_dims, + ), + pixel_mean=[0.485, 0.456, 0.406], + pixel_std=[0.229, 0.224, 0.225], + ) + if checkpoint is not None: + with open(checkpoint, "rb") as f: + state_dict = torch.load(f, map_location="cpu") + sam.load_state_dict(state_dict["model"]) + return sam diff --git a/src/maxarseg/efficient_sam/efficient_sam_decoder.py b/src/maxarseg/efficient_sam/efficient_sam_decoder.py new file mode 100644 index 0000000..380f41c --- /dev/null +++ b/src/maxarseg/efficient_sam/efficient_sam_decoder.py @@ -0,0 +1,315 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Tuple, Type + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .mlp import MLPBlock + + +class PromptEncoder(nn.Module): + def __init__( + self, + embed_dim: int, + image_embedding_size: Tuple[int, int], + input_image_size: Tuple[int, int], + ) -> None: + """ + Encodes prompts for input to SAM's mask decoder. + + Arguments: + embed_dim (int): The prompts' embedding dimension + image_embedding_size (tuple(int, int)): The spatial size of the + image embedding, as (H, W). + input_image_size (int): The padded size of the image as input + to the image encoder, as (H, W). + """ + super().__init__() + self.embed_dim = embed_dim + self.input_image_size = input_image_size + self.image_embedding_size = image_embedding_size + self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) + self.invalid_points = nn.Embedding(1, embed_dim) + self.point_embeddings = nn.Embedding(1, embed_dim) + self.bbox_top_left_embeddings = nn.Embedding(1, embed_dim) + self.bbox_bottom_right_embeddings = nn.Embedding(1, embed_dim) + + def get_dense_pe(self) -> torch.Tensor: + """ + Returns the positional encoding used to encode point prompts, + applied to a dense set of points the shape of the image encoding. + + Returns: + torch.Tensor: Positional encoding with shape + 1x(embed_dim)x(embedding_h)x(embedding_w) + """ + return self.pe_layer(self.image_embedding_size).unsqueeze(0) + + def _embed_points( + self, + points: torch.Tensor, + labels: torch.Tensor, + ) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + point_embedding = self.pe_layer.forward_with_coords( + points, self.input_image_size + ) + invalid_label_ids = torch.eq(labels, -1)[:,:,None] + point_label_ids = torch.eq(labels, 1)[:,:,None] + topleft_label_ids = torch.eq(labels, 2)[:,:,None] + bottomright_label_ids = torch.eq(labels, 3)[:,:,None] + point_embedding = point_embedding + self.invalid_points.weight[:,None,:] * invalid_label_ids + point_embedding = point_embedding + self.point_embeddings.weight[:,None,:] * point_label_ids + point_embedding = point_embedding + self.bbox_top_left_embeddings.weight[:,None,:] * topleft_label_ids + point_embedding = point_embedding + self.bbox_bottom_right_embeddings.weight[:,None,:] * bottomright_label_ids + return point_embedding + + def forward( + self, + coords, + labels, + ) -> torch.Tensor: + """ + Embeds different types of prompts, returning both sparse and dense + embeddings. + + Arguments: + points: A tensor of shape [B, 2] + labels: An integer tensor of shape [B] where each element is 1,2 or 3. + + Returns: + torch.Tensor: sparse embeddings for the points and boxes, with shape + BxNx(embed_dim), where N is determined by the number of input points + and boxes. + """ + return self._embed_points(coords, labels) + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int) -> None: + super().__init__() + self.register_buffer( + "positional_encoding_gaussian_matrix", torch.randn((2, num_pos_feats)) + ) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size + device = self.positional_encoding_gaussian_matrix.device + grid = torch.ones([h, w], device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + def forward_with_coords( + self, coords_input: torch.Tensor, image_size: Tuple[int, int] + ) -> torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + return self._pe_encoding(coords.to(torch.float)) # B x N x C + + +class MaskDecoder(nn.Module): + def __init__( + self, + *, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int, + activation: Type[nn.Module], + normalization_type: str, + normalize_before_activation: bool, + iou_head_depth: int, + iou_head_hidden_dim: int, + upscaling_layer_dims: List[int], + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + transformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + if num_multimask_outputs > 1: + self.num_mask_tokens = num_multimask_outputs + 1 + else: + self.num_mask_tokens = 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + output_dim_after_upscaling = transformer_dim + + self.final_output_upscaling_layers = nn.ModuleList([]) + for idx, layer_dims in enumerate(upscaling_layer_dims): + self.final_output_upscaling_layers.append( + nn.Sequential( + nn.ConvTranspose2d( + output_dim_after_upscaling, + layer_dims, + kernel_size=2, + stride=2, + ), + nn.GroupNorm(1, layer_dims) + if idx < len(upscaling_layer_dims) - 1 + else nn.Identity(), + activation(), + ) + ) + output_dim_after_upscaling = layer_dims + + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLPBlock( + input_dim=transformer_dim, + hidden_dim=transformer_dim, + output_dim=output_dim_after_upscaling, + num_layers=2, + act=activation, + ) + for i in range(self.num_mask_tokens) + ] + ) + + self.iou_prediction_head = MLPBlock( + input_dim=transformer_dim, + hidden_dim=iou_head_hidden_dim, + output_dim=self.num_mask_tokens, + num_layers=iou_head_depth, + act=activation, + ) + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + multimask_output: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings: A tensor of shape [B, C, H, W] or [B*max_num_queries, C, H, W] + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings (the batch dimension is broadcastable). + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + """ + + ( + batch_size, + max_num_queries, + sparse_embed_dim_1, + sparse_embed_dim_2, + ) = sparse_prompt_embeddings.shape + + ( + _, + image_embed_dim_c, + image_embed_dim_h, + image_embed_dim_w, + ) = image_embeddings.shape + + # Tile the image embedding for all queries. + image_embeddings_tiled = torch.tile( + image_embeddings[:, None, :, :, :], [1, max_num_queries, 1, 1, 1] + ).view( + batch_size * max_num_queries, + image_embed_dim_c, + image_embed_dim_h, + image_embed_dim_w, + ) + sparse_prompt_embeddings = sparse_prompt_embeddings.reshape( + batch_size * max_num_queries, sparse_embed_dim_1, sparse_embed_dim_2 + ) + masks, iou_pred = self.predict_masks( + image_embeddings=image_embeddings_tiled, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + ) + if multimask_output and self.num_multimask_outputs > 1: + return masks[:, 1:, :], iou_pred[:, 1:] + else: + return masks[:, :1, :], iou_pred[:, :1] + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + output_tokens = torch.cat( + [self.iou_token.weight, self.mask_tokens.weight], dim=0 + ) + output_tokens = output_tokens.unsqueeze(0).expand( + sparse_prompt_embeddings.size(0), -1, -1 + ) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + # Expand per-image data in batch direction to be per-mask + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = image_embeddings.shape + hs, src = self.transformer(image_embeddings, pos_src, tokens) + iou_token_out = hs[:, 0, :] + mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + upscaled_embedding = src.transpose(1, 2).view(b, c, h, w) + + for upscaling_layer in self.final_output_upscaling_layers: + upscaled_embedding = upscaling_layer(upscaled_embedding) + hyper_in_list: List[torch.Tensor] = [] + for i, output_hypernetworks_mlp in enumerate(self.output_hypernetworks_mlps): + hyper_in_list.append(output_hypernetworks_mlp(mask_tokens_out[:, i, :])) + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + return masks, iou_pred diff --git a/src/maxarseg/efficient_sam/efficient_sam_encoder.py b/src/maxarseg/efficient_sam/efficient_sam_encoder.py new file mode 100644 index 0000000..73fd7ac --- /dev/null +++ b/src/maxarseg/efficient_sam/efficient_sam_encoder.py @@ -0,0 +1,257 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import List, Optional, Tuple, Type + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding""" + + def __init__( + self, + img_size, + patch_size, + in_chans, + embed_dim, + ): + super().__init__() + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + bias=True, + ) + + def forward(self, x): + B, C, H, W = x.shape + x = self.proj(x) + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads, + qkv_bias, + qk_scale=None, + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + def forward(self, x): + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = ( + qkv[0], + qkv[1], + qkv[2], + ) + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + return x + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + return x + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + act_layer=nn.GELU, + ): + super().__init__() + self.norm1 = nn.LayerNorm(dim, eps=1e-6) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + ) + self.norm2 = nn.LayerNorm(dim, eps=1e-6) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + ) + + def forward(self, x): + x = x + self.attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) + return x + + +@torch.jit.export +def get_abs_pos( + abs_pos: torch.Tensor, has_cls_token: bool, hw: List[int] +) -> torch.Tensor: + """ + Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token + dimension for the original embeddings. + Args: + abs_pos (Tensor): absolute positional embeddings with (1, num_position, C). + has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token. + hw (Tuple): size of input image tokens. + + Returns: + Absolute positional embeddings after processing with shape (1, H, W, C) + """ + h = hw[0] + w = hw[1] + if has_cls_token: + abs_pos = abs_pos[:, 1:] + xy_num = abs_pos.shape[1] + size = int(math.sqrt(xy_num)) + assert size * size == xy_num + + if size != h or size != w: + new_abs_pos = F.interpolate( + abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2), + size=(h, w), + mode="bicubic", + align_corners=False, + ) + return new_abs_pos.permute(0, 2, 3, 1) + else: + return abs_pos.reshape(1, h, w, -1) + + +# Image encoder for efficient SAM. +class ImageEncoderViT(nn.Module): + def __init__( + self, + img_size: int, + patch_size: int, + in_chans: int, + patch_embed_dim: int, + normalization_type: str, + depth: int, + num_heads: int, + mlp_ratio: float, + neck_dims: List[int], + act_layer: Type[nn.Module], + ) -> None: + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + patch_embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + act_layer (nn.Module): Activation layer. + """ + super().__init__() + + self.img_size = img_size + self.image_embedding_size = img_size // ((patch_size if patch_size > 0 else 1)) + self.transformer_output_dim = ([patch_embed_dim] + neck_dims)[-1] + self.pretrain_use_cls_token = True + pretrain_img_size = 224 + self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, patch_embed_dim) + # Initialize absolute positional embedding with pretrain image size. + num_patches = (pretrain_img_size // patch_size) * ( + pretrain_img_size // patch_size + ) + num_positions = num_patches + 1 + self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, patch_embed_dim)) + self.blocks = nn.ModuleList() + for i in range(depth): + vit_block = Block(patch_embed_dim, num_heads, mlp_ratio, True) + self.blocks.append(vit_block) + self.neck = nn.Sequential( + nn.Conv2d( + patch_embed_dim, + neck_dims[0], + kernel_size=1, + bias=False, + ), + LayerNorm2d(neck_dims[0]), + nn.Conv2d( + neck_dims[0], + neck_dims[0], + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(neck_dims[0]), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + assert ( + x.shape[2] == self.img_size and x.shape[3] == self.img_size + ), "input image size must match self.img_size" + x = self.patch_embed(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + x = x + get_abs_pos( + self.pos_embed, self.pretrain_use_cls_token, [x.shape[1], x.shape[2]] + ) + num_patches = x.shape[1] + assert x.shape[2] == num_patches + x = x.reshape(x.shape[0], num_patches * num_patches, x.shape[3]) + for blk in self.blocks: + x = blk(x) + x = x.reshape(x.shape[0], num_patches, num_patches, x.shape[2]) + x = self.neck(x.permute(0, 3, 1, 2)) + return x diff --git a/src/maxarseg/efficient_sam/mlp.py b/src/maxarseg/efficient_sam/mlp.py new file mode 100644 index 0000000..b3be8db --- /dev/null +++ b/src/maxarseg/efficient_sam/mlp.py @@ -0,0 +1,29 @@ +from typing import Type + +from torch import nn + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLPBlock(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + act: Type[nn.Module], + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Sequential(nn.Linear(n, k), act()) + for n, k in zip([input_dim] + h, [hidden_dim] * num_layers) + ) + self.fc = nn.Linear(hidden_dim, output_dim) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return self.fc(x) diff --git a/src/maxarseg/efficient_sam/two_way_transformer.py b/src/maxarseg/efficient_sam/two_way_transformer.py new file mode 100644 index 0000000..881e76f --- /dev/null +++ b/src/maxarseg/efficient_sam/two_way_transformer.py @@ -0,0 +1,264 @@ +import math +from typing import Tuple, Type +import torch +from torch import nn, Tensor +from .mlp import MLPBlock + + + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module], + normalize_before_activation: bool, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + curr_layer = TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + normalize_before_activation=normalize_before_activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + self.layers.append(curr_layer) + + self.final_attn_token_to_image = AttentionForTwoWayAttentionBlock( + embedding_dim, + num_heads, + downsample_rate=attention_downsample_rate, + ) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for idx, layer in enumerate(self.layers): + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module], + normalize_before_activation: bool, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = AttentionForTwoWayAttentionBlock(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = AttentionForTwoWayAttentionBlock( + embedding_dim, + num_heads, + downsample_rate=attention_downsample_rate, + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLPBlock( + embedding_dim, + mlp_dim, + embedding_dim, + 1, + activation, + ) + + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = AttentionForTwoWayAttentionBlock( + embedding_dim, + num_heads, + downsample_rate=attention_downsample_rate, + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor + ) -> Tuple[Tensor, Tensor]: + # Self attention block + if not self.skip_first_layer_pe: + queries = queries + query_pe + attn_out = self.self_attn(q=queries, k=queries, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class AttentionForTwoWayAttentionBlock(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert ( + self.internal_dim % num_heads == 0 + ), "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(embedding_dim, self.internal_dim) + self.v_proj = nn.Linear(embedding_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + self._reset_parameters() + + def _reset_parameters(self) -> None: + # The fan_out is incorrect, but matches pytorch's initialization + # for which qkv is a single 3*embedding_dim x embedding_dim matrix + fan_in = self.embedding_dim + fan_out = 3 * self.internal_dim + # Xavier uniform with our custom fan_out + bnd = math.sqrt(6 / (fan_in + fan_out)) + nn.init.uniform_(self.q_proj.weight, -bnd, bnd) + nn.init.uniform_(self.k_proj.weight, -bnd, bnd) + nn.init.uniform_(self.v_proj.weight, -bnd, bnd) + # out_proj.weight is left with default initialization, like pytorch attention + nn.init.zeros_(self.q_proj.bias) + nn.init.zeros_(self.k_proj.bias) + nn.init.zeros_(self.v_proj.bias) + nn.init.zeros_(self.out_proj.bias) + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Attention + _, _, _, c_per_head = q.shape + attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens + attn = attn / math.sqrt(c_per_head) + attn = torch.softmax(attn, dim=-1) + # Get output + out = attn @ v + out = self._recombine_heads(out) + out = self.out_proj(out) + return out diff --git a/src/maxarseg/explore_folders.py b/src/maxarseg/explore_folders.py new file mode 100644 index 0000000..8d02626 --- /dev/null +++ b/src/maxarseg/explore_folders.py @@ -0,0 +1,140 @@ +import os + +def create_pre_post_diz_set(event_id, root = '/nfs/projects/overwatch/maxar-segmentation/maxar-open-data/'): + """ + Create a dictionary of sets for pre and post folders. + The key is the mosaic name and the value is a set containing + the images in that mosaic + """ + pre_path = os.path.join(root, event_id, 'pre') + #print(pre_path) + post_path = os.path.join(root, event_id, 'post') + #print(post_path) + pre_diz = {} + try: + for pre_mosaic in os.listdir(pre_path): + #print('pre_mosaic', pre_mosaic) + pre_diz[pre_mosaic] = set() + for pre_img in os.listdir(os.path.join(pre_path, pre_mosaic)): + pre_diz[pre_mosaic].add(pre_img) + except: + print('No pre folder') + + post_diz = {} + try: + for post_mosaic in os.listdir(post_path): + #print('post_mosaic', post_mosaic) + post_diz[post_mosaic] = set() + for post_img in os.listdir(os.path.join(post_path, post_mosaic)): + post_diz[post_mosaic].add(post_img) + except: + print('No post folder') + + return pre_diz, post_diz + +def subtraction_between_diz(matching, pre_diz, post_diz): + """ + Function not to be used outside the check_matching_pre_post() function. + """ + + non_matching = {} + + for k_pre in pre_diz.keys(): + non_matching[k_pre] = pre_diz[k_pre] + if k_pre in matching.keys(): + non_matching[k_pre] -= matching[k_pre] + + for k_post in post_diz.keys(): + non_matching[k_post] = post_diz[k_post] + if k_post in matching.keys(): + non_matching[k_post] -= matching[k_post] + tmp = {} + for k in non_matching.keys(): + if len(non_matching[k]) != 0: + tmp[k] = non_matching[k] + non_matching = tmp + return non_matching + +def check_matching_pre_post(event_id, root = '/nfs/projects/overwatch/maxar-segmentation/maxar-open-data/', verbose=True): + """ + Params: + event_id example: 'Gambia-flooding-8-11-2022' + + Returns: + - matching: a dictionary with the mosaic name as key and the + set of matching images contained in that mosaic as value + - non_matching: a dictionary with the mosaic name as key and the + set of non matching images in that mosaic as value + (useful if you want to delete non matching images) + - diz_img_mosaic: a dictionary with the image name as key and the + set of mosaics that contain that image as value + """ + + pre_diz, post_diz = create_pre_post_diz_set(event_id, root = root) + + if verbose: + print('Pre') + for k in pre_diz.keys(): + print('-',k, '#img:',len(pre_diz[k])) + print('\nPost') + for k in post_diz.keys(): + print('-',k, '#img:',len(post_diz[k])) + + matching = {} #un diz con chiave il nome del mosaico e valore una lista contenente le immagini che matchano + diz_img_mosaic = {} #un diz con chiave il nome dell'immagine e valore il nome dei mosaici a cui appartiene + + for k_pre in pre_diz.keys(): #Per ogni mosaico pre + for k_post in post_diz.keys(): #controlla ogni mosaico post + for img_post in post_diz[k_post]: # in particolare controlla ogni immagine post + if img_post in pre_diz[k_pre]: #se l'immagine post è presente nel mosaico pre + #print(f'{img_post} è presente nel pre e nel post') + if k_pre not in matching.keys(): + matching[k_pre] = set() + if k_post not in matching.keys(): + matching[k_post] = set() + matching[k_pre].add(img_post) + matching[k_post].add(img_post) + if img_post not in diz_img_mosaic.keys(): + diz_img_mosaic[img_post] = set() + diz_img_mosaic[img_post].add(k_pre) + diz_img_mosaic[img_post].add(k_post) + + non_matching = subtraction_between_diz(matching, pre_diz, post_diz) + return matching, non_matching, diz_img_mosaic + +def count_tif_files(path): + """ + Given a path, this function explores all the subfolders and counts the number of .tif files. + """ + count = 0 + for root, dirs, files in os.walk(path): + for file in files: + if file.endswith(".tif"): + count += 1 + return count + +def compute_stats_on_event(event_id, root = '/nfs/projects/overwatch/maxar-segmentation/maxar-open-data/'): + """ + Params: + event_id: Example: "Gambia-flooding-8-11-2022" + + Returns: + No returns but it prints: + - print the number of images in the pre and post folders + - count the number of matching and non matching images + """ + + matching, non_matching, diz_img_mosaic = check_matching_pre_post(event_id, root = root) + + matching_count = 0 + for k in matching.keys(): + matching_count += len(matching[k]) + total_tif = count_tif_files(os.path.join(root, event_id)) + + print(f'\nMatching: {matching_count} images. {100 * matching_count/total_tif:.2f}%') + + non_matching_count = 0 + for k in non_matching.keys(): + non_matching_count += len(non_matching[k]) + + print(f'\nNon matching: {non_matching_count} images. {100 * non_matching_count/total_tif:.2f}%') diff --git a/src/maxarseg/geo_datasets/geoDatasets.py b/src/maxarseg/geo_datasets/geoDatasets.py new file mode 100644 index 0000000..fcd99ce --- /dev/null +++ b/src/maxarseg/geo_datasets/geoDatasets.py @@ -0,0 +1,373 @@ +from torchgeo.datasets import RasterDataset, IntersectionDataset, BoundingBox +import torch +from torch.utils.data import Dataset +import matplotlib.pyplot as plt +import os +from typing import Any, cast +from torch import Tensor +import re +import sys +import rasterio as rio +from maxarseg.samplers import samplers_utils +from rasterio.features import rasterize +from rasterio.windows import from_bounds + +class noTBoundingBox(): + def __init__(self, minx, maxx, miny, maxy): + self.minx = minx + self.maxx = maxx + self.miny = miny + self.maxy = maxy + +def single_sample_collate_fn(batch): + sample = batch[0] + sample['top_lft_index'] = [sample['top_lft_index']] + sample['bbox'] = [sample['bbox']] + return sample + +class SingleTileDataset(Dataset): + """ + To be used with SinglePatchSampler + """ + def __init__(self, tiff_path, tile_aoi_gdf, aoi_mask): + self.tiff_path = tiff_path + self.tile_aoi_gdf = tile_aoi_gdf + self.aoi_mask = aoi_mask + with rio.open(tiff_path) as src: + self.to_index = src.index + self.to_xy = src.xy + self.tile_shape = (src.height, src.width) + self.height, self.width = self.tile_shape + self.transform = src.transform #pxl to geo + self.bounds = src.bounds #geo coords + self.res = src.res[0] + + def __getitem__(self, bbox): + minx, maxx, miny, maxy = bbox #geo coords + with rio.open(self.tiff_path) as dataset: + window = from_bounds(minx, miny, maxx, maxy, dataset.transform) + patch_data = dataset.read(window=window, boundless=True) + patch_tensor = torch.from_numpy(patch_data).float() + sample = {'image': patch_tensor.unsqueeze(0), #(1,3,h,w) + 'bbox': noTBoundingBox(minx, maxx, miny, maxy), + 'top_lft_index': self.to_index(minx, maxy)} + return sample + +class MxrSingleTile(RasterDataset): + """ + A dataset for reading a single tile. + Returns a dict with: + - crs + - bbox of the sampled patch + - image patch + """ + filename_glob = "*.tif" + is_image = True + parent_tile_bbox_in_item = False #this is a parameter to chose if we want to include the parent tile bbox in the return of __getitem__ + + def __getitem__(self, query: BoundingBox) -> dict[str, Any]: + """Retrieve image/mask and metadata indexed by query. + + Args: + query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index + + Returns: + sample of image/mask and metadata at that index + + Raises: + IndexError: if query is not found in the index + """ + hits = self.index.intersection(tuple(query), objects=True) + filepaths = cast(list[str], [hit.object for hit in hits]) + + if not filepaths: + raise IndexError( + f"query: {query} not found in index with bounds: {self.bounds}" + ) + + data = self._merge_files(filepaths, query, self.band_indexes) + + sample = {"crs": self.crs, "bbox": query} + + data = data.to(self.dtype) + + if self.is_image: + sample["image"] = data + else: + sample["mask"] = data + + if self.transforms is not None: + sample = self.transforms(sample) + + if self.parent_tile_bbox_in_item: + #initialize bbox + minx = sys.float_info.max + maxx = -sys.float_info.max + miny = sys.float_info.max + maxy = -sys.float_info.max + mint = sys.float_info.max + maxt = -sys.float_info.max + + #if there are multiple hits (tiles) take the largest bbox (include both tiles) + for hit in hits: + minx = min(minx, hit[0].minx) + maxx = max(maxx, hit[1].maxx) + miny = min(miny, hit[2].miny) + maxy = max(maxy, hit[3].maxy) + mint = min(mint, hit[4].mint) + maxt = max(maxt, hit[5].maxt) + + + sample['parent_tile_bbox'] = BoundingBox(minx, maxx, miny, maxy, mint, maxt) + + + return sample + + def set_parent_tile_bbox_in_item(self): + self.parent_tile_bbox_in_item = True + + + #tr = Transformer.from_crs("EPSG:32628", "EPSG:4326") + def plot(self, sample): + # Find the correct band index order + #rgb_indices = [] + #for band in self.rgb_bands: + # rgb_indices.append(self.all_bands.index(band)) + + # Reorder and rescale the image + #tr = Transformer.from_crs("EPSG:32628", "EPSG:4326") + minx, maxx, miny, maxy = sample["bbox"].minx, sample["bbox"].maxx, sample["bbox"].miny, sample["bbox"].maxy + #sx_low = tr.transform(minx, miny) + #dx_high = tr.transform(maxx, maxy) + print('In plot') + print('Crs', self.crs) + print('sx_low: ', (minx, miny)) + print('dx_high: ', (maxx, maxy)) + image = sample["image"].permute(1, 2, 0).numpy().astype('uint8') + + # Plot the image + fig, ax = plt.subplots() + ax.imshow(image) + + return fig, ax + +class MxrSingleTileNoEmpty(RasterDataset): + """ + A dataset for reading a single tile. + Returns a dict with: + - crs + - bbox of the sampled patch + - offset (index of the top left corner of the patch in the original image) + - image patch + """ + + filename_glob = "*.tif" + is_image = True + def __init__(self, paths, tile_aoi_gdf, aoi_mask): + super().__init__(paths) + with rio.open(self.files[0]) as src: + self.to_index = src.index + self.to_xy = src.xy + self.transform = src.transform + self.tile_shape = (src.height, src.width) + + + self.tile_aoi_gdf = tile_aoi_gdf + #here tile aoi must be in proj crs + self.aoi_mask = aoi_mask + + def __getitem__(self, query: BoundingBox) -> dict[str, Any]: + """Retrieve image/mask and metadata indexed by query. + + Args: + query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index + + Returns: + sample of image/mask and metadata at that index + + Raises: + IndexError: if query is not found in the index + """ + + hits = self.index.intersection(tuple(query), objects=True) + filepaths = cast(list[str], [hit.object for hit in hits]) + + if not filepaths: + raise IndexError( + f"query: {query} not found in index with bounds: {self.bounds}" + ) + + data = self._merge_files(filepaths, query, self.band_indexes) + + sample = {"crs": self.crs, "bbox": query, "top_lft_index": self.to_index(query[0], query[3])} + + data = data.to(self.dtype) + + if self.is_image: + sample["image"] = data + else: + sample["mask"] = data + + return sample + + #tr = Transformer.from_crs("EPSG:32628", "EPSG:4326") + def plot(self, sample): + # Find the correct band index order + #rgb_indices = [] + #for band in self.rgb_bands: + # rgb_indices.append(self.all_bands.index(band)) + + # Reorder and rescale the image + #tr = Transformer.from_crs("EPSG:32628", "EPSG:4326") + minx, maxx, miny, maxy = sample["bbox"].minx, sample["bbox"].maxx, sample["bbox"].miny, sample["bbox"].maxy + #sx_low = tr.transform(minx, miny) + #dx_high = tr.transform(maxx, maxy) + print('In plot') + print('Crs', self.crs) + print('sx_low: ', (minx, miny)) + print('dx_high: ', (maxx, maxy)) + image = sample["image"].permute(1, 2, 0).numpy().astype('uint8') + + # Plot the image + fig, ax = plt.subplots() + ax.imshow(image) + + return fig, ax + +class Maxar(RasterDataset): + filename_glob = "*.tif" + is_image = True + parent_tile_bbox_in_item = False + + def __getitem__(self, query: BoundingBox) -> dict[str, Any]: + """Retrieve image/mask and metadata indexed by query. + + Args: + query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index + + Returns: + sample of image/mask and metadata at that index + + Raises: + IndexError: if query is not found in the index + """ + hits = self.index.intersection(tuple(query), objects=True) + filepaths = cast(list[str], [hit.object for hit in hits]) + + if not filepaths: + raise IndexError( + f"query: {query} not found in index with bounds: {self.bounds}" + ) + + data = self._merge_files(filepaths, query, self.band_indexes) + + sample = {"crs": self.crs, "bbox": query} + + data = data.to(self.dtype) + + if self.is_image: + sample["image"] = data + else: + sample["mask"] = data + + if self.transforms is not None: + sample = self.transforms(sample) + + if self.parent_tile_bbox_in_item: + #initialize bbox + minx = sys.float_info.max + maxx = -sys.float_info.max + miny = sys.float_info.max + maxy = -sys.float_info.max + mint = sys.float_info.max + maxt = -sys.float_info.max + + #if there are multiple hits (tiles) take the largest bbox (include both tiles) + for hit in hits: + minx = min(minx, hit[0].minx) + maxx = max(maxx, hit[1].maxx) + miny = min(miny, hit[2].miny) + maxy = max(maxy, hit[3].maxy) + mint = min(mint, hit[4].mint) + maxt = max(maxt, hit[5].maxt) + + + sample['parent_tile_bbox'] = BoundingBox(minx, maxx, miny, maxy, mint, maxt) + + + return sample + + def set_parent_tile_bbox_in_item(self): + self.parent_tile_bbox_in_item = True + + + #tr = Transformer.from_crs("EPSG:32628", "EPSG:4326") + def plot(self, sample): + # Find the correct band index order + #rgb_indices = [] + #for band in self.rgb_bands: + # rgb_indices.append(self.all_bands.index(band)) + + # Reorder and rescale the image + #tr = Transformer.from_crs("EPSG:32628", "EPSG:4326") + minx, maxx, miny, maxy = sample["bbox"].minx, sample["bbox"].maxx, sample["bbox"].miny, sample["bbox"].maxy + #sx_low = tr.transform(minx, miny) + #dx_high = tr.transform(maxx, maxy) + print('In plot') + print('Crs', self.crs) + print('sx_low: ', (minx, miny)) + print('dx_high: ', (maxx, maxy)) + image = sample["image"].permute(1, 2, 0).numpy().astype('uint8') + + #Da eliminare + #image = sample["image"].permute(1, 2, 0) + #image = torch.clamp(image / 300, min=0, max=1).numpy() + + # Plot the image + fig, ax = plt.subplots() + ax.imshow(image) + + return fig, ax + +class MaxarIntersectionDataset(IntersectionDataset): + def __init__(self,dataset1, dataset2): + super().__init__(dataset1, dataset2) + + def _merge_dataset_indices(self) -> None: + """Create a new R-tree out of the individual indices from two datasets.""" + i = 0 + ds1, ds2 = self.datasets + for hit1 in ds1.index.intersection(ds1.index.bounds, objects=True): + for hit2 in ds2.index.intersection(hit1.bounds, objects=True): + print('In merge') + print('hit1: ', hit1.object) + print('hit2: ', hit2.object) + box1 = BoundingBox(*hit1.bounds) + box2 = BoundingBox(*hit2.bounds) + self.index.insert(id = i, coordinates = tuple(box1 & box2), obj = (hit1.object, hit2.object)) + i += 1 + + if i == 0: + raise RuntimeError("Datasets have no spatiotemporal intersection") + + def plot(self, sample): + imgPre = sample["image"][:3].permute(1, 2, 0) + imgPre = torch.clamp(imgPre / 300, min=0, max=1).numpy() + imgPost = sample["image"][3:].permute(1, 2, 0) + imgPost = torch.clamp(imgPost / 300, min=0, max=1).numpy() + + # Create a figure with two subplots + fig, axs = plt.subplots(1, 2, figsize=(10, 5)) + + # Plot imgPre in the first subplot + axs[0].imshow(imgPre) + axs[0].set_title('Pre') + axs[0].axis('off') + + # Plot imgPost in the second subplot + axs[1].imshow(imgPost) + axs[1].set_title('Post') + axs[1].axis('off') + + # Display the figure + plt.show() diff --git a/src/maxarseg/main_noTGeo.py b/src/maxarseg/main_noTGeo.py new file mode 100644 index 0000000..bd4a44f --- /dev/null +++ b/src/maxarseg/main_noTGeo.py @@ -0,0 +1,51 @@ +import argparse +import torch + +from maxarseg.assemble import names +from maxarseg.assemble import holders +from maxarseg.configs import Config + +torch.set_float32_matmul_precision('medium') + +def main(): + parser = argparse.ArgumentParser(description='Segment Maxar Tiles') + parser.add_argument('--config', required= True, type = str, help='Path to the custom configuration file') + parser.add_argument('--event_ix', type = int, help='Index of the event in the list events_names') + parser.add_argument('--out_dir_root', help='output directory root') + + args = parser.parse_args() + + cfg = Config(config_path = args.config) + + if args.event_ix is not None: + cfg.set('event/ix', args.event_ix) + + if args.out_dir_root is not None: + cfg.set('output/out_dir_root', args.out_dir_root) + + # check if there is cuda, otherwise use cpu + if not torch.cuda.is_available(): + cfg.set('models/gd/device', 'cpu') + cfg.set('models/df/device', 'cpu') + cfg.set('models/esam/device', 'cpu') + + + print(cfg._data) + events_names = names.get_all_events() + event = holders.Event(events_names[cfg.get('event/ix')], cfg = cfg) + print("Selected Event: ", event.name) + + all_mosaics_names = event.all_mosaics_names + m0 = event.mosaics[all_mosaics_names[0]] #bay of bengal or Gambia + print("Selected Mosaic: ", m0.name) + + land_and_water_tile_path = '/nfs/projects/overwatch/maxar-segmentation/maxar-open-data/Gambia-flooding-8-11-2022/pre/105001002BD68F00/033133031231.tif' + only_water_tile_path = '/nfs/projects/overwatch/maxar-segmentation/maxar-open-data/Gambia-flooding-8-11-2022/pre/105001002BD68F00/033133031303.tif' + + #tile_path = m0.tiles_paths[22] #bay of bengal + + print("Selected Tile: ", land_and_water_tile_path) + m0.segment_tile(land_and_water_tile_path, args.out_dir_root, separate_masks = False, overwrite = True) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/maxarseg/main_seg_event_w_config.py b/src/maxarseg/main_seg_event_w_config.py new file mode 100644 index 0000000..0591980 --- /dev/null +++ b/src/maxarseg/main_seg_event_w_config.py @@ -0,0 +1,46 @@ +import argparse +import torch + +from maxarseg.assemble import names +from maxarseg.assemble import holders +from maxarseg.configs import Config + +torch.set_float32_matmul_precision('medium') + +def main(): + parser = argparse.ArgumentParser(description='Segment Maxar Tiles') + parser.add_argument('--config', required=True, type = str, help='Path to the custom configuration file') + parser.add_argument('--event_ix', type = int, help='Index of the event in the list events_names') + parser.add_argument('--out_dir_root', help='output directory root') + + args = parser.parse_args() + + cfg = Config(config_path = args.config) + + if args.event_ix is not None: + cfg.set('event/ix', args.event_ix) + + if args.out_dir_root is not None: + cfg.set('output/out_dir_root', args.out_dir_root) + + + # check if there is cuda, otherwise use cpu + if not torch.cuda.is_available(): + cfg.set('models/gd/device', 'cpu') + cfg.set('models/df/device', 'cpu') + cfg.set('models/esam/device', 'cpu') + + print(cfg._data) + events_names = names.get_all_events() + event = holders.Event(events_names[cfg.get('event/ix')], cfg = cfg) + print("Selected Event: ", event.name) + + #all_mosaics_names = event.all_mosaics_names + #m0 = event.mosaics[all_mosaics_names[0]] + #tile_path = '/nfs/projects/overwatch/maxar-segmentation/maxar-open-data/southafrica-flooding22/pre/105001002B1CF200/213113323300.tif' + event.seg_all_mosaics(out_dir_root=cfg.get('output/out_dir_root')) + #m0.segment_all_tiles(out_dir_root=args.out_dir_root) #this segment all tiles in the mosaic + #m0.segment_tile(tile_path, args.out_dir_root, separate_masks = False) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/maxarseg/main_seg_event_w_config_partitioned.py b/src/maxarseg/main_seg_event_w_config_partitioned.py new file mode 100644 index 0000000..ad3c7ac --- /dev/null +++ b/src/maxarseg/main_seg_event_w_config_partitioned.py @@ -0,0 +1,82 @@ +import argparse +from re import I +import torch +import math + +from maxarseg.assemble import names +from maxarseg.assemble import holders +from maxarseg.configs import Config + +torch.set_float32_matmul_precision('medium') + +def balance_groups(nums, k): + # Sort numbers in descending order + nums.sort(reverse=True) + + # Initialize k groups + groups = [[] for _ in range(k)] + sums = [0] * k + + # Distribute numbers into groups + for num in nums: + # Find the group with the smallest sum and add the number to it + idx = sums.index(min(sums)) + groups[idx].append(num) + sums[idx] += num + + return groups + +def main(): + parser = argparse.ArgumentParser(description='Segment Maxar Tiles') + parser.add_argument('--config', required=True, type = str, help='Path to the custom configuration file') + parser.add_argument('--event_ix', type = float, help='Index of the event in the list events_names') + parser.add_argument('--out_dir_root', help='output directory root') + + args = parser.parse_args() + + cfg = Config(config_path = args.config) + + if args.event_ix is not None: + cfg.set('event/ix', args.event_ix) + + if args.out_dir_root is not None: + cfg.set('output/out_dir_root', args.out_dir_root) + + + # check if there is cuda, otherwise use cpu + if not torch.cuda.is_available(): + cfg.set('models/gd/device', 'cpu') + cfg.set('models/df/device', 'cpu') + cfg.set('models/esam/device', 'cpu') + + if '.' not in str(cfg.get('event/ix')): + raise ValueError("Event ix should contain a sub partition. E.g. event/ix = 4.1") + + mos_partition = int(str(cfg.get('event/ix')).split('.')[-1]) + event_ix = int(str(cfg.get('event/ix')).split('.')[0]) + print(cfg._data) + events_names = names.get_all_events() + event = holders.Event(events_names[event_ix], cfg = cfg) + print("Selected Event: ", event.name) + + all_mosaics_names = event.all_mosaics_names + all_mosaics_names.sort() + + #There always be 10 partition x.0 to x.9 + mos_ix_s = math.floor(len(all_mosaics_names)/10) * mos_partition + if mos_partition == 9: + mos_ix_e = len(all_mosaics_names) + else: + mos_ix_e = math.floor(len(all_mosaics_names)/10) * (mos_partition + 1) + + mos_names = all_mosaics_names[mos_ix_s:mos_ix_e] + + part_num_tiles = sum([event.mosaics[k].tiles_num for k in mos_names]) + print() + print(f'This partition will segment {len(mos_names)}/{len(event.mosaics)} mosaics. {part_num_tiles}/{event.total_tiles} imgs in total') + print() + event.seg_mos_by_keys(keys = mos_names, out_dir_root=cfg.get('output/out_dir_root')) + +if __name__ == "__main__": + main() + diff --git a/src/maxarseg/main_seg_single_tile.py b/src/maxarseg/main_seg_single_tile.py new file mode 100644 index 0000000..bc59efb --- /dev/null +++ b/src/maxarseg/main_seg_single_tile.py @@ -0,0 +1,107 @@ +import argparse +import torch + +from maxarseg.assemble import names +from maxarseg.assemble import holders +from maxarseg.configs import SegmentConfig, DetectConfig + +torch.set_float32_matmul_precision('medium') + +def main(): + + events_names = names.get_all_events() + + parser = argparse.ArgumentParser(description='Segment Maxar Tiles') + #event + parser.add_argument('--event_ix', default = 2, type = int, help='Index of the event in the list events_names') + parser.add_argument('--when', default = 'pre', choices=['pre', 'post', 'None'], help='Select the pre or post event mosaics') + + #Detect config + parser.add_argument('--GD_bs', default = 1, type = int, help = 'Batch size for Grounding Dino') + parser.add_argument('--DF_bs', default = 32, type = int, help = 'Batch size for DeepForest') + + parser.add_argument('--device_det', default = 'cuda:0', help='device to use for detection') + + parser.add_argument('--size_det', default = 600, type = int, help = 'Size of the patch for detection') + parser.add_argument('--stride_det', default = 400, type = int, help = 'Stride of the patch for detection') + + parser.add_argument('--GD_root', default = "./models/GDINO", help = 'Root of the grounding dino model') + parser.add_argument('--GD_config_file', default = "configs/GroundingDINO_SwinT_OGC.py", help = 'Config file of the grounding dino model') + parser.add_argument('--GD_weights', default = "weights/groundingdino_swint_ogc.pth", help = 'Weights of the grounding dino model') + + parser.add_argument('--text_prompt', default = 'bush', help = 'Prompt for the grounding dino model') + parser.add_argument('--box_threshold', default = 0.12, type = float, help = 'Threshold for the grounding dino model') + parser.add_argument('--text_threshold', default = 0.30, type = float, help = 'Threshold for the grounding dino model') + + parser.add_argument('--max_area_GD_boxes_mt2', default = 6000, type = int, help = 'Max area of the boxes for the grounding dino model') + parser.add_argument('--min_ratio_GD_boxes_edges', default = 0.5, type = float, help = 'Min ratio between edges of the tree boxes') + parser.add_argument('--perc_reduce_tree_boxes', default = 0, type = float, help = 'Percentage of reduction of the tree boxes') + + #Segment config + parser.add_argument('--bs_seg', default = 1, type = int, help = 'Batch size for the segmentation') + parser.add_argument('--device_seg', default = 'cuda:0', help='device to use') + + parser.add_argument('--size_seg', default = 1024, type = int, help = 'Size of the patch') + parser.add_argument('--stride_seg', default = 1024 - 256, type = int, help = 'Stride of the patch') + + parser.add_argument('--ext_mt_build_box', default = 0, type = int, help = 'Extra meter to enlarge building boxes') + + parser.add_argument('--road_width_mt', default = 5, type = int, help = 'Width of the road') + + #Efficient SAM + parser.add_argument('--ESAM_root', default = './models/EfficientSAM', help = 'Root of the efficient sam model') + parser.add_argument('--ESAM_num_parall_queries', default = 5, type = int, help = 'Set the number of paraller queries to be processed') + parser.add_argument('--out_dir_root', default = "./output/tiff/prova_write_canvas", help='output directory root') + + args = parser.parse_args() + + print("Selected Event: ", events_names[args.event_ix]) + + # check if there is cuda, otherwise use cpu + if not torch.cuda.is_available(): + args.device_det = 'cpu' + args.device_seg = 'cpu' + + det_config = DetectConfig( + GD_batch_size = args.GD_bs, + DF_batch_size = args.DF_bs, + size = args.size_det, + stride = args.stride_det, + device = args.device_det, + GD_root = args.GD_root, + GD_config_file = args.GD_config_file, + GD_weights = args.GD_weights, + TEXT_PROMPT = args.text_prompt, + max_area_GD_boxes_mt2 = args.max_area_GD_boxes_mt2, + min_ratio_GD_boxes_edges = args.min_ratio_GD_boxes_edges, + perc_reduce_tree_boxes = args.perc_reduce_tree_boxes, + ) + + seg_config = SegmentConfig(batch_size = args.bs_seg, + size = args.size_seg, + stride = args.stride_seg, + device = args.device_seg, + road_width_mt=args.road_width_mt, + ext_mt_build_box=args.ext_mt_build_box, + ESAM_root = args.ESAM_root, + ESAM_num_parall_queries = args.ESAM_num_parall_queries, + use_separate_detect_config=True, + clean_masks_bool= True + ) + + event = holders.Event(events_names[args.event_ix], + seg_config = seg_config, + det_config = det_config, + when=args.when) + + all_mosaics_names = event.all_mosaics_names + + m0 = event.mosaics[all_mosaics_names[1]] + + land_and_water_tile_path = '/nfs/projects/overwatch/maxar-segmentation/maxar-open-data/Gambia-flooding-8-11-2022/pre/105001002BD68F00/033133031231.tif' + only_water_tile_path = '/nfs/projects/overwatch/maxar-segmentation/maxar-open-data/Gambia-flooding-8-11-2022/pre/105001002BD68F00/033133031303.tif' + tile_path = m0.tiles_paths[2] + m0.segment_tile(tile_path, args.out_dir_root, separate_masks = False, overwrite = True) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/maxarseg/main_seg_tile.py b/src/maxarseg/main_seg_tile.py new file mode 100644 index 0000000..43af8d8 --- /dev/null +++ b/src/maxarseg/main_seg_tile.py @@ -0,0 +1,92 @@ +import argparse + +from maxarseg.assemble import names +from maxarseg.assemble import holders +from maxarseg.configs import SegmentConfig + +def main(): + + events_names = names.get_all_events() + + parser = argparse.ArgumentParser(description='Segment Maxar Tiles') + #event + parser.add_argument('--event_ix', default = 6, type = int, help='Index of the event in the list events_names') + parser.add_argument('--when', default = 'pre', choices=['pre', 'post', 'None'], help='Select the pre or post event mosaics') + + #CONFIG + parser.add_argument('--bs', default = 2, type = int, help = 'Batch size for the dataloader') + parser.add_argument('--device', default = 'cuda:0', help='device to use') + #patch + parser.add_argument('--size', default = 600, type = int, help = 'Size of the patch') + parser.add_argument('--stride', default = 400, type = int, help = 'Stride of the patch') + + #Grounding Dino - Trees + parser.add_argument('--GD_root', default = "./models/GDINO", help = 'Root of the grounding dino model') + parser.add_argument('--GD_config_file', default = "GroundingDINO_SwinT_OGC.py", help = 'Config file of the grounding dino model') + parser.add_argument('--GD_weights', default = "groundingdino_swint_ogc.pth", help = 'Weights of the grounding dino model') + + parser.add_argument('--text_prompt', default = 'green tree', help = 'Prompt for the grounding dino model') + parser.add_argument('--box_threshold', default = 0.15, type = float, help = 'Threshold for the grounding dino model') + parser.add_argument('--text_threshold', default = 0.30, type = float, help = 'Threshold for the grounding dino model') + parser.add_argument('--max_area_GD_boxes_mt2', default = 6000, type = int, help = 'Max area of the boxes for the grounding dino model') + + parser.add_argument('--min_ratio_GD_boxes_edges', default = 0, type = float, help = 'Min ratio between edges of the tree boxes') + parser.add_argument('--perc_reduce_tree_boxes', default = 0, type = float, help = 'Percentage of reduction of the tree boxes') + + #Buildings + parser.add_argument('--ext_mt_build_box', default = 5, type = int, help = 'Width of the building') + + #Roads + parser.add_argument('--road_width_mt', default = 5, type = int, help = 'Width of the road') + + + #Efficient SAM + parser.add_argument('--ESAM_root', default = './models/EfficientSAM', help = 'Root of the efficient sam model') + parser.add_argument('--ESAM_num_parall_queries', default = 5, type = int, help = 'Set the number of paraller queries to be processed') + + + parser.add_argument('--out_dir_root', default = "./output/tiff", help='output directory root') + + args = parser.parse_args() + + print("Selected Event: ", events_names[args.event_ix]) + + config = SegmentConfig(batch_size = args.bs, + size = args.size, + stride = args.stride, + + device = args.device, + GD_root = args.GD_root, + GD_config_file = args.GD_config_file, + GD_weights = args.GD_weights, + + TEXT_PROMPT = args.text_prompt, + BOX_THRESHOLD = args.box_threshold, + TEXT_THRESHOLD = args.text_threshold, + + max_area_GD_boxes_mt2 = args.max_area_GD_boxes_mt2, + min_ratio_GD_boxes_edges = args.min_ratio_GD_boxes_edges, + perc_reduce_tree_boxes = args.perc_reduce_tree_boxes, + + road_width_mt=args.road_width_mt, + ext_mt_build_box=args.ext_mt_build_box, + + ESAM_root = args.ESAM_root, + ESAM_num_parall_queries = args.ESAM_num_parall_queries, + ) + + event = holders.Event(events_names[args.event_ix], seg_config = config, when=args.when) + all_mosaics_names = event.all_mosaics_names + + + #event.seg_all_mosaics() #this segment all the mosiacs in the event + + m0 = event.mosaics[all_mosaics_names[0]] + #m0.segment_all_tiles() #this segment all tiles in the mosaic + + m0_tile_17_path = m0.tiles_paths[17] + m0.segment_tile(m0_tile_17_path, args.out_dir_root) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/maxarseg/main_seg_tile_glbl_detections.py b/src/maxarseg/main_seg_tile_glbl_detections.py new file mode 100644 index 0000000..52a9783 --- /dev/null +++ b/src/maxarseg/main_seg_tile_glbl_detections.py @@ -0,0 +1,106 @@ +import argparse +import torch + +from maxarseg.assemble import names +from maxarseg.assemble import holders +from maxarseg.configs import SegmentConfig, DetectConfig + +torch.set_float32_matmul_precision('medium') + +def main(): + + events_names = names.get_all_events() + + parser = argparse.ArgumentParser(description='Segment Maxar Tiles') + #event + parser.add_argument('--event_ix', default = 2, type = int, help='Index of the event in the list events_names') + parser.add_argument('--when', default = 'pre', choices=['pre', 'post', 'None'], help='Select the pre or post event mosaics') + + #Detect config + parser.add_argument('--GD_bs', default = 1, type = int, help = 'Batch size for Grounding Dino') + parser.add_argument('--DF_bs', default = 1, type = int, help = 'Batch size for DeepForest') + + parser.add_argument('--device_det', default = 'cuda:0', help='device to use for detection') + + parser.add_argument('--size_det', default = 600, type = int, help = 'Size of the patch for detection') + parser.add_argument('--stride_det', default = 400, type = int, help = 'Stride of the patch for detection') + + parser.add_argument('--GD_root', default = "./models/GDINO", help = 'Root of the grounding dino model') + parser.add_argument('--GD_config_file', default = "configs/GroundingDINO_SwinT_OGC.py", help = 'Config file of the grounding dino model') + parser.add_argument('--GD_weights', default = "weights/groundingdino_swint_ogc.pth", help = 'Weights of the grounding dino model') + + parser.add_argument('--text_prompt', default = 'green tree', help = 'Prompt for the grounding dino model') + parser.add_argument('--box_threshold', default = 0.15, type = float, help = 'Threshold for the grounding dino model') + parser.add_argument('--text_threshold', default = 0.30, type = float, help = 'Threshold for the grounding dino model') + + parser.add_argument('--max_area_GD_boxes_mt2', default = 6000, type = int, help = 'Max area of the boxes for the grounding dino model') + parser.add_argument('--min_ratio_GD_boxes_edges', default = 0, type = float, help = 'Min ratio between edges of the tree boxes') + parser.add_argument('--perc_reduce_tree_boxes', default = 0, type = float, help = 'Percentage of reduction of the tree boxes') + + #Segment config + parser.add_argument('--bs_seg', default = 1, type = int, help = 'Batch size for the segmentation') + parser.add_argument('--device_seg', default = 'cuda:0', help='device to use') + + parser.add_argument('--size_seg', default = 1024, type = int, help = 'Size of the patch') + parser.add_argument('--stride_seg', default = 1024 - 256, type = int, help = 'Stride of the patch') + + parser.add_argument('--ext_mt_build_box', default = 0, type = int, help = 'Extra meter to enlarge building boxes') + + parser.add_argument('--road_width_mt', default = 5, type = int, help = 'Width of the road') + + #Efficient SAM + parser.add_argument('--ESAM_root', default = './models/EfficientSAM', help = 'Root of the efficient sam model') + parser.add_argument('--ESAM_num_parall_queries', default = 5, type = int, help = 'Set the number of paraller queries to be processed') + parser.add_argument('--out_dir_root', default = "./output/tiff/prova_write_canvas", help='output directory root') + + args = parser.parse_args() + + print("Selected Event: ", events_names[args.event_ix]) + + # check if there is cuda, otherwise use cpu + if not torch.cuda.is_available(): + args.device_det = 'cpu' + args.device_seg = 'cpu' + + det_config = DetectConfig( + GD_batch_size = args.GD_bs, + DF_batch_size = args.DF_bs, + size = args.size_det, + stride = args.stride_det, + device = args.device_det, + GD_root = args.GD_root, + GD_config_file = args.GD_config_file, + GD_weights = args.GD_weights, + TEXT_PROMPT = args.text_prompt, + max_area_GD_boxes_mt2 = args.max_area_GD_boxes_mt2, + min_ratio_GD_boxes_edges = args.min_ratio_GD_boxes_edges, + perc_reduce_tree_boxes = args.perc_reduce_tree_boxes, + ) + + seg_config = SegmentConfig(batch_size = args.bs_seg, + size = args.size_seg, + stride = args.stride_seg, + device = args.device_seg, + road_width_mt=args.road_width_mt, + ext_mt_build_box=args.ext_mt_build_box, + ESAM_root = args.ESAM_root, + ESAM_num_parall_queries = args.ESAM_num_parall_queries, + use_separate_detect_config=True, + clean_masks_bool= True + ) + + event = holders.Event(events_names[args.event_ix], + seg_config = seg_config, + det_config = det_config, + when=args.when) + + all_mosaics_names = event.all_mosaics_names + m0 = event.mosaics[all_mosaics_names[0]] + tile_path = '/nfs/projects/overwatch/maxar-segmentation/maxar-open-data/Gambia-flooding-8-11-2022/pre/105001002BD68F00/033133031213.tif' + + event.seg_all_mosaics(out_dir_root=args.out_dir_root) + # m0.segment_all_tiles(out_dir_root=args.out_dir_root) + # m0.segment_tile(tile_path, args.out_dir_root, separate_masks = False) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/maxarseg/main_seg_tile_w_config.py b/src/maxarseg/main_seg_tile_w_config.py new file mode 100644 index 0000000..0204146 --- /dev/null +++ b/src/maxarseg/main_seg_tile_w_config.py @@ -0,0 +1,52 @@ +import argparse +import torch + +from maxarseg.assemble import names +from maxarseg.assemble import holders +from maxarseg.configs import Config + +torch.set_float32_matmul_precision('medium') + +def main(): + parser = argparse.ArgumentParser(description='Segment Maxar Tiles') + parser.add_argument('--config', required= True, type = str, help='Path to the custom configuration file') + parser.add_argument('--event_ix', type = int, help='Index of the event in the list events_names') + parser.add_argument('--out_dir_root', help='output directory root') + + args = parser.parse_args() + + cfg = Config(config_path = args.config) + + if args.event_ix is not None: + cfg.set('event/ix', args.event_ix) + + if args.out_dir_root is not None: + cfg.set('output/out_dir_root', args.out_dir_root) + + # check if there is cuda, otherwise use cpu + if not torch.cuda.is_available(): + cfg.set('models/gd/device', 'cpu') + cfg.set('models/df/device', 'cpu') + cfg.set('models/esam/device', 'cpu') + + + print(cfg._data) + events_names = names.get_all_events() + event = holders.Event(events_names[cfg.get('event/ix')], cfg = cfg) + print("Selected Event: ", event.name) + + all_mosaics_names = event.all_mosaics_names + m0 = event.mosaics[all_mosaics_names[0]] #bay of bengal + #m0 = event.mosaics[all_mosaics_names[0]] + print("Selected Mosaic: ", m0.name) + + land_and_water_tile_path = '/nfs/projects/overwatch/maxar-segmentation/maxar-open-data/Gambia-flooding-8-11-2022/pre/105001002BD68F00/033133031231.tif' + only_water_tile_path = '/nfs/projects/overwatch/maxar-segmentation/maxar-open-data/Gambia-flooding-8-11-2022/pre/105001002BD68F00/033133031303.tif' + + tile_path = m0.tiles_paths[22] #bay of bengal + + print("Selected Tile: ", tile_path) + m0.segment_tile(tile_path, args.out_dir_root, separate_masks = False, overwrite = True) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/maxarseg/output.py b/src/maxarseg/output.py new file mode 100644 index 0000000..4a500c4 --- /dev/null +++ b/src/maxarseg/output.py @@ -0,0 +1,114 @@ +import rasterio +from pathlib import Path +import numpy as np +from maxarseg.ESAM_segment import segment_utils +import pandas as pd +from maxarseg.polygonize import polygonize_with_values +import geopandas as gpd + +def single_mask2Tif(tile_path, mask, out_name, out_dir_root = './output/tiff'): + """ + Converts a binary mask to a GeoTIFF file. + + Args: + tile_path (str): The path to the input tile file. + mask (numpy.ndarray): The binary mask array. + out_name (str): The name of the output GeoTIFF file. + out_dir_root (str, optional): The root directory for the output GeoTIFF file. Defaults to './output/tiff'. + + Returns: + None + """ + with rasterio.open(tile_path) as src: + out_meta = src.meta.copy() + + out_meta.update({"driver": "GTiff", + "dtype": "uint8", + "count": 1}) + out_path = Path(out_dir_root) / out_name + + with rasterio.open(out_path, 'w', **out_meta) as dest: + dest.write(mask, 1) + + print(f"Mask written in {out_path}") + +def masks2Tifs(tile_path , masks: np.ndarray, out_names: list, separate_masks: bool, out_dir_root = './output/tiff'): + if not separate_masks: #merge the masks + mask = segment_utils.merge_masks(masks) + masks = np.expand_dims(mask, axis=0) + + with rasterio.open(tile_path) as src: + out_meta = src.meta.copy() + + out_meta.update({"driver": "GTiff", + "dtype": "uint8", + "count": 1}) + masks = masks.astype(np.uint8) + for j, out_name in enumerate(out_names): + out_path = Path(out_dir_root) / out_name + with rasterio.open(out_path, 'w', **out_meta) as dest: + dest.write(masks[j], 1) + print(f"Mask written in {out_path}") + + return masks + +def gen_names(tile_path, separate_masks=False): + """ + Generate output file names based on the given tile path. + + Args: + tile_path (Path): The path to the tile file. + divide_masks (bool, optional): Whether to divide masks into separate files. Defaults to False. + + Returns: + list: A list of output file names. + """ + ev_name, tl_when, mos_name, tl_name = tile_path.parts[-4:] + masks_names = ['road', 'tree', 'building'] + + if separate_masks: + out_names = [Path(ev_name) / tl_when / mos_name / (tl_name.split('.')[0] + '_' + mask_name + '.tif') for mask_name in masks_names] + else: + out_names = [Path(ev_name) / tl_when / mos_name / (tl_name.split('.')[0] + '.tif')] + + return out_names + +def masks2parquet(tile_path , tree_build_masks: np.ndarray, road_series: pd.Series, out_names: list, out_dir_root = './output/tiff'): + with rasterio.open(tile_path) as src: + out_meta = src.meta.copy() + # convert no_overlap_masks to int + tolerances = [0.001, 0.005] + pixel_thresholds = [20, 20] + # polygonization + with rasterio.open(tile_path) as src: + out_meta = src.meta.copy() + gdf_list = [] + # convert pd.Series to gpd.GeoDataFrame + road_gdf = gpd.GeoDataFrame(road_series) + # set road_gdf class_id to 0 + road_gdf['class_id'] = 0 + # rename the columns + road_gdf.columns = ['geometry', 'class_id'] + gdf_list.append(road_gdf) + # cicling over the masks channels + for i in range(tree_build_masks.shape[0]): + if tree_build_masks[i].sum() != 0: + gdf = polygonize_with_values(tree_build_masks[i], class_id=i+1, tolerance=tolerances[i], transform=out_meta['transform'], crs=out_meta['crs'], pixel_threshold=pixel_thresholds[i]) + gdf_list.append(gdf) + crs = out_meta['crs'] + # Set the CRS of all GeoDataFrames to the same CRS + for gdf in gdf_list: + gdf.set_geometry('geometry', inplace=True) + gdf.crs = crs + gdf_list = [gdf.to_crs(crs) for gdf in gdf_list if 'geometry' in gdf.columns and gdf['geometry'].notna().all()] + # concatenate out_dir_root with out_names[0] + out_path = out_dir_root / out_names[0] + # replace '.tif' with '.parquet' + out_path = out_path.with_suffix('.parquet') + # create a single gdf + gdf = gpd.GeoDataFrame(pd.concat(gdf_list, ignore_index=True)) + print('Parquet file created at:', out_path) + # create gdf_first with the first row of gdf + assert out_names.__len__() == 1, "Only one output name is allowed for parquet file" + gdf.to_parquet(out_path) + return gdf \ No newline at end of file diff --git a/src/maxarseg/plotting_utils.py b/src/maxarseg/plotting_utils.py new file mode 100644 index 0000000..4d50549 --- /dev/null +++ b/src/maxarseg/plotting_utils.py @@ -0,0 +1,94 @@ +import matplotlib.pyplot as plt +from shapely.geometry import LineString, Point +from typing import Union, List +import numpy as np + +def show_mask(mask: np.array, ax = None, rgb_color=[30, 144, 255], alpha = 0.6, random_color = False): + """ + Take a mask that is a 2D array and show it on the axis ax + """ + if ax is None: + fig, ax = plt.subplots() + + if random_color: + color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) + else: + color = np.array([rgb_color[0]/255, rgb_color[1]/255, rgb_color[2]/255, alpha]) + h, w = mask.shape[-2:] + mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) + ax.imshow(mask_image) + +def show_Linestrings(lines: Union[List[LineString], LineString], ax, color = 'red', linewidth = 1): + """ + Plots a single or a list of shapely linestrings + """ + + if not isinstance(lines, list): + lines = [lines] + + for line in lines: + x_s, y_s = line.coords.xy + ax.plot(x_s, y_s, color=color, linewidth=linewidth) + +def show_box(boxes, ax, color='r', lw = 0.5): + """ + Plot a single or list of boxes. Where the single box is in the format [x0, y0, x1, y1] + """ + if not isinstance(boxes, list) and not isinstance(boxes, np.ndarray): + boxes = [boxes] + for box in boxes: + x0, y0 = box[0], box[1] + w, h = box[2] - box[0], box[3] - box[1] + ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor=color, facecolor=(0,0,0,0), lw=lw)) + +def show_points(coords: np.array, labels: np.array, ax, marker_size=75): + """ + Plot an array of points. + Inputs: + coords: a np array of shape (n, 2) containing the coordinates of the points + labels: a np array of shape (n,) containing the labels of the points + ax: the axis on which to plot the points + marker_size: the size of the markers + """ + if labels is None: + ax.scatter(coords[:, 0], coords[:, 1], color='b', marker='.', s=marker_size, edgecolor='white', linewidth=0.25) + else: + pos_points = coords[labels==1] + neg_points = coords[labels==0] + ax.scatter(pos_points[:, 0], pos_points[:, 1], color='blue', marker='.', s=marker_size, edgecolor='white', linewidth=0.25) + ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='.', s=marker_size, edgecolor='white', linewidth=0.25) + +def plot_comparison(img, masks, alpha = 0.6): + """ + Plot a comparison between the original image and the image with the masks. + Inputs: + img: the original image (np.array of dim (h, w, 3)) + masks: a np.array of masks (dim: (#masks, h, w) or (h, w)) + alpha: the opacity of the masks + """ + if len(masks.shape) != 3: + masks = np.expand_dims(masks, axis = 0) + + fig = plt.figure(figsize=(15, 15)) + ax1 = fig.add_subplot(1, 2, 1) + ax1.imshow(img) + + ax2 = fig.add_subplot(1, 2, 2, sharex=ax1, sharey=ax1) + ax2.imshow(img) + for i in range(masks.shape[0]): + show_mask(masks[i], ax2, random_color=True) + + #ax2.set_xlim([0, img.shape[1]]) + #ax2.set_ylim([img.shape[0], 0]) + + ax1.axis('off') + ax2.axis('off') + +def show_img(img, ax = None): + """ + Show an image on the axis ax + """ + if ax is None: + fig, ax = plt.subplots() + ax.imshow(img) + ax.axis('off') \ No newline at end of file diff --git a/src/maxarseg/polygonize.py b/src/maxarseg/polygonize.py new file mode 100644 index 0000000..7d1dfc7 --- /dev/null +++ b/src/maxarseg/polygonize.py @@ -0,0 +1,424 @@ +import rasterio as rio +import numpy as np +import geopandas as gpd +from shapely.geometry import shape, Polygon +from collections import defaultdict +import cv2 +import matplotlib.pyplot as plt +from typing import List, Tuple +from matplotlib import collections as mc + + +def onehot(components: np.ndarray) -> np.ndarray: + oh = np.zeros((components.max() + 1, *components.shape), dtype=np.uint8) + for i in range(oh.shape[0]): + oh[i][components == i] = 1 + if 0 in np.unique(components): + oh = oh[1:] + return oh + + +def apply_transform(polygon: Polygon, transform: rio.Affine) -> Polygon: + return Polygon([transform * c for c in polygon.exterior.coords]) + + +def polygonize_raster(raster: np.ndarray, tolerance: float = 0.1, transform: rio.transform.Affine = None, crs: str = None, pixel_threshold: int = 100) -> gpd.GeoDataFrame: + data = defaultdict(list) + onehot_raster = onehot(raster) + for i in range(onehot_raster.shape[0]): + mask = onehot_raster[i] + if mask.sum() < pixel_threshold: + continue + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + for contour in contours: + perimeter = cv2.arcLength(contour, True) + approx = cv2.approxPolyDP(contour, tolerance * perimeter, True) + contour = approx.squeeze() + if contour.shape[0] < 3: + continue + poly = shape({"type": "Polygon", "coordinates": [contour]}) + if transform is not None: + poly = apply_transform(poly, transform) + data["geometry"].append(poly) + data["component"].append(i) + return gpd.GeoDataFrame(data, crs=crs) + + +# version of polygonize_raster, which returns the gdf with an additional field, containing the value of the pixels contained in each polygon +def polygonize_with_values(raster: np.ndarray, class_id: int, tolerance: float = 0.1, transform: rio.transform.Affine = None, crs: str = None, pixel_threshold: int = 100) -> gpd.GeoDataFrame: + # pixel_threshold is the minimum number of pixels that a polygon must contain to be considered + data = defaultdict(list) + onehot_raster = onehot(raster) + for i in range(onehot_raster.shape[0]): + mask = onehot_raster[i] + if mask.sum() < pixel_threshold: + continue + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + for contour in contours: + perimeter = cv2.arcLength(contour, True) + approx = cv2.approxPolyDP(contour, tolerance * perimeter, True) + contour = approx.squeeze() + if contour.shape[0] < 3: + continue + poly = shape({"type": "Polygon", "coordinates": [contour]}) + if transform is not None: + poly = apply_transform(poly, transform) + data["geometry"].append(poly) + data["component"].append(i) + data["class_id"].append(class_id) + # if(raster[contour[:, 1], contour[:, 0]].mean() > 1): + # print(raster[contour[:, 1], contour[:, 0]]) + # print(raster[contour[:, 1], contour[:, 0]].mean()) + return gpd.GeoDataFrame(data, crs=crs) + + +def angle_between_lines(line1, line2): + a1, b1, _ = line1 + a2, b2, _ = line2 + angle = np.arctan2(a1*b2 - a2*b1, a1*a2 + b1*b2) + return angle + + +def minimum_angle(line1: np.ndarray, line2: np.ndarray) -> float: + """return the minimum angle between two lines""" + angle = angle_between_lines(line1, line2) + abs_angle = np.abs(angle) + sign = np.sign(angle) + if abs_angle > np.pi / 2: + angle = (np.pi - abs_angle) * -sign + return angle + + +def line_from_points(point1: tuple, point2: tuple) -> np.ndarray: + """Given two points, return the line parameters in the Hesse normal form""" + x1, y1 = point1 + x2, y2 = point2 + a = y1 - y2 + b = x2 - x1 + c = x1 * y2 - x2 * y1 + return np.array([a, b, c]) + + +def rotate_point(p: tuple, center: tuple, angle: float) -> tuple: + """rotate a point around a center by a given angle""" + x, y = p + cx, cy = center + x -= cx + y -= cy + x_new = x * np.cos(angle) - y * np.sin(angle) + y_new = x * np.sin(angle) + y * np.cos(angle) + x_new += cx + y_new += cy + return x_new, y_new + + +def middle_point(p1: tuple, p2: tuple) -> tuple: + """return the middle point between two points""" + return (p1[0] + p2[0]) / 2, (p1[1] + p2[1]) / 2 + + +def find_longest_edge(polygon: Polygon) -> Tuple[float, float]: + max_length = 0 + max_edge = None + max_index = -1 + for i in range(len(polygon.exterior.coords) - 1): + a = polygon.exterior.coords[i] + b = polygon.exterior.coords[i + 1] + length = np.linalg.norm(np.array(a) - np.array(b)) + if length > max_length: + max_length = length + max_edge = (a, b) + max_index = i + return max_edge, max_length, max_index + + +def remove_angles(polygon: Polygon, min_angle: float, max_angle: float) -> Polygon: + coords = list(polygon.exterior.coords) + i = 0 + while i < len(coords) - 2: + a = coords[i] + b = coords[i + 1] + c = coords[i + 2] + l1 = line_from_points(a, b) + l2 = line_from_points(b, c) + rads = angle_between_lines(l1, l2) + angle = abs(np.rad2deg(rads)) + if angle < min_angle or max_angle < angle: + coords.pop(i + 1) + else: + i += 1 + return Polygon(coords) + + +def rotate_segment(point1, point2, angle): + # Find the middle point + x1, y1 = point1 + x2, y2 = point2 + middle_x = (x1 + x2) / 2 + middle_y = (y1 + y2) / 2 + # Convert angle to radians + cos_angle = np.cos(angle) + sin_angle = np.sin(angle) + + # Rotate the edge by a specific angle around the middle point + new_x1 = (x1 - middle_x) * cos_angle - (y1 - middle_y) * sin_angle + middle_x + new_y1 = (x1 - middle_x) * sin_angle + (y1 - middle_y) * cos_angle + middle_y + new_x2 = (x2 - middle_x) * cos_angle - (y2 - middle_y) * sin_angle + middle_x + new_y2 = (x2 - middle_x) * sin_angle + (y2 - middle_y) * cos_angle + middle_y + + # Return the new coordinates of the rotated edge + return ((new_x1, new_y1), (new_x2, new_y2)) + + +def are_lines_parallel(line1, line2, tolerance = 1e-3): + angle = angle_between_lines(line1, line2) + return abs(angle) < tolerance or abs(angle - np.pi) < tolerance + + +def parallel_line(line, point): + a, b, c = line + x0, y0 = point + c_parallel = -a*x0 - b*y0 + parallel = np.array([a, b, c_parallel]) + return parallel + + +def longest_segment(coords: List[Tuple[float, float]]) -> Tuple[tuple, float]: + """Returns the longest segment in a list of coordinates""" + max_length = 0 + max_segment = None + for i in range(len(coords) - 1): + a = coords[i] + b = coords[i + 1] + length = np.linalg.norm(np.array(a) - np.array(b)) + if length > max_length: + max_length = length + max_segment = (a, b) + return max_segment, max_length + + +def align_polygon(polygon: Polygon) -> Polygon: + """ + Aligns the each edge of the polygon to the closest direction among the two + orientations of its minimum rotated rectangle. + """ + min_rect = polygon.minimum_rotated_rectangle + rect_coords = list(min_rect.exterior.coords) + a, b, c = rect_coords[0], rect_coords[1], rect_coords[2] + + l1 = line_from_points(rect_coords[0], rect_coords[1]) + l2 = line_from_points(rect_coords[1], rect_coords[2]) + + # iterate over polygon edges + new_coords = [] + coords = list(polygon.exterior.coords) + for i in range(len(coords) - 1): + a = coords[i] + b = coords[i + 1] + + # Find the angle between the segment and the line + l = line_from_points(a, b) + angle_1 = minimum_angle(l, l1) + angle_2 = minimum_angle(l, l2) + + if abs(angle_1) < abs(angle_2): + angle = angle_1 + else: + angle = angle_2 + # adjust the edge + ar, br = rotate_segment(a, b, angle=angle) + new_coords.append([ar, br]) + return new_coords + + +def plot_line(line: np.ndarray, xlim: tuple, ylim: tuple, **kwargs): + a, b, c = line + x1, x2 = xlim + y1, y2 = ylim + if abs(b) < 1e-3: + x = -c/a + plt.axvline(x, color='r') + elif abs(a) < 1e-3: + y = -c/b + plt.axhline(y, color='r') + else: + x = np.linspace(x1, x2, 50) + y = (-c - a*x) / b + # filter out points outside the plot + #values = [(x, y) for x, y in zip(x, y) if x1 <= x <= x2 and y1 <= y <= y2] + indices = [i for i in range(len(x)) if x1 <= x[i] <= x2 and y1 <= y[i] <= y2] + plt.plot(x[indices], y[indices], **kwargs) + + +def format_point(point: tuple) -> str: + return f"{point[0]:.2f},{point[1]:.2f}" + + +def orthogonal_line(line, point): + a, b, c = line + x0, y0 = point + a_ortho, b_ortho = -b, a + c_ortho = -a_ortho*x0 - b_ortho*y0 + ortho = np.array([a_ortho, b_ortho, c_ortho]) + return ortho + + +def filter_segments(coords: List[tuple], centroid: tuple, threshold: float = 0.1) -> List[tuple]: + result = [] + _, max_length = longest_segment(coords) + + print(f"max length: {max_length * threshold}") + plt.figure(figsize=(15, 15)) + # plt.xlim(100, 500) + # plt.ylim(100, 500) + + for i in range(len(coords)): + a, b = coords[i] + c, d = coords[(i + 1) % len(coords)] + l1 = line_from_points(a, b) + l2 = line_from_points(c, d) + + # plot_line(l1, start=min(a[0] ,b[0]) - 10, end=max(a[0] ,b[0]) + 10) + # plot_line(l2, start=min(c[0] ,d[0]) - 10, end=max(c[0] ,d[0]) + 10) + x, y = zip(*[a, b, c, d]) + plt.scatter(x, y, color="red", marker="x", s=5, zorder=5) + if are_lines_parallel(l1, l2): + l1 = orthogonal_line(l1, b) + e = intersection_point(l1, l2) + # plot_line(l1, start=min(a[0] ,b[0]) - 10, end=max(a[0] ,b[0]) + 10, color="magenta") + x, y = zip(*[e]) + plt.scatter(x, y, color="magenta", marker="x", s=150, zorder=10) + + length_a = np.linalg.norm(np.array(a) - np.array(b)) + length_b = np.linalg.norm(np.array(c) - np.array(d)) + length_c = np.linalg.norm(np.array(e) - np.array(b)) + if length_c < threshold * max_length: + if length_a > length_b: + edge = (a, b) + else: + edge = (c, d) + result.append(edge) + plt.gca().add_collection(mc.LineCollection([edge], colors="blue", linewidths=2)) + continue + + plt.gca().add_collection(mc.LineCollection([(a, b)], colors="blue", linewidths=2)) + result.append((a, b)) + + return result + + +def perpendicular_distance(point: tuple, line: np.ndarray) -> float: + """Computes the perpendicular distance between a point and a line.""" + a, b, c = line + x, y = point + return abs(a*x + b*y + c) / np.sqrt(a**2 + b**2) + + +def average(points: List[tuple]) -> tuple: + """Computes the average of a list of points.""" + x = sum([p[0] for p in points]) / len(points) + y = sum([p[1] for p in points]) / len(points) + return x, y + + +def intersection_point(line1: np.ndarray, line2: np.ndarray) -> tuple: + """Computes the intersection point of two lines, given in Hesse normal form.""" + a1, b1, c1 = line1 + a2, b2, c2 = line2 + det = a1*b2 - a2*b1 + if abs(det) < 1e-3: + raise ValueError("Lines are parallel") + x = (b2*c1 - b1*c2) / det + y = (a1*c2 - a2*c1) / det + return abs(x), abs(y) + + +def filterv2(coords: List[tuple], length_thresh: float, dist_thresh: float) -> List[tuple]: + result = [] + skip = set() + _, max_length = longest_segment(coords) + for i in range(len(coords) - 1): + if i in skip: + continue + a, b = coords[i] + # discard it immediately if it is too short + length = np.linalg.norm(np.array(a) - np.array(b)) + if length < length_thresh * max_length: + continue + # if the lines are parallel, check if the orthogonal distance is too small + # if so, compute a middle point and discard the segment + j = (i + 1) % len(coords) + should_stop = False + new_as = list() + while not should_stop: + c, d = coords[j] + l1 = line_from_points(a, b) + l2 = line_from_points(c, d) + # if they are not parallel, we can stop + # we store this to use it later, if we need to stop because of the distance + are_parallel = are_lines_parallel(l1, l2) + if not are_parallel: + should_stop = True + else: + # otherwise, check the perpendicular distance + dist = perpendicular_distance(c, l1) + if dist < dist_thresh * max_length: + a2 = intersection_point(orthogonal_line(l1, a), l2) + new_as.append(a2) + skip.add(j) + j = (j + 1) % len(coords) + continue + else: + should_stop = True + # if we should stop, we can compute new_b and add the segment + if should_stop and j not in skip: + l = l2 if not are_parallel else orthogonal_line(l2, c) + new_a = average(new_as) if len(new_as) > 0 else a + new_b = intersection_point(l, parallel_line(l1, new_a)) + result.append((new_a, new_b)) + + return result + + +def merge_segments(coords: List[tuple], threshold: float = 0.1) -> List[tuple]: + result = [] + i = 0 + j = i + while i < len(coords): + a, b = coords[i] + c, d = coords[(i + 1) % len(coords)] + if are_lines_parallel(line_from_points(a, b), line_from_points(c, d)): + length_a = np.linalg.norm(np.array(a) - np.array(b)) + length_b = np.linalg.norm(np.array(c) - np.array(d)) + length_c = np.linalg.norm(np.array(b) - np.array(c)) + if length_c < threshold * (length_a + length_b): + print("merging") + mid_c = middle_point(b, c) + line_a = line_from_points(a, b) + line_b = line_from_points(c, d) + lin_c = parallel_line(line_a, mid_c) + new_a = intersection_point(orthogonal_line(line_a, a), lin_c) + new_b = intersection_point(orthogonal_line(line_b, d), lin_c) + result.append((new_a, new_b)) + i += 2 + continue + result.append((a, b)) + i += 1 + return result + + +def link_segments(coords: List[tuple]) -> Polygon: + poly_coords = [] + for i in range(len(coords)): + a, b = coords[i] + c, d = coords[(i + 1) % len(coords)] + l1 = line_from_points(a, b) + l2 = line_from_points(c, d) + + if are_lines_parallel(l1, l2): + l1 = orthogonal_line(l1, b) + poly_coords.append(b) + + poly_coords.append(intersection_point(l1, l2)) + poly_coords.insert(0, poly_coords[-1]) + return Polygon(poly_coords) \ No newline at end of file diff --git a/src/maxarseg/samplers/samplers.py b/src/maxarseg/samplers/samplers.py new file mode 100644 index 0000000..9e5754e --- /dev/null +++ b/src/maxarseg/samplers/samplers.py @@ -0,0 +1,430 @@ +from torchgeo.samplers.utils import get_random_bounding_box, tile_to_chips +from torchgeo.samplers.single import RandomGeoSampler, GridGeoSampler +from torchgeo.datasets import GeoDataset, BoundingBox +from maxarseg.samplers.samplers_utils import path_2_tile_aoi, boundingBox_2_Polygon, boundingBox_2_centralPoint +from torchgeo.samplers.constants import Units +from typing import Optional, Union +from collections.abc import Iterator +import torch +import math +from maxarseg.geo_datasets import geoDatasets + +class SinglePatchSampler: + """ + To be used with SingleTileDataset + Sample a single patch from a dataset. + """ + def __init__(self, dataset, patch_size, stride): + self.dataset = dataset + assert patch_size > 0 + self.patch_size = patch_size #pxl + if dataset.transform[0] != -dataset.transform[4]: + raise ValueError("The pixel scale in x and y directions are different.") + self.patch_size_meters = patch_size * dataset.transform[0] + + assert stride > 0 + self.stride = stride #pxl + if self.stride is None: + self.stride = self.patch_size + self.stride_meters = self.stride * dataset.transform[0] + + def tile_to_chips(self) -> tuple[int, int]: + rows = math.ceil((self.dataset.height - self.patch_size) / self.stride) + 1 + cols = math.ceil((self.dataset.width - self.patch_size) / self.stride) + 1 + return rows, cols + + def __iter__(self): + rows, cols = self.tile_to_chips() + discarder_chips = 0 + for i in range(rows): + miny = self.dataset.bounds.bottom + i * self.stride_meters + maxy = miny + self.patch_size_meters + + # For each column... + for j in range(cols): + minx = self.dataset.bounds.left + j * self.stride_meters + maxx = minx + self.patch_size_meters + selected_bbox = geoDatasets.noTBoundingBox(minx, maxx, miny, maxy) + selected_bbox_polygon = boundingBox_2_Polygon(selected_bbox) + if self.dataset.tile_aoi_gdf.intersects(selected_bbox_polygon).any(): + yield (minx, maxx, miny, maxy) + else: + discarder_chips += 1 + continue + print('Discarded empty chips: ', discarder_chips) + print('True num of batch: ', len(self) - discarder_chips) + + def __len__(self) -> int: + return self.tile_to_chips()[0]*self.tile_to_chips()[1] + +# Samplers per Base Datasets +class MyRandomGeoSampler(RandomGeoSampler): + """ + Sample a single random bounding box from a dataset (does NOT support batches). + Check that the random bounding box is inside the tile's polygon. + """ + def __init__( + self, + dataset: GeoDataset, + size: Union[tuple[float, float], float], + length: Optional[int], + roi: Optional[BoundingBox] = None, + units: Units = Units.PIXELS, + verbose: bool = False + ) -> None: + + super().__init__(dataset, size, length, roi, units) + self.verbose = verbose + + def __iter__(self) -> Iterator[BoundingBox]: + """Return the index of a dataset. + + Returns: + (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset + """ + i = 0 + while i < len(self): + # Choose a random tile, weighted by area + idx = torch.multinomial(self.areas, 1) + hit = self.hits[idx] + + tile_path = hit.object + tile_polyg = path_2_tile_aoi(tile_path) + + bounds = BoundingBox(*hit.bounds) #TODO: ridurre i bounds usando il bbox del geojson + # Choose a random index within that tile + bounding_box = get_random_bounding_box(bounds, self.size, self.res) + #rnd_bbox_polyg = boundingBox_2_Polygon(bounding_box) + rnd_central_point = boundingBox_2_centralPoint(bounding_box) + + #se il punto centrale della rnd_bbox è nel poligono (definito con geojson) del tile + if rnd_central_point.intersects(tile_polyg): + if self.verbose: #TODO: magari in futuro togliere il verbose per velocizzare + print('In sampler') + print('tile_polyg', tile_polyg) + print() + i += 1 + yield bounding_box + else: + continue + +class MyGridGeoSampler(GridGeoSampler): + """ + Sample a single bounding box in a grid fashion from a dataset (does NOT support batches). + Check that the bounding box is inside the tile's polygon. + """ + def __init__(self, + dataset: GeoDataset, + size: Union[tuple[float, float], float], + stride: Union[tuple[float, float], float], + roi: Optional[BoundingBox] = None, + units: Units = Units.PIXELS, + ) -> None: + + super().__init__(dataset, size, stride, roi, units) + + def __iter__(self) -> Iterator[BoundingBox]: + """Return the index of a dataset. + + Returns: + (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset + """ + # For each tile... + for hit in self.hits: #These hits are all the tiles that intersect the roi (region of interest). If roi not specified then hits = all the tiles + tile_path = hit.object + tile_polygon = path_2_tile_aoi(tile_path) + + #print('In sampler') + #print('tile_polygon: ', tile_polygon) + + bounds = BoundingBox(*hit.bounds) + rows, cols = tile_to_chips(bounds, self.size, self.stride) + mint = bounds.mint + maxt = bounds.maxt + + # For each row... + for i in range(rows): + miny = bounds.miny + i * self.stride[0] + maxy = miny + self.size[0] + + # For each column... + for j in range(cols): + minx = bounds.minx + j * self.stride[1] + maxx = minx + self.size[1] + selected_bbox = BoundingBox(minx, maxx, miny, maxy, mint, maxt) + selected_bbox_polygon = boundingBox_2_Polygon(selected_bbox) + if selected_bbox_polygon.intersects(tile_polygon): + #print("selected_bbox_polygon", selected_bbox_polygon) + yield selected_bbox + else: + continue + +class WholeTifGridGeoSampler(GridGeoSampler): + """ + Sample a batch of bounding boxes from a dataset. + Returns all possible patches even if they are empty. + """ + def __init__(self, + dataset: GeoDataset, + size: Union[tuple[float, float], float], + batch_size: int, + stride: Union[tuple[float, float], float], + roi: Optional[BoundingBox] = None, + units: Units = Units.PIXELS, + ) -> None: + + super().__init__(dataset, size, stride, roi, units) + self.batch_size = batch_size + + def __iter__(self) -> Iterator[BoundingBox]: + """Return the index of a dataset. + + Returns: + (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset + """ + batch = [] + # For each tile... + for k, hit in enumerate(self.hits): #These hits are all the tiles that intersect the roi (region of interest). If roi not specified then hits = all the tiles + tile_path = hit.object + tile_polygon = path_2_tile_aoi(tile_path) + + bounds = BoundingBox(*hit.bounds) + rows, cols = tile_to_chips(bounds, self.size, self.stride) + mint = bounds.mint + maxt = bounds.maxt + + empty_chips = 0 + valid_chips = 0 + + # For each row... + for i in range(rows): + miny = bounds.miny + i * self.stride[0] + maxy = miny + self.size[0] + + # For each column... + for j in range(cols): + minx = bounds.minx + j * self.stride[1] + maxx = minx + self.size[1] + selected_bbox = BoundingBox(minx, maxx, miny, maxy, mint, maxt) + batch.append(selected_bbox) + + #Check if the selected_bbox intersects the tile_polygon (to avoid all black images) + selected_bbox_polygon = boundingBox_2_Polygon(selected_bbox) + + if selected_bbox_polygon.intersects(tile_polygon): + valid_chips += 1 + else: + empty_chips += 1 + + is_last_batch = k == len(self.hits) - 1 and i == rows - 1 and j == cols - 1 + + if len(batch) == self.batch_size or is_last_batch: + if is_last_batch and len(batch) < self.batch_size: + #print('Last batch not full. Only', len(batch), 'chips') + #pad the last batch with the last selected_bbox if it is not full + batch.extend([selected_bbox] * (self.batch_size - len(batch))) + + yield batch + batch = [] + print('Valid patches: ', valid_chips) + print('Empty patches: ', empty_chips) + + def __len__(self) -> int: + """Return the number of batches in a single epoch. + + Returns: + number of batches in an epoch + """ + return math.ceil(self.length / self.batch_size) + + + def get_num_rows_cols(self): + hit = self.hits[0] #get the first and only tile + bounds = BoundingBox(*hit.bounds) #get its bounds + return tile_to_chips(bounds, self.size, self.stride) + +class BatchGridGeoSampler(GridGeoSampler): + """ + Sample a batch of bounding boxes from a dataset in a grid fashion. + Check if the bounding box is inside the tile's polygon. + Discard empty patches. + This should be used with a dataset with only ONE tile. + """ + def __init__(self, + dataset: GeoDataset, + size: Union[tuple[float, float], float], + batch_size: int, + stride: Union[tuple[float, float], float], + roi: Optional[BoundingBox] = None, + units: Units = Units.PIXELS, + ) -> None: + + super().__init__(dataset, size, stride, roi, units) + self.batch_size = batch_size + self.dataset = dataset + + def __iter__(self) -> Iterator[BoundingBox]: + """Return the index of a dataset. + + Returns: + (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset + """ + batch = [] + # For each tile... + for k, hit in enumerate(self.hits): #These hits are all the tiles that intersect the roi (region of interest). If roi not specified then hits = all the tiles + + #print('In sampler') + #print('tile_polygon: ', tile_polygon) + + bounds = BoundingBox(*hit.bounds) + rows, cols = tile_to_chips(bounds, self.size, self.stride) + mint = bounds.mint + maxt = bounds.maxt + + discarder_chips = 0 + # For each row... + for i in range(rows): + miny = bounds.miny + i * self.stride[0] + maxy = miny + self.size[0] + + # For each column... + for j in range(cols): + minx = bounds.minx + j * self.stride[1] + maxx = minx + self.size[1] + selected_bbox = BoundingBox(minx, maxx, miny, maxy, mint, maxt) + #Check if the selected_bbox intersects the tile_polygon (to avoid all black images) + selected_bbox_polygon = boundingBox_2_Polygon(selected_bbox) + + #TODO: qui potenzialmente scartare tutte le patch che non hanno edifici o strade + #here tile_aoi must be in proj crs + if self.dataset.tile_aoi_gdf.intersects(selected_bbox_polygon).any(): + #print("selected_bbox_polygon", selected_bbox_polygon) + batch.append(selected_bbox) + else: + discarder_chips += 1 + continue + + is_last_batch = k == len(self.hits) - 1 and i == rows - 1 and j == cols - 1 + + if len(batch) == self.batch_size or is_last_batch: + if is_last_batch and len(batch) < self.batch_size: + #print('Last batch not full. Only', len(batch), 'chips') + #pad the last batch with the last selected_bbox if it is not full + batch.extend([selected_bbox] * (self.batch_size - len(batch))) + + yield batch + batch = [] + print('Discarded empty chips: ', discarder_chips) + print('True num of batch: ', len(self) - discarder_chips/self.batch_size) + + def __len__(self) -> int: + """Return the number of batches in a single epoch. + + Returns: + number of batches in an epoch + """ + return math.ceil(self.length / self.batch_size) + + +# Samplers per Intersection Datasets +class MyIntersectionRandomGeoSampler(RandomGeoSampler): + def __init__( + self, + dataset: GeoDataset, + size: Union[tuple[float, float], float], + length: Optional[int], + roi: Optional[BoundingBox] = None, + units: Units = Units.PIXELS, + ) -> None: + + super().__init__(dataset, size, length, roi, units) + + def __iter__(self) -> Iterator[BoundingBox]: + """Return the index of a dataset. + + Returns: + (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset + """ + i = 0 + while i < len(self): + # Choose a random tile, weighted by area + idx = torch.multinomial(self.areas, 1) + hit = self.hits[idx] + + tile_path1= hit.object[0] + tile_path2= hit.object[1] + + tile_polyg1 = path_2_tile_aoi(tile_path1) + tile_polyg2 = path_2_tile_aoi(tile_path2) + + bounds = BoundingBox(*hit.bounds) #TODO: ridurre i bounds usando il bbox del geojson + # Choose a random index within that tile + bounding_box = get_random_bounding_box(bounds, self.size, self.res) + rnd_bbox_polyg = boundingBox_2_Polygon(bounding_box) + rnd_central_point = boundingBox_2_centralPoint(bounding_box) + + #se il centro della bounding_box ricade nel polygono del tile1 e in quello del tile2 + # (calcolati usando il geojson) allora la bounding_box è valida + if rnd_central_point.intersects(tile_polyg1) and rnd_central_point.intersects(tile_polyg2): + print('In sampler') + print('tile_polyg1', tile_polyg1) + print('tile_polyg2', tile_polyg2) + print() + i += 1 + yield bounding_box + + else: + continue + + +class MyIntersectionGridGeoSampler(GridGeoSampler): + def __init__( + self, + dataset: GeoDataset, + size: Union[tuple[float, float], float], + stride: Union[tuple[float, float], float], + roi: Optional[BoundingBox] = None, + units: Units = Units.PIXELS, + ) -> None: + + super().__init__(dataset, size, stride, roi, units) + + + def __iter__(self) -> Iterator[BoundingBox]: + """Return the index of a dataset. + + Returns: + (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset + """ + # For each tile... + for hit in self.hits: + path_tile_1 = hit.object[0] + path_tile_2 = hit.object[1] + polyg_tile_1 = path_2_tile_aoi(path_tile_1) + polyg_tile_2= path_2_tile_aoi(path_tile_2) + + print('In sampler') + print('tile_polygon 1: ', polyg_tile_1) + print('tile_polygon 2: ', polyg_tile_2) + + bounds = BoundingBox(*hit.bounds) + rows, cols = tile_to_chips(bounds, self.size, self.stride) + mint = bounds.mint + maxt = bounds.maxt + + # For each row... + for i in range(rows): + miny = bounds.miny + i * self.stride[0] + maxy = miny + self.size[0] + + # For each column... + for j in range(cols): + minx = bounds.minx + j * self.stride[1] + maxx = minx + self.size[1] + selected_bbox = BoundingBox(minx, maxx, miny, maxy, mint, maxt) + selected_bbox_polygon = boundingBox_2_Polygon(selected_bbox) + if selected_bbox_polygon.intersects(polyg_tile_1) and selected_bbox_polygon.intersects(polyg_tile_2): + print("selected_bbox_polygon", selected_bbox_polygon) + yield selected_bbox + else: + continue \ No newline at end of file diff --git a/src/maxarseg/samplers/samplers_utils.py b/src/maxarseg/samplers/samplers_utils.py new file mode 100644 index 0000000..29373a0 --- /dev/null +++ b/src/maxarseg/samplers/samplers_utils.py @@ -0,0 +1,256 @@ +import os +import json +import shapely +from typing import List +from pathlib import Path +import numpy as np +import geopandas as gpd +from maxarseg.geo_datasets import geoDatasets +from shapely.geometry.polygon import Polygon +from shapely import geometry +import rasterio +import pandas as pd +import glob +import scipy + +#from maxarseg import segment + +def path_2_tile_aoi(tile_path, root = './metadata/from_github_maxar_metadata/datasets' ): + """ + Create a shapely Polygon from a tile_path + Example of a tile_path: '../Gambia-flooding-8-11-2022/pre/10300100CFC9A500/033133031213.tif' + """ + if isinstance(tile_path, str): + event = tile_path.split('/')[-4] + child = tile_path.split('/')[-2] + tile = tile_path.split('/')[-1].replace(".tif", "") + elif isinstance(tile_path, Path): + event = tile_path.parts[-4] + child = tile_path.parts[-2] + tile = tile_path.parts[-1].replace(".tif", "") + else: + raise TypeError("tile_path must be a string or a Path object") + + try: + path_2_child_geojson = os.path.join(root, event, child +'.geojson') + with open(path_2_child_geojson, 'r') as f: + child_geojson = json.load(f) + except: + file_pattern = str(os.path.join(root, event, child + '*inv.geojson')) + file_list = glob.glob(f"{file_pattern}") + assert len(file_list) == 1, f"Found {len(file_list)} files with pattern {file_pattern}. Expected 1 file." + path_2_child_geojson = file_list[0] + with open(path_2_child_geojson, 'r') as f: + child_geojson = json.load(f) + + + j = [el["properties"]["proj:geometry"] for el in child_geojson['features'] if el['properties']['quadkey'] == tile][0] + tile_polyg = shapely.geometry.shape(j) + return tile_polyg + +def path_2_tile_aoi_no_water(tile_path, land_gdf = None, root = './metadata/from_github_maxar_metadata/datasets' ): + """ + Create a shapely Polygon from a tile_path + Example of a tile_path: '../Gambia-flooding-8-11-2022/pre/10300100CFC9A500/033133031213.tif' + """ + if isinstance(tile_path, str): + event = tile_path.split('/')[-4] + child = tile_path.split('/')[-2] + tile = tile_path.split('/')[-1].replace(".tif", "") + elif isinstance(tile_path, Path): + event = tile_path.parts[-4] + child = tile_path.parts[-2] + tile = tile_path.parts[-1].replace(".tif", "") + else: + raise TypeError("tile_path must be a string or a Path object") + + try: + path_2_child_geojson = os.path.join(root, event, child +'.geojson') + with open(path_2_child_geojson, 'r') as f: + child_geojson = json.load(f) + except: + file_pattern = str(os.path.join(root, event, child + '*inv.geojson')) + file_list = glob.glob(f"{file_pattern}") + assert len(file_list) == 1, f"Found {len(file_list)} files with pattern {file_pattern}. Expected 1 file." + path_2_child_geojson = file_list[0] + with open(path_2_child_geojson, 'r') as f: + child_geojson = json.load(f) + + prj_crs = [el['properties']['proj:epsg'] for el in child_geojson['features'] if el['properties']['quadkey'] == tile][0] + j = [el["geometry"] for el in child_geojson['features'] if el['properties']['quadkey'] == tile][0] + tile_polyg = shapely.geometry.shape(j) + + tile_adj_aois = [] + if land_gdf is None: #caso in cui tutto l'evento non interseca confine wl + tile_adj_aois.append(tile_polyg) + else: + intersection_gdf = land_gdf.intersection(tile_polyg).loc[lambda x: ~x.is_empty] + if len(intersection_gdf) == 0: + print('Tile non interseca land. Solo mare. Mask vuota') + tile_adj_aois.append(Polygon()) + else: + if land_gdf.contains(tile_polyg).any(): + print('Completely contained in land. No mod to tile_aoi') + tile_adj_aois.append(tile_polyg) + else: + print('Tile interseca wlb') + for geom in intersection_gdf: + tile_adj_aois.append(geom) + + return gpd.GeoDataFrame(geometry = tile_adj_aois, crs="EPSG:4326").to_crs(prj_crs) + +def boundingBox_2_Polygon(bounding_box): + """ + Create a shapely Polygon from a BoundingBox + """ + minx, miny, maxx, maxy = bounding_box.minx, bounding_box.miny, bounding_box.maxx, bounding_box.maxy + vertices = [(minx, miny), (maxx, miny), (maxx, maxy), (minx, maxy), (minx, miny)] + bbox_polyg = shapely.geometry.Polygon(vertices) + return bbox_polyg + +def xyxy_2_Polygon(xyxy_box): + """ + Create a shapely Polygon from a xyxy box + """ + if not len(xyxy_box) == 4: #allow for a tuple of 2 points. E.g. ((minx, miny), (maxx, maxy)) + minx, miny = xyxy_box[0] + maxx, maxy = xyxy_box[1] + else: + minx, miny, maxx, maxy = xyxy_box + vertices = [(minx, miny), (maxx, miny), (maxx, maxy), (minx, maxy), (minx, miny)] + return shapely.geometry.Polygon(vertices) + +def xyxyBox2Polygon(xyxy_box): + """ + Create a shapely Polygon from a xyxy box + """ + minx, miny, maxx, maxy = xyxy_box + vertices = [(minx, miny), (maxx, miny), (maxx, maxy), (minx, maxy), (minx, miny)] + bbox_polyg = shapely.geometry.Polygon(vertices) + return bbox_polyg + +def boundingBox_2_centralPoint(bounding_box): + """ + Create a shapely Point from a BoundingBox + """ + minx, miny, maxx, maxy = bounding_box.minx, bounding_box.miny, bounding_box.maxx, bounding_box.maxy + return shapely.geometry.Point((minx + maxx)/2, (miny + maxy)/2) + +def align_bbox(bbox: Polygon): + """ + Turn the polygon into a bbox axis aligned + """ + minx, miny, maxx, maxy = bbox.bounds + return minx, miny, maxx, maxy + +def rel_bbox_coords(geodf:gpd.GeoDataFrame, + ref_coords:tuple, + res, + ext_mt = None): + """ + Returns the relative coordinates of a bbox w.r.t. a reference bbox in the 'geometry' column. + Goes from absolute geo coords to relative coords in the image. + + Inputs: + geodf: dataframe with bboxes + ref_coords: a tuple in the format (minx, miny, maxx, maxy) + res: resolution of the image + ext_mt: meters to add to each edge of the box (the center remains fixed) + Returns: + a list of tuples with the relative coordinates of the bboxes [(minx, miny, maxx, maxy), ...] + """ + result = [] + ref_minx, ref_maxy = ref_coords[0], ref_coords[3] #coords of top left corner of the patch sample extracted from the tile + #print('\nref_coords top left: ', ref_minx, ref_maxy ) + for geom in geodf.geometry: + minx, miny, maxx, maxy = align_bbox(geom) + if ext_mt != None or ext_mt != 0: + minx -= (ext_mt / 2) + miny -= (ext_mt / 2) + maxx += (ext_mt / 2) + maxy += (ext_mt / 2) + + rel_bbox_coords = list(np.array([minx - ref_minx, ref_maxy - maxy, maxx - ref_minx, ref_maxy - miny]) / res) + result.append(rel_bbox_coords) + + return result + +def tile_sizes(dataset: geoDatasets.MxrSingleTile): + """ + Returns the sizes of the tile given the path + It uses the + """ + bounds = dataset.bounds + x_size_pxl = (bounds.maxy - bounds.miny) / dataset.res + y_size_pxl = (bounds.maxx - bounds.minx) / dataset.res + + if x_size_pxl % 1 != 0 or y_size_pxl % 1 != 0: + raise ValueError("The sizes of the tile are not integers") + + return (int(x_size_pxl), int(y_size_pxl)) + +def tile_path_2_tile_size(tile_path): + """ + Returns the sizes of the tile given the path + """ + with rasterio.open(tile_path) as src: + return src.width, src.height + +def double_tuple_box_2_shapely_box(double_tuple_box): + """ + Create a shapely Polygon from a double tuple box + """ + minx, miny = double_tuple_box[0] + maxx, maxy = double_tuple_box[1] + return geometry.box(minx, miny, maxx, maxy) + +def road_gdf_vs_aois_gdf(proj_road_gdf, aois_gdf): + #Could be usefull but not used + num_roads = len(proj_road_gdf) + num_hits = np.array([0]*num_roads) + in_aoi_roads_gdf = gpd.GeoSeries() + for geom in aois_gdf.geometry: + intersec_geom = proj_road_gdf.intersection(geom) + valid_gdf = intersec_geom[~intersec_geom.is_empty] + num_hits = num_hits + (~intersec_geom.is_empty.values) + in_aoi_roads_gdf = gpd.GeoSeries(pd.concat([valid_gdf, in_aoi_roads_gdf], ignore_index=True)) + + if any(num_hits > 1): + raise NotImplementedError("Error: case in which a road is located in more than one area of interest. Not implemented.") + else: + return in_aoi_roads_gdf + +def filter_road_gdf_vs_aois_gdf(proj_road_gdf, aois_gdf): + num_roads = len(proj_road_gdf) + num_hits = np.array([0]*num_roads) + for geom in aois_gdf.geometry: + hits = proj_road_gdf.intersects(geom) + num_hits = num_hits + hits.values + return proj_road_gdf[num_hits >= 1] + +def intersection_road_gdf_vs_aois_gdf(proj_road_gdf, aois_gdf): + intersected_roads = gpd.GeoSeries() + num_roads = len(proj_road_gdf) + for geom in aois_gdf.geometry: + intersec_geom = proj_road_gdf.intersection(geom) + valid_gdf = intersec_geom[~intersec_geom.is_empty] + intersected_roads = gpd.GeoSeries(pd.concat([valid_gdf, intersected_roads], ignore_index=True)) + + return intersected_roads + +def entropy_from_lbl(lbl): + flat_array = lbl.flatten() + class_imp = [] + for i in [0, 1, 2, 255]: + class_imp.append(np.sum(flat_array == i)) + return scipy.stats.entropy(class_imp, base = 2) + +def compute_entropy_matrix(img, size = 1024): + if len(img.shape) == 3: + img = img.squeeze() + entropy_matrix = np.zeros((int(img.shape[0]/size), int(img.shape[1]/size))) + for i in range(int(img.shape[0]/size)): + for j in range(int(img.shape[1]/size)): + patch = img[i*size:(i+1)*size, j*size:(j+1)*size] + entropy_matrix[i, j] = entropy_from_lbl(patch) + return entropy_matrix \ No newline at end of file diff --git a/src/maxarseg/scripts/count_build_single_event.py b/src/maxarseg/scripts/count_build_single_event.py new file mode 100644 index 0000000..5c78002 --- /dev/null +++ b/src/maxarseg/scripts/count_build_single_event.py @@ -0,0 +1,180 @@ +#%% +from maxarseg.assemble import holders +import os +from pathlib import Path +from maxarseg.assemble import delimiters, names +import geopandas as gpd +import numpy as np +from maxarseg.samplers import samplers_utils +import time +import pandas as pd +import sys +import rasterio +from maxarseg.assemble import names + +def list_directories(path): + return [d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d))] +def list_tif_files(path): + return [f for f in os.listdir(path) if f.endswith('.tif') and os.path.isfile(os.path.join(path, f))] + +def list_parquet_files(path): + return [f for f in os.listdir(path) if f.endswith('.parquet') and os.path.isfile(os.path.join(path, f))] + +def filter_gdf_vs_aois_gdf(proj_gdf, aois_gdf): + num_hits = np.array([0]*len(proj_gdf)) + for geom in aois_gdf.geometry: + hits = proj_gdf.intersects(geom) + num_hits = num_hits + hits.values + return proj_gdf[num_hits >= 1] + +class Event_light: + def __init__(self, + name, + maxar_root = '/nfs/projects/overwatch/maxar-segmentation/maxar-open-data', + maxar_metadata_path = './metadata/from_github_maxar_metadata/datasets', + region = 'infer'): + + #Paths + self.maxar_root = Path(maxar_root) + self.buildings_ds_links_path = Path('./metadata/buildings_dataset_links.csv') + self.maxar_metadata_path = Path(maxar_metadata_path) + + #Event + self.name = name + self.when = 'pre' + self.region_name = names.get_region_name(self.name) if region == 'infer' else region + self.bbox = delimiters.get_event_bbox(self.name, extra_mt=1000) + self.all_mosaics_names = names.get_mosaics_names(self.name, self.maxar_root, self.when) + + self.wlb_gdf = gpd.read_file('./metadata/eventi_confini_complete.gpkg') + self.filtered_wlb_gdf = self.wlb_gdf[self.wlb_gdf['event names'] == self.name] + if self.filtered_wlb_gdf.iloc[0].geometry is None: + print('Evento interamente su terra') + self.cross_wlb = False + self.filtered_wlb_gdf = None + else: + print('Evento su bordo') + self.cross_wlb = True + + print(f'Creating event: {self.name}\nRegion: {self.region_name}\nMosaics: {self.all_mosaics_names}') + #Roads + self.road_gdf = None + + #Mosaics + self.mosaics = {} + + #Init mosaics + for m_name in self.all_mosaics_names: + self.mosaics[m_name] = holders.Mosaic(m_name, self) + + self.total_tiles = sum([mosaic.tiles_num for mosaic in self.mosaics.values()]) + + def __str__(self) -> str: + res = f'\n_______________________________________________________\nEvent: {self.name}\nMosaics: {self.all_mosaics_names}\nTotal tiles: {self.total_tiles}\n_______________________________________________________\n' + return res + +#%% +def main(): + print('Starting...', flush=True) + event_idx = sys.argv[1] + lbl_root_folder = '/nfs/projects/overwatch/maxar-segmentation/outputs/04_05/train' #cartella delle label + ev_name = names.get_all_events()[int(event_idx)] + cols = {'event': [], + 'mosaic': [], + 'tile': [], + 'num_ms_build_aoi': [], + 'num_ms_build_aoi_no_water_for': [], + 'num_ms_build_aoi_no_water_sjoin': [], + 'parquet_build': [], + 'parquet_tree': [], + 'bg_pxl': [], + 'road_pxl': [], + 'tree_pxl': [], + 'build_pxl': [], + 'entropy': [] + } + high_res_entropies = [] + + mos_names = list_directories(os.path.join(lbl_root_folder, ev_name, 'pre')) + mos_names = sorted(mos_names) + if len(mos_names) > 0: + event = Event_light(ev_name) #this event will contain all the events in the imgs + print(event, flush=True) + + current_mos = 0 + for mos_name in mos_names: #only mos in lbl + current_mos += 1 + print(f'\n{ev_name}/{mos_name}', flush=True) + mos = event.mosaics[mos_name] + try: + mos.set_build_gdf() + except: + print(f'No buildings in {ev_name}/{mos_name}', file=sys.stderr) + continue + print(f'len build gdf {len(mos.build_gdf):,}', flush=True) + tif_names = list_tif_files(os.path.join(lbl_root_folder, ev_name, 'pre', mos_name)) + tif_names = sorted(tif_names) + parquets_names = list_parquet_files(os.path.join(lbl_root_folder, ev_name, 'pre', mos_name)) + if len(tif_names) != len(parquets_names): + print(f'{ev_name}/{mos_name}. Not corresponcence between tifs and parquet, lists are not of equal length.', file=sys.stderr) + + print('proc_tifs and parquet:', len(tif_names), flush=True) + print() + current_tif = 0 + for tile_name in tif_names: + current_tif += 1 + print(f'{ev_name}/pre/{mos_name}/{tile_name}, tif:({current_tif}/{len(tif_names)}), mos:({current_mos}/{len(mos_names)})', flush=True) + cols['event'].append(ev_name) + cols['mosaic'].append(mos_name) + cols['tile'].append(tile_name) + + tile_path = os.path.join(lbl_root_folder, ev_name, 'pre', mos_name, tile_name) + parquet_path = tile_path[:-4] + '.parquet' + #tile_aoi = gpd.GeoDataFrame({'geometry': [samplers_utils.path_2_tile_aoi(tile_path)]}) + + num_aoi_build = len(mos.proj_build_gdf.iloc[mos.sindex_proj_build_gdf.query(samplers_utils.path_2_tile_aoi(tile_path))]) + cols['num_ms_build_aoi'].append(num_aoi_build) + #print('tile_aoi builds', num_aoi_build) + #print() + + tile_aoi_no_water = samplers_utils.path_2_tile_aoi_no_water(tile_path, event.filtered_wlb_gdf) + num_ms_build_aoi_no_water_for = len(filter_gdf_vs_aois_gdf(mos.proj_build_gdf, tile_aoi_no_water)) + cols['num_ms_build_aoi_no_water_for'].append(num_ms_build_aoi_no_water_for) + #print('tile_aoi builds no water filter with for', num_ms_build_aoi_no_water_for) + #print() + + num_ms_build_aoi_no_water_sjoin = len(gpd.sjoin(mos.proj_build_gdf, tile_aoi_no_water, how='inner', op='intersects')) + cols['num_ms_build_aoi_no_water_sjoin'].append(num_ms_build_aoi_no_water_sjoin) + #print('tile_aoi builds no water filter with sjoin', num_ms_build_aoi_no_water_sjoin) + #print() + if os.path.exists(parquet_path): + parquet_build = sum(pd.read_parquet(parquet_path, engine='pyarrow').class_id == 2) + parquet_trees = sum(pd.read_parquet(parquet_path, engine='pyarrow').class_id == 1) + else: + print(f'{parquet_path} do not exists', file=sys.stderr) + parquet_build = -1 + parquet_trees = -1 + cols['parquet_build'].append(parquet_build) + cols['parquet_tree'].append(parquet_trees) + + + with rasterio.open(tile_path) as src: #here read the lbl + lbl = src.read() + tot_pxl = lbl.shape[-1]**2 #assume every img is squared + cols['bg_pxl'].append(np.sum(lbl == 255)/tot_pxl) + cols['road_pxl'].append(np.sum(lbl == 0)/tot_pxl) + cols['tree_pxl'].append(np.sum(lbl == 1)/tot_pxl) + cols['build_pxl'].append(np.sum(lbl == 2)/tot_pxl) + cols['entropy'].append(samplers_utils.entropy_from_lbl(lbl)) + + high_res_entropies.append(samplers_utils.compute_entropy_matrix(lbl)) + + high_res_entropies_np = np.stack(high_res_entropies, axis=0) + np.save(f'{ev_name}_high_res_entropies.npy', high_res_entropies_np) + print(f'Saved {ev_name}_high_res_entropies.npy', flush=True) + res_df = pd.DataFrame(cols) + res_df.to_csv(f'{ev_name}_lbl_stats.csv', index = True) + print(f'Saved {ev_name}_lbl_stats.csv', flush=True) + +if __name__ == "__main__": + main() diff --git a/src/maxarseg/scripts/count_buildings_and_roads.py b/src/maxarseg/scripts/count_buildings_and_roads.py new file mode 100644 index 0000000..e6bff3d --- /dev/null +++ b/src/maxarseg/scripts/count_buildings_and_roads.py @@ -0,0 +1,43 @@ +from maxarseg import build +from tqdm import tqdm +import pandas as pd + +def main(): + df = pd.read_csv('./output/stats_roadnBuild.csv') + print(df) + tot_buildings = 0 + tot_roads = 0 + config = build.SegmentConfig(batch_size = 4, device='cpu') + for event_name in tqdm(build.get_all_events()): + if event_name in df['event_name'].values: + continue + try: + event_build_num = 0 + event_road_num = 0 + print('\n', event_name) + evento = build.Event(event_name, seg_config = config, when='pre') + evento.set_all_mos_road_gdf() + evento.set_build_gdf_all_mos() + for _, mos in evento.mosaics.items(): + event_build_num += mos.build_num + event_road_num += mos.road_num + + new_row = pd.DataFrame({'event_name': [event_name], 'num_road': [event_road_num], 'num_build': [event_build_num]}) + df = pd.concat([df, new_row], ignore_index=True) + tot_buildings += event_build_num + tot_roads += event_road_num + print(f'Event: {event_name}, Total buildings: {event_build_num}, Total roads: {event_road_num}') + except Exception as e: + print(f'Error in {event_name}') + df.to_csv('./output/stats_roadnBuild.csv', index=False) + print(f"Caught an exception: {e}") + return + + print('Total buildings: ', tot_buildings) + print('Total roads: ', tot_roads) + df.to_csv('./output/stats_roadnBuild.csv', index=False) + return + +if __name__ == "__main__": + main() + \ No newline at end of file diff --git a/src/maxarseg/scripts/count_builds.py b/src/maxarseg/scripts/count_builds.py new file mode 100644 index 0000000..c2d8a04 --- /dev/null +++ b/src/maxarseg/scripts/count_builds.py @@ -0,0 +1,186 @@ +#%% +from maxarseg.assemble import holders +import os +from pathlib import Path +from maxarseg.assemble import delimiters, names +import geopandas as gpd +import numpy as np +from maxarseg.samplers import samplers_utils +import time +import pandas as pd +import sys +import rasterio + +def list_directories(path): + return [d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d))] +def list_tif_files(path): + return [f for f in os.listdir(path) if f.endswith('.tif') and os.path.isfile(os.path.join(path, f))] + +def list_parquet_files(path): + return [f for f in os.listdir(path) if f.endswith('.parquet') and os.path.isfile(os.path.join(path, f))] + +def filter_gdf_vs_aois_gdf(proj_gdf, aois_gdf): + num_hits = np.array([0]*len(proj_gdf)) + for geom in aois_gdf.geometry: + hits = proj_gdf.intersects(geom) + num_hits = num_hits + hits.values + return proj_gdf[num_hits >= 1] + +class Event_light: + def __init__(self, + name, + maxar_root = '/nfs/projects/overwatch/maxar-segmentation/maxar-open-data', + maxar_metadata_path = './metadata/from_github_maxar_metadata/datasets', + region = 'infer'): + + #Paths + self.maxar_root = Path(maxar_root) + self.buildings_ds_links_path = Path('./metadata/buildings_dataset_links.csv') + self.maxar_metadata_path = Path(maxar_metadata_path) + + #Event + self.name = name + self.when = 'pre' + self.region_name = names.get_region_name(self.name) if region == 'infer' else region + self.bbox = delimiters.get_event_bbox(self.name, extra_mt=1000) + self.all_mosaics_names = names.get_mosaics_names(self.name, self.maxar_root, self.when) + + self.wlb_gdf = gpd.read_file('./metadata/eventi_confini_complete.gpkg') + self.filtered_wlb_gdf = self.wlb_gdf[self.wlb_gdf['event names'] == self.name] + if self.filtered_wlb_gdf.iloc[0].geometry is None: + print('Evento interamente su terra') + self.cross_wlb = False + self.filtered_wlb_gdf = None + else: + print('Evento su bordo') + self.cross_wlb = True + + print(f'Creating event: {self.name}\nRegion: {self.region_name}\nMosaics: {self.all_mosaics_names}') + #Roads + self.road_gdf = None + + #Mosaics + self.mosaics = {} + + #Init mosaics + for m_name in self.all_mosaics_names: + self.mosaics[m_name] = holders.Mosaic(m_name, self) + + self.total_tiles = sum([mosaic.tiles_num for mosaic in self.mosaics.values()]) + + def __str__(self) -> str: + res = f'\n_______________________________________________________\nEvent: {self.name}\nMosaics: {self.all_mosaics_names}\nTotal tiles: {self.total_tiles}\n_______________________________________________________\n' + return res + +#%% +def main(): + print('Starting...', flush=True) + lbl_root_folder = '/nfs/projects/overwatch/maxar-segmentation/outputs/04_05/train' #cartella delle label + ev_names = list_directories(lbl_root_folder) + ev_names = sorted(ev_names) + cols = {'event': [], + 'mosaic': [], + 'tile': [], + 'num_ms_build_aoi': [], + 'num_ms_build_aoi_no_water_for': [], + 'num_ms_build_aoi_no_water_sjoin': [], + 'parquet_build': [], + 'parquet_tree': [], + 'bg_pxl': [], + 'road_pxl': [], + 'tree_pxl': [], + 'build_pxl': [], + 'entropy': [] + } + high_res_entropies = [] + len_ev = len(ev_names) + current_ev = 0 + for ev_name in ev_names: + if ev_name == 'Morocco-Earthquake-Sept-2023' or ev_name == 'Hurricane-Ian-9-26-2022': + continue + current_ev += 1 + mos_names = list_directories(os.path.join(lbl_root_folder, ev_name, 'pre')) + mos_names = sorted(mos_names) + if len(mos_names) > 0: + event = Event_light(ev_name) #this event will contain all the events in the imgs + print(event, flush=True) + + current_mos = 0 + for mos_name in mos_names: #only mos in lbl + current_mos += 1 + print(f'\n{ev_name}/{mos_name}', flush=True) + mos = event.mosaics[mos_name] + try: + mos.set_build_gdf() + print(f'len build gdf {len(mos.build_gdf):,}', flush=True) + except: + print(f'No buildings in {ev_name}/{mos_name}', file=sys.stderr) + continue + tif_names = list_tif_files(os.path.join(lbl_root_folder, ev_name, 'pre', mos_name)) + tif_names = sorted(tif_names) + parquets_names = list_parquet_files(os.path.join(lbl_root_folder, ev_name, 'pre', mos_name)) + if len(tif_names) != len(parquets_names): + print(f'{ev_name}/{mos_name}. Not corresponcence between tifs and parquet, lists are not of equal length.', file=sys.stderr) + + print('proc_tifs and parquet:', len(tif_names), flush=True) + print() + current_tif = 0 + for tile_name in tif_names: + current_tif += 1 + print(f'{ev_name}/pre/{mos_name}/{tile_name}, tif:({current_tif}/{len(tif_names)}), mos:({current_mos}/{len(mos_names)}), ev:({current_ev}/{len_ev})', flush=True) + cols['event'].append(ev_name) + cols['mosaic'].append(mos_name) + cols['tile'].append(tile_name) + + tile_path = os.path.join(lbl_root_folder, ev_name, 'pre', mos_name, tile_name) + parquet_path = tile_path[:-4] + '.parquet' + #tile_aoi = gpd.GeoDataFrame({'geometry': [samplers_utils.path_2_tile_aoi(tile_path)]}) + + num_aoi_build = len(mos.proj_build_gdf.iloc[mos.sindex_proj_build_gdf.query(samplers_utils.path_2_tile_aoi(tile_path))]) + cols['num_ms_build_aoi'].append(num_aoi_build) + #print('tile_aoi builds', num_aoi_build) + #print() + + tile_aoi_no_water = samplers_utils.path_2_tile_aoi_no_water(tile_path, event.filtered_wlb_gdf) + num_ms_build_aoi_no_water_for = len(filter_gdf_vs_aois_gdf(mos.proj_build_gdf, tile_aoi_no_water)) + cols['num_ms_build_aoi_no_water_for'].append(num_ms_build_aoi_no_water_for) + #print('tile_aoi builds no water filter with for', num_ms_build_aoi_no_water_for) + #print() + + num_ms_build_aoi_no_water_sjoin = len(gpd.sjoin(mos.proj_build_gdf, tile_aoi_no_water, how='inner', op='intersects')) + cols['num_ms_build_aoi_no_water_sjoin'].append(num_ms_build_aoi_no_water_sjoin) + #print('tile_aoi builds no water filter with sjoin', num_ms_build_aoi_no_water_sjoin) + #print() + if os.path.exists(parquet_path): + parquet_build = sum(pd.read_parquet(parquet_path, engine='pyarrow').class_id == 2) + parquet_trees = sum(pd.read_parquet(parquet_path, engine='pyarrow').class_id == 1) + else: + print(f'{parquet_path} do not exists', file=sys.stderr) + parquet_build = -1 + parquet_trees = -1 + cols['parquet_build'].append(parquet_build) + cols['parquet_tree'].append(parquet_trees) + + + with rasterio.open(tile_path) as src: #here read the lbl + lbl = src.read() + tot_pxl = lbl.shape[-1]**2 #assume every img is squared + cols['bg_pxl'].append(np.sum(lbl == 255)/tot_pxl) + cols['road_pxl'].append(np.sum(lbl == 0)/tot_pxl) + cols['tree_pxl'].append(np.sum(lbl == 1)/tot_pxl) + cols['build_pxl'].append(np.sum(lbl == 2)/tot_pxl) + cols['entropy'].append(samplers_utils.entropy_from_lbl(lbl)) + + high_res_entropies.append(samplers_utils.compute_entropy_matrix(lbl)) + + #if current_ev == 2: + # break + high_res_entropies_np = np.stack(high_res_entropies, axis=0) + np.save('high_res_entropies.npy', high_res_entropies_np) + print('Saved high_res_entropies.npy', flush=True) + res_df = pd.DataFrame(cols) + res_df.to_csv('lbl_stats.csv', index = True) + print('Saved lbl_stats.csv', flush=True) + +if __name__ == "__main__": + main() diff --git a/src/maxarseg/scripts/downloadMaxar.py b/src/maxarseg/scripts/downloadMaxar.py new file mode 100644 index 0000000..fc29f2f --- /dev/null +++ b/src/maxarseg/scripts/downloadMaxar.py @@ -0,0 +1,86 @@ +import leafmap +from tqdm import tqdm +import geopandas as gpd +import pandas as pd +import os +import argparse +from pathlib import Path + +if Path.cwd().name != 'src': + os.chdir('./src') + +# Read the csv file +events_df = pd.read_csv('../metadata/dateEventi.csv', sep=';') +# Create a dictionary with the event name as key and the date as value +event2date = events_df.set_index('Aligned name')['date'].to_dict() + + +def get_pre_post_gdf_local(collection_id, event2date = event2date, local_gdf = True): + + #Retrieve the event date + try: + event_date = event2date[collection_id] + except: + print("ERROR: Event date not found!!!") + return None, None + + #Create the geodataframe + geojson_path = '../metadata/from_github_maxar_metadata/datasets' + if local_gdf: + gdf = gpd.read_file(os.path.join(geojson_path, collection_id + '.geojson')) + else: + gdf = gpd.GeoDataFrame() + for child_id in tqdm(leafmap.maxar_child_collections(collection_id)): + current_gdf = leafmap.maxar_items( + collection_id = collection_id, + child_id = child_id, + return_gdf=True, + assets=['visual'], + ) + gdf = pd.concat([gdf, current_gdf]) + + + #Split the geodataframe + pre_gdf = gdf[gdf['datetime'] < event_date] + post_gdf = gdf[gdf['datetime'] >= event_date] + + print('Collection_id:',collection_id,'\nEvent date:', event_date) + + if pre_gdf.shape[0] + post_gdf.shape[0] == gdf.shape[0]: + print("OK: All items are accounted for\n") + else: + print("ERROR: Some items are missing!!!\n") + + print("pre_gdf", pre_gdf.shape) + print("post_gdf", post_gdf.shape) + + return pre_gdf, post_gdf + +def download_event(collection_id, out_dir_root = "/nfs/projects/overwatch/maxar-segmentation/maxar-open-data/"): + + pre_gdf, post_gdf = get_pre_post_gdf_local(collection_id) + if pre_gdf is None or post_gdf is None: + return + + leafmap.maxar_download(pre_gdf['visual'].to_list(), out_dir = os.path.join(out_dir_root, collection_id, 'pre', "")) + leafmap.maxar_download(post_gdf['visual'].to_list(), out_dir = os.path.join(out_dir_root, collection_id, 'post', "")) + + +def main(): + pareser = argparse.ArgumentParser(description='Download Maxar images') + pareser.add_argument('--c_id', help='single or list of collection id you want to download') + pareser.add_argument('--out_dir', default = "/nfs/projects/overwatch/maxar-segmentation/maxar-open-data/", help='output directory') + + args = pareser.parse_args() + + if args.c_id is None: + collection_ids = leafmap.maxar_collections() + else: + collection_ids = [args.c_id] + + for collection_id in collection_ids: + download_event(collection_id, args.out_dir) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/src/maxarseg/scripts/downloadRoads.py b/src/maxarseg/scripts/downloadRoads.py new file mode 100644 index 0000000..aeaecba --- /dev/null +++ b/src/maxarseg/scripts/downloadRoads.py @@ -0,0 +1,28 @@ +import pandas as pd +import os +import requests +import zipfile +import io +from pathlib import Path +import argparse + +def download_roads(meta_root, output_folder): + meta_root = Path(meta_root) + road_links_df = pd.read_csv( meta_root / 'roads_links.csv') + output_folder = Path(output_folder) + + for i, row in road_links_df.iterrows(): + url = row['link'] + filename = url.split("/")[-1].replace("zip", "tsv") + print(filename) + if not os.path.exists(output_folder / filename): + response = requests.get(url) + zip_file = zipfile.ZipFile(io.BytesIO(response.content)) + zip_file.extractall(path=output_folder) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--meta_root', type=str, default='./metadata') + parser.add_argument('--output_folder', type=str, default='/nfs/projects/overwatch/maxar-segmentation/microsoft-roads') + args = parser.parse_args() + download_roads(args.meta_root, args.output_folder) \ No newline at end of file diff --git a/src/maxarseg/scripts/make-gis-friendly.py b/src/maxarseg/scripts/make-gis-friendly.py new file mode 100644 index 0000000..a1a8ba6 --- /dev/null +++ b/src/maxarseg/scripts/make-gis-friendly.py @@ -0,0 +1,29 @@ +""" +This snippet demonstrates how to access and convert the buildings +data from .csv.gz to geojson for use in common GIS tools. You will +need to install pandas, geopandas, and shapely. +from https://github.com/microsoft/GlobalMLBuildingFootprints/blob/main/scripts/make-gis-friendly.py +""" + +import pandas as pd +import geopandas as gpd +from shapely.geometry import shape +from tqdm import tqdm + +def main(): + # this is the name of the geography you want to retrieve. update to meet your needs + location = 'Morocco' + + dataset_links = pd.read_csv("/nfs/home/vaschetti/maxarSrc/metadata/buildings_dataset_links_24_05_08.csv") + print("/nfs/home/vaschetti/maxarSrc/metadata/buildings_dataset_links_24_05_08.csv") + greece_links = dataset_links[dataset_links.Location == location] + for _, row in tqdm(greece_links.iterrows()): + df = pd.read_json(row.Url, lines=True) + df['geometry'] = df['geometry'].apply(shape) + gdf = gpd.GeoDataFrame(df, crs=4326) + gdf.to_file(f"/nfs/home/vaschetti/maxarSrc/morocco_builds_24_05_08/{row.QuadKey}.geojson", driver="GeoJSON") + print("Done") + + +if __name__ == "__main__": + main() \ No newline at end of file