Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix JSON and Black checks #457

Merged
merged 11 commits into from
Oct 18, 2024
22 changes: 9 additions & 13 deletions .github/workflows/check_formatting.yml
Original file line number Diff line number Diff line change
@@ -1,23 +1,19 @@
name: check_formatting
name: Check Python File Formatting with Black
on: [push, pull_request]
jobs:
formatting_job:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- name: Set up conda
uses: conda-incubator/setup-miniconda@v2
uses: mamba-org/setup-micromamba@v1
with:
miniforge-version: latest
miniforge-variant: mambaforge
channel-priority: strict
channels: conda-forge
show-channel-urls: true
use-only-tar-bz2: true

- name: Install dependencies and check formatting
shell: bash -l {0}
environment-file: environment-ci.yml
generate-run-shell: true
cache-environment: true
cache-downloads: true
- name: Check formatting
shell: micromamba-shell {0}
run:
mamba install --quiet --yes --file requirements.txt black &&
black --version &&
black tobac --check --diff
9 changes: 4 additions & 5 deletions .github/workflows/check_json.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@ jobs:
shell: bash -el {0}
steps:
- name: check out repository code
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: set up conda environment
uses: conda-incubator/setup-miniconda@v2
uses: actions/setup-python@v5
with:
auto-update-conda: true
auto-activate-base: false
activate-environment: checkjson-env
python-version: '3.12'
cache: 'pip' # caching pip dependencies
- name: Install check-jsonschema
run: |
pip install check-jsonschema
Expand Down
3 changes: 2 additions & 1 deletion environment-ci.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: pyart-dev
name: tobac-dev
channels:
- conda-forge
dependencies:
Expand All @@ -14,3 +14,4 @@ dependencies:
- trackpy
- pytest
- typing_extensions
- black
7 changes: 3 additions & 4 deletions tobac/feature_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,7 +1015,6 @@ def feature_detection_multithreshold_timestep(
raise ValueError(
"Please provide the input parameter statistic to determine what statistics to calculate."
)


track_data = gaussian_filter(
track_data, sigma=sigma_threshold
Expand Down Expand Up @@ -1136,14 +1135,14 @@ def feature_detection_multithreshold_timestep(
labels.ravel()[regions_old[key]] = key
# apply function to get statistics based on labeled regions and functions provided by the user
# the feature dataframe is updated by appending a column for each metric

# select which data to use according to statistics_unsmoothed option
stats_data = data_i.core_data() if statistics_unsmoothed else track_data

features_thresholds = get_statistics(
features_thresholds,
labels,
stats_data,
stats_data,
statistic=statistic,
index=np.unique(labels[labels > 0]),
id_column="idx",
Expand Down
44 changes: 22 additions & 22 deletions tobac/tests/test_utils_bulk_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,51 +153,51 @@ def test_bulk_statistics_missing_segments():
### Test 2D data with time dimension
test_data = tb_test.make_simple_sample_data_2D().core_data()
common_dset_opts = {
"in_arr": test_data,
"data_type": "iris",}

"in_arr": test_data,
"data_type": "iris",
}

test_data_iris = tb_test.make_dataset_from_arr(
time_dim_num=0, y_dim_num=1, x_dim_num=2, **common_dset_opts)
time_dim_num=0, y_dim_num=1, x_dim_num=2, **common_dset_opts
)

# detect features
threshold = 7
# test_data_iris = testing.make_dataset_from_arr(test_data, data_type="iris")
fd_output = tobac.feature_detection.feature_detection_multithreshold(
test_data_iris,
dxy=1000,
threshold=[threshold],
n_min_threshold=100,
target="maximum",)
test_data_iris,
dxy=1000,
threshold=[threshold],
n_min_threshold=100,
target="maximum",
)

# perform segmentation with bulk statistics
stats = {
"segment_max": np.max,
"segment_min": min,
"percentiles": (np.percentile, {"q": 95}),}
"segment_max": np.max,
"segment_min": min,
"percentiles": (np.percentile, {"q": 95}),
}

out_seg_mask, out_df = tobac.segmentation.segmentation_2D(
fd_output, test_data_iris, dxy=1000, threshold=threshold)
fd_output, test_data_iris, dxy=1000, threshold=threshold
)

# specify some timesteps we set to zero
# specify some timesteps we set to zero
timesteps_to_zero = [1, 3, 10] # 0-based indexing
modified_data = out_seg_mask.data.copy()
modified_data = out_seg_mask.data.copy()
# Set values to zero for the specified timesteps
for timestep in timesteps_to_zero:
modified_data[timestep, :, :] = 0 # Set all values for this timestep to zero

# assure that bulk statistics in postprocessing give same result
out_segmentation = tb_utils.get_statistics_from_mask(
out_df, out_seg_mask, test_data_iris, statistic=stats)
out_df, out_seg_mask, test_data_iris, statistic=stats
)

assert out_df.time.unique().size == out_segmentation.time.unique().size








def test_bulk_statistics_multiple_fields():
"""
Test that multiple field input to bulk_statistics works as intended
Expand Down
10 changes: 8 additions & 2 deletions tobac/utils/bulk_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,15 @@ def get_statistics_from_mask(

for tt in pd.to_datetime(segmentation_mask.time):
# select specific timestep
segmentation_mask_t = segmentation_mask.sel(time=tt, method = 'nearest').data
segmentation_mask_t = segmentation_mask.sel(time=tt, method="nearest").data
fields_t = (
field.sel(time=tt, method = 'nearest', tolerance = np.timedelta64(1000, 'us')).values if "time" in field.coords else field.values
(
field.sel(
time=tt, method="nearest", tolerance=np.timedelta64(1000, "us")
).values
if "time" in field.coords
else field.values
)
for field in fields
)

Expand Down
Loading