Skip to content

Commit

Permalink
update example.py
Browse files Browse the repository at this point in the history
  • Loading branch information
2320sharon committed Feb 21, 2024
1 parent bdadc8b commit 00452e6
Showing 1 changed file with 71 additions and 10 deletions.
81 changes: 71 additions & 10 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

# load modules
import os
import json
import numpy as np
import pickle
import warnings
Expand All @@ -31,14 +32,18 @@
# can also be loaded from a .kml polygon
# kml_polygon = os.path.join(os.getcwd(), 'examples', 'NARRA_polygon.kml')
# polygon = SDS_tools.polygon_from_kml(kml_polygon)
# or read from geojson polygon (create it from https://geojson.io/)
# geojson_polygon = os.path.join(os.getcwd(), 'examples', 'NARRA_polygon.geojson')
# polygon = SDS_tools.polygon_from_geojson(geojson_polygon)
# convert polygon to a smallest rectangle (sides parallel to coordinate axes)
polygon = SDS_tools.smallest_rectangle(polygon)

# date range
dates = ['1984-01-01', '2022-01-01']

# dates = ['1984-01-01', '2022-01-01']
dates = ['2023-12-01', '2024-01-31']
# satellite missions
sat_list = ['L5','L7','L8']
# sat_list = ['L5','L7','L8']
sat_list = ['L9']
collection = 'C02' # choose Landsat collection 'C01' or 'C02'
# name of the site
sitename = 'NARRA'
Expand Down Expand Up @@ -88,11 +93,17 @@
'cloud_mask_issue': False, # switch this parameter to True if sand pixels are masked (in black) on many images
'sand_color': 'default', # 'default', 'latest', 'dark' (for grey/black sand beaches) or 'bright' (for white sand beaches)
'pan_off': False, # True to switch pansharpening off for Landsat 7/8/9 imagery
's2cloudless_prob': 40, # threshold to identify cloud pixels in the s2cloudless probability mask
# add the inputs defined previously
'inputs': inputs,
'create_plot': True, #True create a matplotlib plot of the image with the datetime as the title. False save as a standard JPG
}

# fn_animation = os.path.join(inputs['filepath'],inputs['sitename'], '%s_animation_RGB.mp4'%inputs['sitename'])
# fp_images = os.path.join(inputs['filepath'], inputs['sitename'], 'jpg_files', 'preprocessed')
# fps = 4 # frames per second in animation
# SDS_tools.make_animation_mp4(fp_images, fps, fn_animation)

# [OPTIONAL] preprocess images (cloud masking, pansharpening/down-sampling)
# SDS_preprocess.save_jpg(metadata, settings)

Expand All @@ -119,6 +130,13 @@
gdf.to_file(os.path.join(inputs['filepath'], inputs['sitename'], '%s_output_%s.geojson'%(sitename,geomtype)),
driver='GeoJSON', encoding='utf-8')

# # create MP4 timelapse animation
# fn_animation = os.path.join(inputs['filepath'],inputs['sitename'], '%s_animation_shorelines.mp4'%inputs['sitename'])
# fp_images = os.path.join(inputs['filepath'], inputs['sitename'], 'jpg_files', 'detection')
# fps = 4 # frames per second in animation
# SDS_tools.make_animation_mp4(fp_images, fps, fn_animation)


# plot the mapped shorelines
plt.ion()
fig = plt.figure(figsize=[15,8], tight_layout=True)
Expand All @@ -134,11 +152,54 @@
fig.savefig(os.path.join(inputs['filepath'], inputs['sitename'], 'mapped_shorelines.jpg'),dpi=200)

#%% 4. Shoreline analysis
def load_data_from_json(filepath: str) -> dict:
"""
Reads data from a JSON file and returns it as a dictionary.
The function reads the data from the specified JSON file using the provided filepath.
It applies a custom object hook, `DecodeDateTime`, to decode the datetime and shoreline
data if they exist in the dictionary.
Args:
filepath (str): Path to the JSON file.
Returns:
dict: Data read from the JSON file as a dictionary.
"""

def DecodeDateTime(readDict):
"""
Helper function to decode datetime and shoreline data in the dictionary.
Args:
readDict (dict): Dictionary to decode.
Returns:
dict: Decoded dictionary.
"""
if "dates" in readDict:
tmp = [
datetime.fromisoformat(dates) for dates in readDict["dates"]
]
readDict["dates"] = tmp
if "shorelines" in readDict:
tmp = [
np.array(shoreline) if len(shoreline) > 0 else np.empty((0, 2))
for shoreline in readDict["shorelines"]
]
readDict["shorelines"] = tmp
return readDict

with open(filepath, "r") as fp:
data = json.load(fp, object_hook=DecodeDateTime)
return data


# if you have already mapped the shorelines, load the output.pkl file
filepath = os.path.join(inputs['filepath'], sitename)
with open(os.path.join(filepath, sitename + '_output' + '.pkl'), 'rb') as f:
output = pickle.load(f)
output = load_data_from_json(os.path.join(filepath, sitename + '_output' + '.json'))
# remove duplicates (images taken on the same date by the same satellite)
output = SDS_tools.remove_duplicates(output)
# remove inaccurate georeferencing (set threshold to 10 m)
Expand Down Expand Up @@ -193,7 +254,7 @@
'max_range': 30, # max range for points around transect
'min_chainage': -100, # largest negative value along transect (landwards of transect origin)
'multiple_inter': 'auto', # mode for removing outliers ('auto', 'nan', 'max')
'prc_multiple': 0.1, # percentage to use in 'auto' mode to switch from 'nan' to 'max'
'auto_prc': 0.1, # percentage to use in 'auto' mode to switch from 'nan' to 'max'
}
cross_distance = SDS_transects.compute_intersection_QC(output, transects, settings_transects)

Expand Down Expand Up @@ -239,7 +300,7 @@
# load the measured tide data
filepath = os.path.join(os.getcwd(),'examples','NARRA_tides.csv')
tide_data = pd.read_csv(filepath, parse_dates=['dates'])
dates_ts = [_.to_pydatetime() for _ in tide_data['dates']]
dates_ts = [pd.to_datetime(_).to_pydatetime() for _ in tide_data['dates']]
tides_ts = np.array(tide_data['tide'])

# get tide levels corresponding to the time of image acquisition
Expand Down Expand Up @@ -358,7 +419,7 @@

# plot seasonal averages
fig,ax=plt.subplots(1,1,figsize=[14,4],tight_layout=True)
ax.grid(b=True,which='major', linestyle=':', color='0.5')
ax.grid(which='major', linestyle=':', color='0.5')
ax.set_title('Time-series at %s'%key, x=0, ha='left')
ax.set(ylabel='distance [m]')
ax.plot(dates_nonan, chainage,'+', lw=1, color='k', mfc='w', ms=4, alpha=0.5,label='raw datapoints')
Expand All @@ -384,7 +445,7 @@

# plot seasonal averages
fig,ax=plt.subplots(1,1,figsize=[14,4],tight_layout=True)
ax.grid(b=True,which='major', linestyle=':', color='0.5')
ax.grid(which='major', linestyle=':', color='0.5')
ax.set_title('Time-series at %s'%key, x=0, ha='left')
ax.set(ylabel='distance [m]')
ax.plot(dates_nonan, chainage,'+', lw=1, color='k', mfc='w', ms=4, alpha=0.5,label='raw datapoints')
Expand Down Expand Up @@ -627,4 +688,4 @@
ax[1].text(j+1,median_data[j]+1, '%.1f' % median_data[j], horizontalalignment='center', fontsize=14)
ax[1].text(j+1+0.35,median_data[j]+1, ('n=%.d' % int(n_data[j])), ha='center', va='center', fontsize=12, rotation='vertical')
ax[1].set(ylabel='error [m]', ylim=sett['lims']);
fig.savefig(os.path.join(os.getcwd(),'examples','comparison_all_transects.jpg'), dpi=150)
fig.savefig(os.path.join(os.getcwd(),'examples','comparison_all_transects.jpg'), dpi=150)

0 comments on commit 00452e6

Please sign in to comment.