-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprediction.py
78 lines (56 loc) · 2.62 KB
/
prediction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import rasterio
import tensorflow as tf
import os
import numpy as np
from glob import glob
from datagen import CustomImageGeneratorPrediction
from tools import *
import yaml
from sklearn.preprocessing import MinMaxScaler
# Read data from config file
if os.path.exists("config_prediction.yaml"):
with open('config_prediction.yaml') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
if data['model']['Unet_Sen1_Sen2']:
sentinel1_folder = data['data_source']['Sentinel1']
sentinel2_folder = data['data_source']['Sentinel2']
sentinel_paths = glob("{}/*.tif".format(sentinel2_folder)) + glob("{}/*.tif".format(sentinel1_folder))
sentinel_paths.sort()
model_path = data['model']['Unet_Sen1_Sen2']
elif data['model']['Unet_Sen2']:
sentinel2_folder = data['data_source']['Sentinel2']
sentinel_paths = glob("{}/*.tif".format(sentinel2_folder))
sentinel_paths.sort()
model_path = data['model']['Unet_Sen2']
output_folder = data["output_folder"]
patching = True
model = tf.keras.models.load_model(model_path, compile=False, custom_objects={'dice_coef': dice_coef})
patch_size = model.input_shape[1]
if patching:
bands_patches = {}
for idx, band in enumerate(sentinel_paths):
band_name = os.path.basename(band).split("_")[-1].split(".")[0]
print("Start patching with band: ", band_name)
raster = rasterio.open(band)
if raster.transform[0] != 10:
raster = resampleRaster(band, 10)
r_array = raster.ReadAsArray()
r_array = np.expand_dims(r_array, axis=0)
else:
r_array = raster.read()[:,:10980,:10980]
r_array = np.moveaxis(r_array, 0, -1)
r_array = np.nan_to_num(r_array)
a,b = 0,1
c,d = np.percentile(r_array, [0.1, 99.9])
r_array_norm = (b-a)*((r_array-c)/(d-c))+a
r_array_norm[r_array_norm > 1] = 1
r_array_norm[r_array_norm < 0] = 0
bands_patches[band_name] = patchifyRasterAsArray(r_array_norm, patch_size)
patches_path = savePatchesPredict(bands_patches, output_folder)
patches_path = glob(r"{}/Crops/img/*.tif".format(output_folder))
patches_path = sorted(patches_path, key = lambda x: int(x.split("_")[-1].split(".")[0]))
patch_array = load_img_as_array(patches_path[0])
patch_xy = (patch_array.shape[0], patch_array.shape[1])
b_count = patch_array.shape[-1]
predict_datagen = CustomImageGeneratorPrediction(patches_path, patch_xy, b_count)
predictPatches(model, predict_datagen, sentinel_paths[4], output_folder)