Skip to content

Commit

Permalink
#179 add new crs parameter to test_merge_sessions.py
Browse files Browse the repository at this point in the history
  • Loading branch information
2320sharon committed Dec 21, 2023
1 parent 0a783ac commit e6619e7
Showing 1 changed file with 75 additions and 26 deletions.
101 changes: 75 additions & 26 deletions scripts/test_merge_sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@
TEST_DATA_LOCATION = r"C:\development\doodleverse\coastseg\CoastSeg\test_data"
SAVE_LOCATION = os.path.join(TEST_DATA_LOCATION, "merged_sessions")


# helper functions
# --------------
def get_unique_dates_from_geojson_files(gdfs:list[gpd.GeoDataFrame]):
def get_unique_dates_from_geojson_files(gdfs: list[gpd.GeoDataFrame]):
unique_dates = set()
for gdf in gdfs:
unique_dates.update(gdf.date)
return unique_dates


def check_geojson_files(gdf1, gdf2, columns: list[str]):
"""
Check if the specified columns exist in both GeoDataFrame objects and if their values match.
Expand All @@ -35,10 +37,13 @@ def check_geojson_files(gdf1, gdf2, columns: list[str]):
if isinstance(columns, str):
columns = [columns]
for column in columns:
assert (column in gdf1.columns) and (column in gdf2.columns), f"{column} column missing in one of the files"
assert (column in gdf1.columns) and (
column in gdf2.columns
), f"{column} column missing in one of the files"
assert set(gdf1[column]) == set(gdf2[column]), f"{column} do not match"

def assert_all_files_exist(dest:str):

def assert_all_files_exist(dest: str):
"""
Check if all the required files exist in the specified destination directory.
Expand All @@ -55,7 +60,8 @@ def assert_all_files_exist(dest:str):
assert os.path.exists(os.path.join(dest, "merged_config.geojson"))
assert os.path.exists(os.path.join(dest, "transect_time_series.csv"))

def verify_merged_session(dest:str):

def verify_merged_session(dest: str):
"""
Verify the consistency of a merged session.
Expand All @@ -69,31 +75,50 @@ def verify_merged_session(dest:str):
None
"""
# 1. read in extracted_shorelines_points.geojson from merged session
shoreline_points_gdf = merge_utils.read_first_geojson_file(dest, "extracted_shorelines_points.geojson")
shoreline_points_gdf = merge_utils.read_first_geojson_file(
dest, "extracted_shorelines_points.geojson"
)
# 2. read in extracted_shorelines_lines.geojson from merged session
shoreline_lines_gdf = merge_utils.read_first_geojson_file(dest, "extracted_shorelines_lines.geojson")
shoreline_lines_gdf = merge_utils.read_first_geojson_file(
dest, "extracted_shorelines_lines.geojson"
)
# 3.verify that the 'date' and 'satname' columns are present and consistent in the geojson files
check_geojson_files(shoreline_points_gdf, shoreline_lines_gdf, ['date', 'satname'])
check_geojson_files(shoreline_points_gdf, shoreline_lines_gdf, ["date", "satname"])
# 4. read in the extracted_shoreline_dict.json from merged session
extracted_shorelines_dict=file_utilities.load_data_from_json(os.path.join(dest, "extracted_shorelines_dict.json"))
extracted_shorelines_dict = file_utilities.load_data_from_json(
os.path.join(dest, "extracted_shorelines_dict.json")
)
# 5. Check if all the dates & satellites in the geojson files are present in the dictionary
columns = ['dates', 'satname']
columns = ["dates", "satname"]
for column in columns:
if column == 'dates':
assert shoreline_points_gdf['date'].isin(extracted_shorelines_dict.get(column)).all()
if column == "dates":
assert (
shoreline_points_gdf["date"]
.isin(extracted_shorelines_dict.get(column))
.all()
)
else:
assert shoreline_points_gdf[column].isin(extracted_shorelines_dict.get(column)).all()
assert (
shoreline_points_gdf[column]
.isin(extracted_shorelines_dict.get(column))
.all()
)
# 6. Read in the transects_time_series.csv from merged session
transect_time_series= pd.read_csv(os.path.join(dest, "transect_time_series.csv"))
transect_time_series = pd.read_csv(os.path.join(dest, "transect_time_series.csv"))
# 8. Check if all the dates in the geojson files are present in the csv file
assert shoreline_points_gdf['date'].isin(transect_time_series['dates']).all()
assert shoreline_points_gdf["date"].isin(transect_time_series["dates"]).all()
# 9. Check if the length of dates in the dictionary is the same as the number of dates in the csv file
assert len(extracted_shorelines_dict.get('dates')) == len(transect_time_series['dates'])
assert len(extracted_shorelines_dict.get("dates")) == len(
transect_time_series["dates"]
)
# 10. Check if the length of all the values in the dictionary is the same as the number of dates in the csv file
for key in extracted_shorelines_dict.keys():
assert len(extracted_shorelines_dict.get(key)) == len(transect_time_series['dates'])

def validate_dates(session_locations:list[str], dest:str):
assert len(extracted_shorelines_dict.get(key)) == len(
transect_time_series["dates"]
)


def validate_dates(session_locations: list[str], dest: str):
"""
Validates that the dates before and after merging shoreline geojson files are the same.
Expand All @@ -104,7 +129,7 @@ def validate_dates(session_locations:list[str], dest:str):
Raises:
AssertionError: If the dates before and after merging are not the same.
"""
# get the dates from the shoreline geojson files located in the original sessions
# get the dates from the shoreline geojson files located in the original sessions
gdfs = merge_utils.process_geojson_files(
session_locations,
["extracted_shorelines_points.geojson", "extracted_shorelines.geojson"],
Expand All @@ -118,7 +143,7 @@ def validate_dates(session_locations:list[str], dest:str):
merged_dates = set(merged_shorelines["date"])
# check that the dates are the same before and after merging
assert unique_dates == merged_dates

# read the merged geojson file and check if all the dates are present
# get the dates from the shoreline geojson files located in the new merged session
merged_gdfs = merge_utils.process_geojson_files(
Expand All @@ -133,6 +158,7 @@ def validate_dates(session_locations:list[str], dest:str):
assert unique_dates == result_unique_dates
assert merged_dates == result_unique_dates


def clear_directory(directory):
"""
Deletes all the contents of the specified directory.
Expand All @@ -149,8 +175,10 @@ def clear_directory(directory):
except Exception as e:
print(f"Failed to delete {file_path}. Reason: {e}")


# ----------------------------


def test_with_all_arguments(monkeypatch):
# Test case 1: Required arguments provided
test_args = [
Expand All @@ -162,6 +190,8 @@ def test_with_all_arguments(monkeypatch):
"merged_session",
"-s",
"save_location",
"-c",
"epsg:32610",
"-ad",
"30",
"-mp",
Expand All @@ -170,6 +200,12 @@ def test_with_all_arguments(monkeypatch):
"20",
"-mr",
"20",
"-pm",
"0.2",
"-mi",
"nan",
"-mc",
"-110",
]
monkeypatch.setattr(sys, "argv", test_args)
args = parse_arguments()
Expand All @@ -180,13 +216,22 @@ def test_with_all_arguments(monkeypatch):
assert args.min_points == 5
assert args.max_std == 20
assert args.max_range == 20
assert args.min_chainage == -100
assert args.multiple_inter == "auto"
assert args.prc_multiple == 0.1
assert args.min_chainage == -110
assert args.multiple_inter == "nan"
assert args.prc_multiple == 0.2


def test_with_mandatory_arguments_only(monkeypatch):
test_args = ["program_name", "-i", "session1", "session2", "-n", "merged_session"]
test_args = [
"program_name",
"-i",
"session1",
"session2",
"-n",
"merged_session",
"-c",
"epsg:32610",
]
monkeypatch.setattr(sys, "argv", test_args)
args = parse_arguments()
assert args.session_locations == ["session1", "session2"]
Expand Down Expand Up @@ -221,6 +266,7 @@ def test_main_with_overlapping():
session_locations=session_locations,
save_location=SAVE_LOCATION,
merged_session_name=merged_session_name,
crs="EPSG:32610",
along_dist=25,
min_points=3,
max_std=15,
Expand All @@ -240,7 +286,7 @@ def test_main_with_overlapping():


def test_main_with_same_rois():
# Create a Namespace object with your arguments
# Read the test data from the test_case1_same_rois folder
source_dest = os.path.join(TEST_DATA_LOCATION, "test_case1_same_rois")
session_locations = [
os.path.join(source_dest, session) for session in os.listdir(source_dest)
Expand All @@ -253,11 +299,12 @@ def test_main_with_same_rois():
dest = os.path.join(SAVE_LOCATION, merged_session_name)
if os.path.exists(dest):
clear_directory(dest)

# Create a Namespace object with your arguments
mock_args = argparse.Namespace(
session_locations=session_locations,
save_location=SAVE_LOCATION,
merged_session_name=merged_session_name,
crs="EPSG:32610",
along_dist=25,
min_points=3,
max_std=15,
Expand Down Expand Up @@ -295,6 +342,7 @@ def test_main_with_different_rois():
session_locations=session_locations,
save_location=SAVE_LOCATION,
merged_session_name=merged_session_name,
crs="EPSG:32611",
along_dist=25,
min_points=3,
max_std=15,
Expand Down Expand Up @@ -332,6 +380,7 @@ def test_main_with_overlapping_dates():
session_locations=session_locations,
save_location=SAVE_LOCATION,
merged_session_name=merged_session_name,
crs="EPSG:32610",
along_dist=25,
min_points=3,
max_std=15,
Expand Down

0 comments on commit e6619e7

Please sign in to comment.