Skip to content

Commit

Permalink
Replace argparse type bool with strtobool (#65)
Browse files Browse the repository at this point in the history
* replace argparse type bool with strtobool

* replace deprecated distutils strtobool with reimplementation
  • Loading branch information
rgutzen authored May 22, 2024
1 parent 13ebed9 commit 52d6686
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 46 deletions.
4 changes: 2 additions & 2 deletions cobrawap/pipeline/stage02_processing/scripts/roi_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import neo
import os
from utils.io_utils import load_neo, write_neo, save_plot
from utils.parse import none_or_str
from utils.parse import none_or_str, str_to_bool
from utils.neo_utils import analogsignal_to_imagesequence, imagesequence_to_analogsignal

CLI = argparse.ArgumentParser()
Expand All @@ -22,7 +22,7 @@
help="path of output image", default=None)
CLI.add_argument("--intensity_threshold", nargs='?', type=float,
help="threshold for mask [0,1]", default=0.5)
CLI.add_argument("--crop_to_selection", nargs='?', type=bool,
CLI.add_argument("--crop_to_selection", nargs='?', type=str_to_bool,
help="discard frame outside of ROI", default=True)

def calculate_contour(img, contour_limit):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import neo
import argparse
import quantities as pq
from distutils.util import strtobool
from utils.io_utils import load_neo, write_neo
from utils.parse import str_to_bool

CLI = argparse.ArgumentParser()
CLI.add_argument("--data", nargs='?', type=str, required=True,
Expand All @@ -19,7 +19,7 @@
help="minimum duration of UP states in seconds")
CLI.add_argument("--min_down_duration", nargs='?', type=float, default=0.005,
help="minimum duration of DOWN states in seconds")
CLI.add_argument("--remove_down_first", nargs='?', type=strtobool, default=True,
CLI.add_argument("--remove_down_first", nargs='?', type=str_to_bool, default=True,
help="If True, remove short down states first")


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
import numpy as np
from copy import copy
import matplotlib.pyplot as plt
from distutils.util import strtobool
from utils.io_utils import load_neo, write_neo, save_plot
from utils.parse import none_or_str
from utils.parse import none_or_str, str_to_bool
from utils.neo_utils import imagesequence_to_analogsignal, analogsignal_to_imagesequence
from utils.convolve import phase_conv2d, get_kernel, conv, norm_angle

Expand All @@ -33,7 +32,7 @@
help='sigma of gaussian filter in each dimension')
CLI.add_argument("--derivative_filter", nargs='?', type=none_or_str, default=None,
help='Filter kernel to use for calculating spatial derivatives')
CLI.add_argument("--use_phases", nargs='?', type=strtobool, default=False,
CLI.add_argument("--use_phases", nargs='?', type=str_to_bool, default=False,
help='whether to use signal phase instead of amplitude')

def horn_schunck_step(frame, next_frame, alpha, max_Niter, convergence_limit,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
"""

import argparse
from warnings import warn

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from warnings import warn
from scipy.interpolate import RBFInterpolator
from utils.convolve import get_kernel, nan_conv2d
from utils.io_utils import load_neo, save_plot
from utils.parse import none_or_str
from utils.convolve import nan_conv2d, get_kernel
from utils.parse import none_or_str, str_to_bool

CLI = argparse.ArgumentParser()
CLI.add_argument("--data", nargs='?', type=str, required=True,
Expand All @@ -26,7 +27,7 @@
help="derivative kernel")
CLI.add_argument("--event_name", "--EVENT_NAME", nargs='?', type=str, default='wavefronts',
help="name of neo.Event to analyze (must contain waves)")
CLI.add_argument("--interpolate", "--INTERPOLATE", nargs='?', type=bool, default=False,
CLI.add_argument("--interpolate", "--INTERPOLATE", nargs='?', type=str_to_bool, default=False,
help="whether to thin-plate-spline interpolate the wave patterns before derivation")
CLI.add_argument("--smoothing", "--SMOOTHING", nargs='?', type=float, default=0,
help="smoothing factor for the interpolation")
Expand Down
94 changes: 60 additions & 34 deletions cobrawap/pipeline/utils/parse.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import numpy as np
import warnings
import re
from pathlib import Path
import sys
import warnings
from pathlib import Path

import numpy as np

from .io_utils import load_neo


def get_base_type(datatype):
if hasattr(datatype, 'dtype'):
if hasattr(datatype, "dtype"):
datatype = datatype.dtype
elif not type(datatype) == type:
datatype = type(datatype)
Expand All @@ -16,25 +18,35 @@ def get_base_type(datatype):
warnings.warn("List don't have a defined type!")

if np.issubdtype(datatype, np.integer):
return 'int'
return "int"
elif np.issubdtype(datatype, float):
return 'float'
return "float"
elif np.issubdtype(datatype, str):
return 'str'
return "str"
elif np.issubdtype(datatype, complex):
return 'complex'
return "complex"
elif np.issubdtype(datatype, bool):
return 'bool'
return "bool"
else:
warnings.warn(f"Did not recognize type {datatype}! returning 'object'")
return 'object'
return "object"


nan_values = {
"int": -1,
"float": np.nan,
"bool": False,
"a5": "None",
"str": "None",
"complex": np.nan + 1j * np.nan,
"object": None,
}

nan_values = {'int': -1, 'float': np.nan, 'bool': False, 'a5': 'None',
'str': 'None', 'complex': np.nan+1j*np.nan, 'object': None}

def get_nan_value(type_string):
return nan_values[type_string]


def guess_type(string):
try:
out = int(string)
Expand All @@ -43,11 +55,11 @@ def guess_type(string):
out = float(string)
except:
out = str(string)
if out == 'None':
if out == "None":
out = None
elif out == 'True':
elif out == "True":
out = True
elif out == 'False':
elif out == "False":
out = False
return out

Expand All @@ -56,26 +68,25 @@ def str2dict(string):
"""
Transforms a str(dict) back to dict
"""
if string[0] == '{':
if string[0] == "{":
string = string[1:]
if string[-1] == '}':
if string[-1] == "}":
string = string[:-1]
my_dict = {}
# list or tuple values
brackets = [delimiter for delimiter in ['[',']','(',')']
if delimiter in string]
brackets = [delimiter for delimiter in ["[", "]", "(", ")"] if delimiter in string]
if len(brackets):
for kv in string.split("{},".format(brackets[1])):
k,v = kv.split(":")
v = v.replace(brackets[0], '').replace(brackets[1], '')
values = [guess_type(val) for val in v.split(',')]
k, v = kv.split(":")
v = v.replace(brackets[0], "").replace(brackets[1], "")
values = [guess_type(val) for val in v.split(",")]
if len(values) == 1:
values = values[0]
my_dict[k.strip()] = values
# scalar values
else:
for kv in string.split(','):
k,v = kv.split(":")
for kv in string.split(","):
k, v = kv.split(":")
my_dict[k.strip()] = guess_type(v.strip())
return my_dict

Expand All @@ -87,7 +98,7 @@ def parse_string2dict(kwargs_str, **kwargs):
elif len(kwargs_str) == 1:
kwargs = kwargs_str[0]
else:
kwargs = ''.join(kwargs_str)[1:-1]
kwargs = "".join(kwargs_str)[1:-1]
else:
kwargs = str(kwargs_str)
if guess_type(kwargs) is None:
Expand All @@ -100,7 +111,7 @@ def parse_string2dict(kwargs_str, **kwargs):
nested_dict_name, nested_dict = match.split(":{")
nested_dict = nested_dict[:-1]
my_dict[nested_dict_name] = str2dict(nested_dict)
kwargs = kwargs.replace(match, '')
kwargs = kwargs.replace(match, "")
# match entries with word value, list value, or tuple value
pattern = re.compile("[\w\s]+:(?:[\w\.\s\/\-\&\+]+|\[[^\]]+\]|\([^\)]+\))")
for match in pattern.findall(kwargs):
Expand All @@ -117,19 +128,35 @@ def parse_string2dict(kwargs_str, **kwargs):
# else:
# return input_dict

_true_set = {"yes", "true", "t", "y", "1"}
_false_set = {"no", "false", "f", "n", "0"}


def str_to_bool(value, raise_exc=False):
if isinstance(value, str):
value = value.lower()
if value in _true_set:
return True
if value in _false_set:
return False
if raise_exc:
raise ValueError('Expected "%s"' % '", "'.join(_true_set | _false_set))
return None


def none_or_X(value, dtype):
if value is None or not bool(value) or value == 'None':
if value is None or not bool(value) or value == "None":
return None
try:
return dtype(value)
except ValueError:
return None


none_or_int = lambda v: none_or_X(v, int)
none_or_float = lambda v: none_or_X(v, float)
none_or_str = lambda v: none_or_X(v, str)
str_list = lambda v: v.split(',')
str_list = lambda v: v.split(",")


def parse_plot_channels(channels, input_file):
Expand All @@ -139,18 +166,17 @@ def parse_plot_channels(channels, input_file):
# * check is channel exists, even when there is no None
# * use annotation channel ids instead of array indices
if None in channels:
dim_t, channel_num = load_neo(input_file, object='analogsignal',
lazy=True).shape
dim_t, channel_num = load_neo(
input_file, object="analogsignal", lazy=True
).shape
for i, channel in enumerate(channels):
if channel is None or channel >= channel_num:
channels[i] = np.random.randint(0,channel_num)
channels[i] = np.random.randint(0, channel_num)
return channels


def determine_spatial_scale(coords):
coords = np.array(coords)
dists = np.diff(coords[:,0])
dists = np.diff(coords[:, 0])
dists = dists[np.nonzero(dists)]
return np.min(dists)


0 comments on commit 52d6686

Please sign in to comment.