diff --git a/.github/workflows/check_formatting.yml b/.github/workflows/check_formatting.yml index b6f8ec27..bca4a182 100644 --- a/.github/workflows/check_formatting.yml +++ b/.github/workflows/check_formatting.yml @@ -1,23 +1,19 @@ -name: check_formatting +name: Check Python File Formatting with Black on: [push, pull_request] jobs: formatting_job: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Set up conda - uses: conda-incubator/setup-miniconda@v2 + uses: mamba-org/setup-micromamba@v1 with: - miniforge-version: latest - miniforge-variant: mambaforge - channel-priority: strict - channels: conda-forge - show-channel-urls: true - use-only-tar-bz2: true - - - name: Install dependencies and check formatting - shell: bash -l {0} + environment-file: environment-ci.yml + generate-run-shell: true + cache-environment: true + cache-downloads: true + - name: Check formatting + shell: micromamba-shell {0} run: - mamba install --quiet --yes --file requirements.txt black && black --version && black tobac --check --diff diff --git a/.github/workflows/check_json.yml b/.github/workflows/check_json.yml index 60d61175..f5f45fd5 100644 --- a/.github/workflows/check_json.yml +++ b/.github/workflows/check_json.yml @@ -8,13 +8,12 @@ jobs: shell: bash -el {0} steps: - name: check out repository code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: set up conda environment - uses: conda-incubator/setup-miniconda@v2 + uses: actions/setup-python@v5 with: - auto-update-conda: true - auto-activate-base: false - activate-environment: checkjson-env + python-version: '3.12' + cache: 'pip' # caching pip dependencies - name: Install check-jsonschema run: | pip install check-jsonschema diff --git a/environment-ci.yml b/environment-ci.yml index 4f0b954b..964e7155 100644 --- a/environment-ci.yml +++ b/environment-ci.yml @@ -1,4 +1,4 @@ -name: pyart-dev +name: tobac-dev channels: - conda-forge dependencies: @@ -14,3 +14,4 @@ dependencies: - trackpy - pytest - typing_extensions + - black diff --git a/tobac/feature_detection.py b/tobac/feature_detection.py index a866e897..72955bb2 100644 --- a/tobac/feature_detection.py +++ b/tobac/feature_detection.py @@ -1015,7 +1015,6 @@ def feature_detection_multithreshold_timestep( raise ValueError( "Please provide the input parameter statistic to determine what statistics to calculate." ) - track_data = gaussian_filter( track_data, sigma=sigma_threshold @@ -1136,14 +1135,14 @@ def feature_detection_multithreshold_timestep( labels.ravel()[regions_old[key]] = key # apply function to get statistics based on labeled regions and functions provided by the user # the feature dataframe is updated by appending a column for each metric - + # select which data to use according to statistics_unsmoothed option stats_data = data_i.core_data() if statistics_unsmoothed else track_data - + features_thresholds = get_statistics( features_thresholds, labels, - stats_data, + stats_data, statistic=statistic, index=np.unique(labels[labels > 0]), id_column="idx", diff --git a/tobac/tests/test_utils_bulk_statistics.py b/tobac/tests/test_utils_bulk_statistics.py index 2026f9c4..d50bf539 100644 --- a/tobac/tests/test_utils_bulk_statistics.py +++ b/tobac/tests/test_utils_bulk_statistics.py @@ -153,51 +153,51 @@ def test_bulk_statistics_missing_segments(): ### Test 2D data with time dimension test_data = tb_test.make_simple_sample_data_2D().core_data() common_dset_opts = { - "in_arr": test_data, - "data_type": "iris",} - + "in_arr": test_data, + "data_type": "iris", + } + test_data_iris = tb_test.make_dataset_from_arr( - time_dim_num=0, y_dim_num=1, x_dim_num=2, **common_dset_opts) + time_dim_num=0, y_dim_num=1, x_dim_num=2, **common_dset_opts + ) # detect features threshold = 7 # test_data_iris = testing.make_dataset_from_arr(test_data, data_type="iris") fd_output = tobac.feature_detection.feature_detection_multithreshold( - test_data_iris, - dxy=1000, - threshold=[threshold], - n_min_threshold=100, - target="maximum",) + test_data_iris, + dxy=1000, + threshold=[threshold], + n_min_threshold=100, + target="maximum", + ) # perform segmentation with bulk statistics stats = { - "segment_max": np.max, - "segment_min": min, - "percentiles": (np.percentile, {"q": 95}),} + "segment_max": np.max, + "segment_min": min, + "percentiles": (np.percentile, {"q": 95}), + } out_seg_mask, out_df = tobac.segmentation.segmentation_2D( - fd_output, test_data_iris, dxy=1000, threshold=threshold) + fd_output, test_data_iris, dxy=1000, threshold=threshold + ) - # specify some timesteps we set to zero + # specify some timesteps we set to zero timesteps_to_zero = [1, 3, 10] # 0-based indexing - modified_data = out_seg_mask.data.copy() + modified_data = out_seg_mask.data.copy() # Set values to zero for the specified timesteps for timestep in timesteps_to_zero: modified_data[timestep, :, :] = 0 # Set all values for this timestep to zero # assure that bulk statistics in postprocessing give same result out_segmentation = tb_utils.get_statistics_from_mask( - out_df, out_seg_mask, test_data_iris, statistic=stats) + out_df, out_seg_mask, test_data_iris, statistic=stats + ) assert out_df.time.unique().size == out_segmentation.time.unique().size - - - - - - def test_bulk_statistics_multiple_fields(): """ Test that multiple field input to bulk_statistics works as intended diff --git a/tobac/utils/bulk_statistics.py b/tobac/utils/bulk_statistics.py index f9266dd8..daa430eb 100644 --- a/tobac/utils/bulk_statistics.py +++ b/tobac/utils/bulk_statistics.py @@ -301,9 +301,15 @@ def get_statistics_from_mask( for tt in pd.to_datetime(segmentation_mask.time): # select specific timestep - segmentation_mask_t = segmentation_mask.sel(time=tt, method = 'nearest').data + segmentation_mask_t = segmentation_mask.sel(time=tt, method="nearest").data fields_t = ( - field.sel(time=tt, method = 'nearest', tolerance = np.timedelta64(1000, 'us')).values if "time" in field.coords else field.values + ( + field.sel( + time=tt, method="nearest", tolerance=np.timedelta64(1000, "us") + ).values + if "time" in field.coords + else field.values + ) for field in fields )