-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_sgrids.py
138 lines (93 loc) · 4.33 KB
/
main_sgrids.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import glob
import xarray as xr
import altair as alt
import pandas as pd
import streamlit as st
from PIL import Image
from itertools import product
import matplotlib.pyplot as plt
from streamlit_option_menu import option_menu
METHODS = ['pvc_unconstrained','pvc_constrained','pwcv_basic','pwcv_robust']
TITLES = dict(zip(METHODS, ['Unconstrained V-curve', 'Constrained V-curve', 'Basic WCV', 'Robust WCV']))
def plot_main_altair(df):
df = df.drop('spatial_ref', axis=1).reset_index()
df = df.melt(id_vars=['longitude','latitude'])
df.columns = ['longitude','latitude','method','sopt']
df['method'] = [TITLES[m] for m in df.method.values]
chart = alt.Chart(df).mark_point().encode(
x='longitude:Q',
y='latitude:Q',
color='sopt:Q'
).properties(
width=180,
height=180
).facet(
facet='method:N',
columns=2
)
st.altair_chart(chart, use_container_width=True)
def plot_main_plt(tile, index):
ds = read_data(tile, index)
fig, axs = plt.subplots(2, 2, sharex=True, sharey=True)
for i,j in product(range(2), range(2)):
da = ds[METHODS[2 * i + j]]
im = da.plot.imshow(ax=axs[i,j], vmin=-4, vmax=4, add_colorbar=False,)
axs[i,j].set_title(TITLES[METHODS[2 * i + j]])
axs[i,j].set_xlabel('', fontsize=0)
axs[i,j].set_ylabel('', fontsize=0)
fig.subplots_adjust(right=0.875) #also try using kwargs bottom, top, or hspace
cbar_ax = fig.add_axes([0.1, -0.1, .8, .05]) #left, bottom, width, height
fig.colorbar(im, cax=cbar_ax, orientation="horizontal")
st.pyplot(fig, use_container_width=True)
def read_data(tile, ind):
try:
ds = xr.open_mfdataset(f'data/sgrids_{ind.lower()}/{METHODS[0]}/{tile}', engine='zarr').rename({'sg':METHODS[0]})
for m in METHODS[1:]:
ds[m] = xr.open_mfdataset(f'data/sgrids_{ind.lower()}/{m}/{tile}', engine='zarr').rename({'sg':m})[m]
return ds
except:
return dict(zip(METHODS, [None, None, None, None]))
@st.cache_data # _resource # No need for TTL this time. It's static data :)
def get_data_by_state():
tiles = glob.glob('data/sgrids_ndvi/pwcv_basic/*')
tiles = sorted([p.split('\\')[-1] for p in tiles]) # \
ndvi_tiles = dict(zip(tiles, [read_data(tile, 'ndvi') for tile in tiles]))
#tda_tiles = dict(zip(tiles, [read_data(tile, 'tda') for tile in tiles]))
#tna_tiles = dict(zip(tiles, [read_data(tile, 'tna') for tile in tiles]))
tiles = list(ndvi_tiles.keys())
return tiles #, tda_tiles, tna_tiles
def main():
# =============================================================================
# Layout
# =============================================================================
st.set_page_config(layout='wide')
# =============================================================================
# Data
# =============================================================================
tiles = get_data_by_state()
# =============================================================================
# Widgets inputs
# =============================================================================
with st.sidebar:
st.title("S grids inspector")
grid = st.checkbox('Show MODIS sinusoidal grid', value=False)
tile = st.selectbox('Tile', tiles)
st.markdown('------------')
index = option_menu("Choose index", ["NDVI", "TDA", "TNA"],
icons=['tree','thermometer-half','thermometer-half'],
menu_icon="app-indicator", default_index=0,
orientation='vertical')
# =============================================================================
# Selection of dataset
# =============================================================================
#all_ds_indexes = dict(NDVI = ndvi_tiles, TDA = tda_tiles, TNA = tna_tiles)
#ds_to_plot = all_ds_indexes[index][tile]
# =============================================================================
# Main plot
# =============================================================================
if grid:
st.image(Image.open('data/MODIS_sinusoidal_grid1.gif'), use_column_width=True)
else:
plot_main_plt(tile, index) #.to_dataframe())
if __name__ == "__main__":
main()