diff --git a/sunkit_image/asda.py b/sunkit_image/asda.py index 32fab622..a7fdd121 100644 --- a/sunkit_image/asda.py +++ b/sunkit_image/asda.py @@ -1,391 +1,391 @@ -""" -This module contains an implementation of the Automated Swirl Detection -Algorithm (ASDA). -""" - -import warnings -from itertools import product - -import numpy as np -from skimage import measure - -from sunkit_image.utils import calculate_gamma, points_in_poly, reform2d, remove_duplicate - -__all__ = [ - "generate_velocity_field", - "calculate_gamma_values", - "get_vortex_edges", - "get_vortex_properties", - "get_vortex_meshgrid", - "get_rotational_velocity", - "get_radial_velocity", - "get_velocity_field", -] - - -def generate_velocity_field(vx, vy, i, j, r=3): - """ - Given a point ``[i, j]``, generate a velocity field which contains a region - with a size of ``(2r+1) x (2r+1)`` centered at ``[i, j]`` from the original - velocity field ``vx`` and ``vy``. - - Parameters - ---------- - vx : `numpy.ndarray` - Velocity field in the x direction. - vy : `numpy.ndarray` - Velocity field in the y direction. - i : `int` - first dimension of the pixel position of a target point. - j : `int` - second dimension of the pixel position of a target point. - r : `int`, optional - Maximum distance of neighbor points from target point. - Default value is 3. - - Returns - ------- - `numpy.ndarray` - The first dimension is a velocity field which contains a - region with a size of ``(2r+1) x (2r+1)`` centered at ``[i, j]`` from - the original velocity field ``vx`` and ``vy``. - the second dimension is similar as the first dimension, but - with the mean velocity field subtracted from the original - velocity field. - """ - if vx.shape != vy.shape: - msg = "Shape of velocity field's vx and vy do not match" - raise ValueError(msg) - if not isinstance(r, int): - msg = "Keyword 'r' must be an integer" - raise TypeError(msg) - vel = np.array( - [[vx[i + im, j + jm], vy[i + im, j + jm]] for im in np.arange(-r, r + 1) for jm in np.arange(-r, r + 1)], - ) - return np.array([vel, vel - vel.mean(axis=0)]) - - -def calculate_gamma_values(vx, vy, factor=1, r=3): - """ - Calculate ``gamma1`` and ``gamma2`` values of velocity field vx and vy. - - Parameters - ---------- - vx : `numpy.ndarray` - Velocity field in the x direction. - vy : `numpy.ndarray` - Velocity field in the y direction. - factor : `int`, optional - Magnify the original data to find sub-grid vortex center and boundary. - Default value is 1. - r : `int`, optional - Maximum distance of neighbor points from target point. - Default value is 3. - - Returns - ------- - `tuple` - A tuple in form of ``(gamma1, gamma2)``, where ``gamma1`` is useful in - finding vortex centers and ``gamma2`` is useful in finding vortex - edges. - """ - - if vx.shape != vy.shape: - msg = "Shape of velocity field's vx and vy do not match" - raise ValueError(msg) - if not isinstance(r, int): - msg = "Keyword 'r' must be an integer" - raise TypeError(msg) - if not isinstance(factor, int): - msg = "Keyword 'factor' must be an integer" - raise TypeError(msg) - - # This part of the code was written in (x, y) order - # but numpy is in (y, x) order so we need to transpose it - dshape = np.shape(vx) - vx = vx.T - vy = vy.T - if factor > 1: - vx = reform2d(vx, factor) - vy = reform2d(vy, factor) - gamma = np.array([np.zeros_like(vx), np.zeros_like(vy)]).T - # pm vectors, see equation (8) in Graftieaux et al. 2001 or Equation (1) in Liu et al. 2019 - pm = np.array( - [[i, j] for i in np.arange(-r, r + 1) for j in np.arange(-r, r + 1)], - dtype=float, - ) - # Mode of vector pm - pnorm = np.linalg.norm(pm, axis=1) - # Number of points in the concerned region - N = (2 * r + 1) ** 2 - - index = np.array( - [[i, j] for i in np.arange(r, dshape[0] - r) for j in np.arange(r, dshape[1] - r)], - ) - index = index.T - vel = generate_velocity_field(vx, vy, index[1], index[0], r) - for d, (i, j) in enumerate( - product(np.arange(r, dshape[0] - r, 1), np.arange(r, dshape[1] - r, 1)), - ): - gamma[i, j, 0], gamma[i, j, 1] = calculate_gamma(pm, vel[..., d], pnorm, N) - # Transpose back vx & vy - vx = vx.T - vy = vy.T - return gamma - - -def get_vortex_edges(gamma, rmin=4, gamma_min=0.89, factor=1): - """ - Find all swirls from ``gamma1``, and ``gamma2``. - - Parameters - ---------- - gamma : `tuple` - A tuple in form of ``(gamma1, gamma2)``, where ``gamma1`` is useful in - finding vortex centers and ``gamma2`` is useful in finding vortex - edges. - rmin : `int`, optional - Minimum radius of swirls, all swirls with radius less than ``rmin`` will be rejected. - Defaults to 4. - gamma_min : `float`, optional - Minimum value of ``gamma1``, all potential swirls with - peak ``gamma1`` values less than ``gamma_min`` will be rejected. - factor : `int`, optional - Magnify the original data to find sub-grid vortex center and boundary. - Default value is 1. - - Returns - ------- - `dict` - The keys and their meanings of the dictionary are: - ``center`` : Center locations of vortices, in the form of ``[x, y]``. - ``edge`` : Edge locations of vortices, in the form of ``[x, y]``. - ``points`` : All points within vortices, in the form of ``[x, y]``. - ``peak`` : Maximum/minimum gamma1 values in vortices. - ``radius`` : Equivalent radius of vortices. - All results are in pixel coordinates. - """ - if not isinstance(factor, int): - msg = "Keyword 'factor' must be an integer" - raise TypeError(msg) - - edge_prop = {"center": (), "edge": (), "points": (), "peak": (), "radius": ()} - cs = np.array(measure.find_contours(gamma[..., 1].T, -2 / np.pi), dtype=object) - cs_pos = np.array(measure.find_contours(gamma[..., 1].T, 2 / np.pi), dtype=object) - if len(cs) == 0: - cs = cs_pos - elif len(cs_pos) != 0: - cs = np.append(cs, cs_pos, 0) - for i in range(np.shape(cs)[0]): - v = np.rint(cs[i].astype(np.float32)) - v = remove_duplicate(v) - # Find all points in the contour - ps = points_in_poly(v) - dust = [gamma[..., 0][int(p[1]), int(p[0])] for p in ps] - # Determine swirl properties - if len(dust) > 1: - # Effective radius - re = np.sqrt(np.array(ps).shape[0] / np.pi) / factor - # Only consider swirls with re >= rmin and maximum gamma1 value greater than gamma_min - if np.max(np.fabs(dust)) >= gamma_min and re >= rmin: - # Extract the index, only first dimension - idx = np.where(np.fabs(dust) == np.max(np.fabs(dust)))[0][0] - edge_prop["center"] += (np.array(ps[idx]) / factor,) - edge_prop["edge"] += (np.array(v) / factor,) - edge_prop["points"] += (np.array(ps) / factor,) - edge_prop["peak"] += (dust[idx],) - edge_prop["radius"] += (re,) - return edge_prop - - -def get_vortex_properties(vx, vy, edge_prop, image=None): - """ - Calculate expanding, rotational speed, equivalent radius and average - intensity of given swirls. - - Parameters - ---------- - vx : `numpy.ndarray` - Velocity field in the x direction. - vy : `numpy.ndarray` - Velocity field in the y direction. - edge_prop : `dict` - The keys and their meanings of the dictionary are: - ``center`` : Center locations of vortices, in the form of ``[x, y]``. - ``edge`` : Edge locations of vortices, in the form of ``[x, y]``. - ``points`` : All points within vortices, in the form of ``[x, y]``. - ``peak`` : Maximum/minimum gamma1 values in vortices. - ``radius`` : Equivalent radius of vortices. - All results are in pixel coordinates. - image : `numpy.ndarray` - Has to have the same shape as ``vx`` observational image, - which will be used to calculate the average observational values of all swirls. - - Returns - ------- - `tuple` - The returned tuple has four components, which are: - - ``ve`` : expanding speed, in the same unit as ``vx`` or ``vy``. - ``vr`` : rotational speed, in the same unit as ``vx`` or ``vy``. - ``vc`` : velocity of the center, in the form of ``[vx, vy]``. - ``ia`` : average of the observational values within the vortices if the parameter image is given. - """ - if vx.shape != vy.shape: - msg = "Shape of velocity field's vx and vy do not match" - raise ValueError(msg) - - ve, vr, vc, ia = (), (), (), () - for i in range(len(edge_prop["center"])): - # Centre and edge of i-th swirl - cen = edge_prop["center"][i] - edg = edge_prop["edge"][i] - # Points of i-th swirl - pnt = np.array(edge_prop["points"][i], dtype=int) - # Calculate velocity of the center - vc += ( - [ - vx[int(round(cen[1])), int(round(cen[0]))], - vy[int(round(cen[1])), int(round(cen[0]))], - ], - ) - # Calculate average the observational values - if image is None: - ia += (None,) - else: - value = sum(image[pos[1], pos[0]] for pos in pnt) - ia += (value / pnt.shape[0],) - ve0, vr0 = [], [] - for j in range(edg.shape[0]): - # Edge position - idx = [edg[j][0], edg[j][1]] - # Eadial vector from swirl center to a point at its edge - pm = [idx[0] - cen[0], idx[1] - cen[1]] - # Tangential vector - tn = [cen[1] - idx[1], idx[0] - cen[0]] - # Velocity vector - v = [vx[int(idx[1]), int(idx[0])], vy[int(idx[1]), int(idx[0])]] - ve0.append(np.dot(v, pm) / np.linalg.norm(pm)) - vr0.append(np.dot(v, tn) / np.linalg.norm(tn)) - ve += (np.nanmean(ve0),) - vr += (np.nanmean(vr0),) - return ve, vr, vc, ia - - -def get_vortex_meshgrid(x_range, y_range): - """ - Returns a meshgrid of the coordinates of the vortex. - - Parameters - ---------- - x_range : `list` - Range of the x coordinates of the meshgrid. - y_range : `list` - Range of the y coordinates of the meshgrid. - - Return - ------ - `tuple` - Contains the meshgrids generated. - """ - xx, yy = np.meshgrid(np.arange(x_range[0], x_range[1]), np.arange(y_range[0], y_range[1])) - return xx, yy - - -def get_rotational_velocity(gamma, rcore, r=0): - """ - Calculate rotation speed at radius of ``r``. - - Parameters - ---------- - gamma : `float`, optional - A replacement for ``vmax`` and only used if both ``gamma`` and ``rcore`` are not `None`. - Defaults to `None`. - rcore : `float`, optional - A replacement for ``rmax`` and only used if both ``gamma`` and ``rcore`` are not `None`. - Defaults to `None`. - r : `float`, optional - Radius which defaults to 0. - - Return - ------ - `float` - Rotating speed at radius of ``r``. - """ - r = r + 1e-10 - return gamma * (1.0 - np.exp(0 - np.square(r) / np.square(rcore))) / (2 * np.pi * r) - - -def get_radial_velocity(gamma, rcore, ratio_vradial, r=0): - """ - Calculate radial (expanding or shrinking) speed at radius of ``r``. - - Parameters - ---------- - gamma : `float`, optional - A replacement for ``vmax`` and only used if both ``gamma`` and ``rcore`` are not `None`. - Defaults to `None`. - rcore : `float`, optional - A replacement for ``rmax`` and only used if both ``gamma`` and ``rcore`` are not `None`. - Defaults to `None`. - ratio_vradial : `float`, optional - Ratio between expanding/shrinking speed and rotating speed. - Defaults to 0. - r : `float`, optional - Radius which defaults to 0. - - Return - ------ - `float` - Radial speed at the radius of ``r``. - """ - r = r + 1e-10 - return get_rotational_velocity(gamma, rcore, r) * ratio_vradial - - -def get_velocity_field(gamma, rcore, ratio_vradial, x_range, y_range, x=None, y=None): - """ - Calculates the velocity field in a meshgrid generated with ``x_range`` and - ``y_range``. - - Parameters - ---------- - gamma : `float`, optional - A replacement for ``vmax`` and only used if both ``gamma`` and ``rcore`` are not `None`. - Defaults to `None`. - rcore : `float`, optional - A replacement for ``rmax`` and only used if both ``gamma`` and ``rcore`` are not `None`. - Defaults to `None`. - ratio_vradial : `float`, optional - Ratio between expanding/shrinking speed and rotating speed. - Defaults to 0. - x_range : `list` - Range of the x coordinates of the meshgrid. - y_range : `list` - range of the y coordinates of the meshgrid. - x, y : `numpy.meshgrid`, optional - If both are given, ``x_range`` and ``y_range`` will be ignored. - Defaults to None``. - - Return - ------ - `tuple` - The generated velocity field ``(vx, vy)``. - """ - if x is None or y is None: - # Check if one of the input parameters is None but the other one is not None - if x != y: - warnings.warn("One of the input parameters is missing, setting both to 'None'", stacklevel=3) - x, y = None, None - # Creating mesh grid - x, y = get_vortex_meshgrid(x_range=x_range, y_range=y_range) - # Calculate radius - r = np.sqrt(np.square(x) + np.square(y)) + 1e-10 - # Calculate velocity vector - vector = [ - 0 - get_rotational_velocity(gamma, rcore, r) * y + get_radial_velocity(gamma, rcore, ratio_vradial, r) * x, - get_rotational_velocity(gamma, rcore, r) * x + get_radial_velocity(gamma, rcore, ratio_vradial, r) * y, - ] - vx = vector[0] / r - vy = vector[1] / r - return vx, vy +""" +This module contains an implementation of the Automated Swirl Detection +Algorithm (ASDA). +""" + +import warnings +from itertools import product + +import numpy as np +from skimage import measure + +from sunkit_image.utils import calculate_gamma, points_in_poly, reform2d, remove_duplicate + +__all__ = [ + "generate_velocity_field", + "calculate_gamma_values", + "get_vortex_edges", + "get_vortex_properties", + "get_vortex_meshgrid", + "get_rotational_velocity", + "get_radial_velocity", + "get_velocity_field", +] + + +def generate_velocity_field(vx, vy, i, j, r=3): + """ + Given a point ``[i, j]``, generate a velocity field which contains a region + with a size of ``(2r+1) x (2r+1)`` centered at ``[i, j]`` from the original + velocity field ``vx`` and ``vy``. + + Parameters + ---------- + vx : `numpy.ndarray` + Velocity field in the x direction. + vy : `numpy.ndarray` + Velocity field in the y direction. + i : `int` + first dimension of the pixel position of a target point. + j : `int` + second dimension of the pixel position of a target point. + r : `int`, optional + Maximum distance of neighbor points from target point. + Default value is 3. + + Returns + ------- + `numpy.ndarray` + The first dimension is a velocity field which contains a + region with a size of ``(2r+1) x (2r+1)`` centered at ``[i, j]`` from + the original velocity field ``vx`` and ``vy``. + the second dimension is similar as the first dimension, but + with the mean velocity field subtracted from the original + velocity field. + """ + if vx.shape != vy.shape: + msg = "Shape of velocity field's vx and vy do not match" + raise ValueError(msg) + if not isinstance(r, int): + msg = "Keyword 'r' must be an integer" + raise TypeError(msg) + vel = np.array( + [[vx[i + im, j + jm], vy[i + im, j + jm]] for im in np.arange(-r, r + 1) for jm in np.arange(-r, r + 1)], + ) + return np.array([vel, vel - vel.mean(axis=0)]) + + +def calculate_gamma_values(vx, vy, factor=1, r=3): + """ + Calculate ``gamma1`` and ``gamma2`` values of velocity field vx and vy. + + Parameters + ---------- + vx : `numpy.ndarray` + Velocity field in the x direction. + vy : `numpy.ndarray` + Velocity field in the y direction. + factor : `int`, optional + Magnify the original data to find sub-grid vortex center and boundary. + Default value is 1. + r : `int`, optional + Maximum distance of neighbor points from target point. + Default value is 3. + + Returns + ------- + `tuple` + A tuple in form of ``(gamma1, gamma2)``, where ``gamma1`` is useful in + finding vortex centers and ``gamma2`` is useful in finding vortex + edges. + """ + + if vx.shape != vy.shape: + msg = "Shape of velocity field's vx and vy do not match" + raise ValueError(msg) + if not isinstance(r, int): + msg = "Keyword 'r' must be an integer" + raise TypeError(msg) + if not isinstance(factor, int): + msg = "Keyword 'factor' must be an integer" + raise TypeError(msg) + + # This part of the code was written in (x, y) order + # but numpy is in (y, x) order so we need to transpose it + dshape = np.shape(vx) + vx = vx.T + vy = vy.T + if factor > 1: + vx = reform2d(vx, factor) + vy = reform2d(vy, factor) + gamma = np.array([np.zeros_like(vx), np.zeros_like(vy)]).T + # pm vectors, see equation (8) in Graftieaux et al. 2001 or Equation (1) in Liu et al. 2019 + pm = np.array( + [[i, j] for i in np.arange(-r, r + 1) for j in np.arange(-r, r + 1)], + dtype=float, + ) + # Mode of vector pm + pnorm = np.linalg.norm(pm, axis=1) + # Number of points in the concerned region + N = (2 * r + 1) ** 2 + + index = np.array( + [[i, j] for i in np.arange(r, dshape[0] - r) for j in np.arange(r, dshape[1] - r)], + ) + index = index.T + vel = generate_velocity_field(vx, vy, index[1], index[0], r) + for d, (i, j) in enumerate( + product(np.arange(r, dshape[0] - r, 1), np.arange(r, dshape[1] - r, 1)), + ): + gamma[i, j, 0], gamma[i, j, 1] = calculate_gamma(pm, vel[..., d], pnorm, N) + # Transpose back vx & vy + vx = vx.T + vy = vy.T + return gamma + + +def get_vortex_edges(gamma, rmin=4, gamma_min=0.89, factor=1): + """ + Find all swirls from ``gamma1``, and ``gamma2``. + + Parameters + ---------- + gamma : `tuple` + A tuple in form of ``(gamma1, gamma2)``, where ``gamma1`` is useful in + finding vortex centers and ``gamma2`` is useful in finding vortex + edges. + rmin : `int`, optional + Minimum radius of swirls, all swirls with radius less than ``rmin`` will be rejected. + Defaults to 4. + gamma_min : `float`, optional + Minimum value of ``gamma1``, all potential swirls with + peak ``gamma1`` values less than ``gamma_min`` will be rejected. + factor : `int`, optional + Magnify the original data to find sub-grid vortex center and boundary. + Default value is 1. + + Returns + ------- + `dict` + The keys and their meanings of the dictionary are: + ``center`` : Center locations of vortices, in the form of ``[x, y]``. + ``edge`` : Edge locations of vortices, in the form of ``[x, y]``. + ``points`` : All points within vortices, in the form of ``[x, y]``. + ``peak`` : Maximum/minimum gamma1 values in vortices. + ``radius`` : Equivalent radius of vortices. + All results are in pixel coordinates. + """ + if not isinstance(factor, int): + msg = "Keyword 'factor' must be an integer" + raise TypeError(msg) + + edge_prop = {"center": (), "edge": (), "points": (), "peak": (), "radius": ()} + cs = np.array(measure.find_contours(gamma[..., 1].T, -2 / np.pi), dtype=object) + cs_pos = np.array(measure.find_contours(gamma[..., 1].T, 2 / np.pi), dtype=object) + if len(cs) == 0: + cs = cs_pos + elif len(cs_pos) != 0: + cs = np.append(cs, cs_pos, 0) + for i in range(np.shape(cs)[0]): + v = np.rint(cs[i].astype(np.float32)) + v = remove_duplicate(v) + # Find all points in the contour + ps = points_in_poly(v) + dust = [gamma[..., 0][int(p[1]), int(p[0])] for p in ps] + # Determine swirl properties + if len(dust) > 1: + # Effective radius + re = np.sqrt(np.array(ps).shape[0] / np.pi) / factor + # Only consider swirls with re >= rmin and maximum gamma1 value greater than gamma_min + if np.max(np.fabs(dust)) >= gamma_min and re >= rmin: + # Extract the index, only first dimension + idx = np.where(np.fabs(dust) == np.max(np.fabs(dust)))[0][0] + edge_prop["center"] += (np.array(ps[idx]) / factor,) + edge_prop["edge"] += (np.array(v) / factor,) + edge_prop["points"] += (np.array(ps) / factor,) + edge_prop["peak"] += (dust[idx],) + edge_prop["radius"] += (re,) + return edge_prop + + +def get_vortex_properties(vx, vy, edge_prop, image=None): + """ + Calculate expanding, rotational speed, equivalent radius and average + intensity of given swirls. + + Parameters + ---------- + vx : `numpy.ndarray` + Velocity field in the x direction. + vy : `numpy.ndarray` + Velocity field in the y direction. + edge_prop : `dict` + The keys and their meanings of the dictionary are: + ``center`` : Center locations of vortices, in the form of ``[x, y]``. + ``edge`` : Edge locations of vortices, in the form of ``[x, y]``. + ``points`` : All points within vortices, in the form of ``[x, y]``. + ``peak`` : Maximum/minimum gamma1 values in vortices. + ``radius`` : Equivalent radius of vortices. + All results are in pixel coordinates. + image : `numpy.ndarray` + Has to have the same shape as ``vx`` observational image, + which will be used to calculate the average observational values of all swirls. + + Returns + ------- + `tuple` + The returned tuple has four components, which are: + + ``ve`` : expanding speed, in the same unit as ``vx`` or ``vy``. + ``vr`` : rotational speed, in the same unit as ``vx`` or ``vy``. + ``vc`` : velocity of the center, in the form of ``[vx, vy]``. + ``ia`` : average of the observational values within the vortices if the parameter image is given. + """ + if vx.shape != vy.shape: + msg = "Shape of velocity field's vx and vy do not match" + raise ValueError(msg) + + ve, vr, vc, ia = (), (), (), () + for i in range(len(edge_prop["center"])): + # Centre and edge of i-th swirl + cen = edge_prop["center"][i] + edg = edge_prop["edge"][i] + # Points of i-th swirl + pnt = np.array(edge_prop["points"][i], dtype=int) + # Calculate velocity of the center + vc += ( + [ + vx[int(round(cen[1])), int(round(cen[0]))], + vy[int(round(cen[1])), int(round(cen[0]))], + ], + ) + # Calculate average the observational values + if image is None: + ia += (None,) + else: + value = sum(image[pos[1], pos[0]] for pos in pnt) + ia += (value / pnt.shape[0],) + ve0, vr0 = [], [] + for j in range(edg.shape[0]): + # Edge position + idx = [edg[j][0], edg[j][1]] + # Eadial vector from swirl center to a point at its edge + pm = [idx[0] - cen[0], idx[1] - cen[1]] + # Tangential vector + tn = [cen[1] - idx[1], idx[0] - cen[0]] + # Velocity vector + v = [vx[int(idx[1]), int(idx[0])], vy[int(idx[1]), int(idx[0])]] + ve0.append(np.dot(v, pm) / np.linalg.norm(pm)) + vr0.append(np.dot(v, tn) / np.linalg.norm(tn)) + ve += (np.nanmean(ve0),) + vr += (np.nanmean(vr0),) + return ve, vr, vc, ia + + +def get_vortex_meshgrid(x_range, y_range): + """ + Returns a meshgrid of the coordinates of the vortex. + + Parameters + ---------- + x_range : `list` + Range of the x coordinates of the meshgrid. + y_range : `list` + Range of the y coordinates of the meshgrid. + + Return + ------ + `tuple` + Contains the meshgrids generated. + """ + xx, yy = np.meshgrid(np.arange(x_range[0], x_range[1]), np.arange(y_range[0], y_range[1])) + return xx, yy + + +def get_rotational_velocity(gamma, rcore, r=0): + """ + Calculate rotation speed at radius of ``r``. + + Parameters + ---------- + gamma : `float`, optional + A replacement for ``vmax`` and only used if both ``gamma`` and ``rcore`` are not `None`. + Defaults to `None`. + rcore : `float`, optional + A replacement for ``rmax`` and only used if both ``gamma`` and ``rcore`` are not `None`. + Defaults to `None`. + r : `float`, optional + Radius which defaults to 0. + + Return + ------ + `float` + Rotating speed at radius of ``r``. + """ + r = r + 1e-10 + return gamma * (1.0 - np.exp(0 - np.square(r) / np.square(rcore))) / (2 * np.pi * r) + + +def get_radial_velocity(gamma, rcore, ratio_vradial, r=0): + """ + Calculate radial (expanding or shrinking) speed at radius of ``r``. + + Parameters + ---------- + gamma : `float`, optional + A replacement for ``vmax`` and only used if both ``gamma`` and ``rcore`` are not `None`. + Defaults to `None`. + rcore : `float`, optional + A replacement for ``rmax`` and only used if both ``gamma`` and ``rcore`` are not `None`. + Defaults to `None`. + ratio_vradial : `float`, optional + Ratio between expanding/shrinking speed and rotating speed. + Defaults to 0. + r : `float`, optional + Radius which defaults to 0. + + Return + ------ + `float` + Radial speed at the radius of ``r``. + """ + r = r + 1e-10 + return get_rotational_velocity(gamma, rcore, r) * ratio_vradial + + +def get_velocity_field(gamma, rcore, ratio_vradial, x_range, y_range, x=None, y=None): + """ + Calculates the velocity field in a meshgrid generated with ``x_range`` and + ``y_range``. + + Parameters + ---------- + gamma : `float`, optional + A replacement for ``vmax`` and only used if both ``gamma`` and ``rcore`` are not `None`. + Defaults to `None`. + rcore : `float`, optional + A replacement for ``rmax`` and only used if both ``gamma`` and ``rcore`` are not `None`. + Defaults to `None`. + ratio_vradial : `float`, optional + Ratio between expanding/shrinking speed and rotating speed. + Defaults to 0. + x_range : `list` + Range of the x coordinates of the meshgrid. + y_range : `list` + range of the y coordinates of the meshgrid. + x, y : `numpy.meshgrid`, optional + If both are given, ``x_range`` and ``y_range`` will be ignored. + Defaults to None``. + + Return + ------ + `tuple` + The generated velocity field ``(vx, vy)``. + """ + if x is None or y is None: + # Check if one of the input parameters is None but the other one is not None + if x != y: + warnings.warn("One of the input parameters is missing, setting both to 'None'", stacklevel=3) + x, y = None, None + # Creating mesh grid + x, y = get_vortex_meshgrid(x_range=x_range, y_range=y_range) + # Calculate radius + r = np.sqrt(np.square(x) + np.square(y)) + 1e-10 + # Calculate velocity vector + vector = [ + 0 - get_rotational_velocity(gamma, rcore, r) * y + get_radial_velocity(gamma, rcore, ratio_vradial, r) * x, + get_rotational_velocity(gamma, rcore, r) * x + get_radial_velocity(gamma, rcore, ratio_vradial, r) * y, + ] + vx = vector[0] / r + vy = vector[1] / r + return vx, vy diff --git a/sunkit_image/granule.py b/sunkit_image/granule.py index 76a5a5e0..af002b0a 100644 --- a/sunkit_image/granule.py +++ b/sunkit_image/granule.py @@ -1,282 +1,282 @@ -""" -This module contains functions that will segment images for granule detection. -""" - -import logging - -import matplotlib as mpl -import numpy as np -import scipy -import skimage - -import sunpy -import sunpy.map - -__all__ = ["segment", "segments_overlap_fraction"] - - -def segment(smap, *, skimage_method="li", mark_dim_centers=False, bp_min_flux=None): - """ - Segment an optical image of the solar photosphere into tri-value maps with: - - * 0 as intergranule - * 1 as granule - * 2 as brightpoint - - If mark_dim_centers is set to True, an additional label, 3, will be assigned to - dim granule centers. - - Parameters - ---------- - smap : `~sunpy.map.GenericMap` - `~sunpy.map.GenericMap` containing data to segment. Must have square pixels. - skimage_method : { "li", "otsu", "isodata", "mean", "minimum", "yen", "triangle" }, optional - scikit-image thresholding method, defaults to "li". - Depending on input data, one or more of these methods may be - significantly better or worse than the others. Typically, 'li', 'otsu', - 'mean', and 'isodata' are good choices, 'yen' and 'triangle' over- - identify intergranule material, and 'minimum' over identifies granules. - mark_dim_centers : `bool`, optional - Whether to mark dim granule centers as a separate category for future exploration. - bp_min_flux : `float`, optional - Minimum flux per pixel for a region to be considered a brightpoint. - Default is `None` which will use data mean + 0.5 * sigma. - - Returns - ------- - segmented_map : `~sunpy.map.GenericMap` - `~sunpy.map.GenericMap` containing a segmented image (with the original header). - """ - if not isinstance(smap, sunpy.map.mapbase.GenericMap): - msg = "Input must be an instance of a sunpy.map.GenericMap" - raise TypeError(msg) - if smap.scale[0].value == smap.scale[1].value: - resolution = smap.scale[0].value - else: - msg = "Currently only maps with square pixels are supported." - raise ValueError(msg) - # Obtain local histogram equalization of map. - # min-max normalization to [0, 1] - map_norm = (smap.data - np.nanmin(smap.data)) / (np.nanmax(smap.data) - np.nanmin(smap.data)) - map_he = skimage.filters.rank.equalize( - skimage.util.img_as_ubyte(map_norm), - footprint=skimage.morphology.disk(radius=100), - ) - # Apply initial skimage threshold. - median_filtered = scipy.ndimage.median_filter(map_he, size=3) - threshold = _get_threshold(median_filtered, skimage_method) - segmented_image = np.uint8(median_filtered > threshold) - # Fix the extra intergranule material bits in the middle of granules. - seg_im_fixed = _trim_intergranules(segmented_image, mark=mark_dim_centers) - # Mark brightpoint and get final granule and brightpoint count. - seg_im_markbp, brightpoint_count, granule_count = _mark_brightpoint( - seg_im_fixed, - smap.data, - map_he, - resolution, - bp_min_flux, - ) - logging.info(f"Segmentation has identified {granule_count} granules and {brightpoint_count} brightpoint") # NOQA: G004 - # Create output map using input wcs and adding colormap such that 0 (intergranules) = black, 1 (granule) = white, 2 (brightpoints) = yellow, 3 (dim_centers) = blue. - segmented_map = sunpy.map.Map(seg_im_markbp, smap.wcs) - cmap = mpl.colors.ListedColormap(["black", "white", "#ffc406", "blue"]) - norm = mpl.colors.BoundaryNorm(boundaries=[-0.5, 0.5, 1.5, 2.5, 3.5], ncolors=cmap.N) - segmented_map.plot_settings["cmap"] = cmap - segmented_map.plot_settings["norm"] = norm - return segmented_map - - -def _get_threshold(data, method): - """ - Get the threshold value using given skimage segmentation type. - - Parameters - ---------- - data : `numpy.ndarray` - Data to threshold. - method : { "li", "otsu", "isodata", "mean", "minimum", "yen", "triangle" } - scikit-image thresholding method. - - Returns - ------- - threshold : `float` - Threshold value. - """ - if not isinstance(data, np.ndarray): - msg = "Input data must be an instance of a np.ndarray" - raise TypeError(msg) - if len(data.flatten()) > 500**2: - logging.info( - "Input image is large (> 500**2), so threshold computation will be based on a random 500x500 sample of pixels", - ) - rng = np.random.default_rng() - # Computing threshold based on random sample works well and saves significant computational time - data = rng.choice( - data.flatten(), - (500, 500), - ) - method = method.lower() - method_funcs = { - "li": skimage.filters.threshold_li, - "otsu": skimage.filters.threshold_otsu, - "yen": skimage.filters.threshold_yen, - "mean": skimage.filters.threshold_mean, - "minimum": skimage.filters.threshold_minimum, - "triangle": skimage.filters.threshold_triangle, - "isodata": skimage.filters.threshold_isodata, - } - if method not in method_funcs: - raise ValueError("Method must be one of: " + ", ".join(list(method_funcs.keys()))) - return method_funcs[method](data) - - -def _trim_intergranules(segmented_image, *, mark=False): - """ - Remove the erroneous identification of intergranule material in the middle - of granules that the pure threshold segmentation produces. - - Parameters - ---------- - segmented_image : `numpy.ndarray` - The segmented image containing incorrect extra intergranules. - mark : `bool` - If `False` (the default), remove erroneous intergranules. - If `True`, mark them as 3 instead (for later examination). - - Returns - ------- - segmented_image_fixed : `numpy.ndarray` - The segmented image without incorrect extra intergranules. - """ - if len(np.unique(segmented_image)) > 2: - msg = "segmented_image must only have values of 1 and 0." - raise ValueError(msg) - # Float conversion for correct region labeling. - segmented_image_fixed = np.copy(segmented_image).astype(float) - # Add padding of intergranule around edges. - # Avoids the case where all edge pixels are granule, - # which will result in all dim centers as intergranules. - pad = int(np.shape(segmented_image)[0] / 200) - segmented_image_fixed[:, 0:pad] = 0 - segmented_image_fixed[0:pad, :] = 0 - segmented_image_fixed[:, -pad:] = 0 - segmented_image_fixed[-pad:, :] = 0 - labeled_seg = skimage.measure.label(segmented_image_fixed + 1, connectivity=2) - values = np.unique(labeled_seg) - # Find value of the large continuous 0-valued region. - size = 0 - for value in values: - if len(labeled_seg[labeled_seg == value]) > size and sum(segmented_image[labeled_seg == value] == 0): - real_IG_value = value - size = len(labeled_seg[labeled_seg == value]) - # Set all other 0 regions to mark value (3). - for value in values: - if np.sum(segmented_image[labeled_seg == value]) == 0 and value != real_IG_value: - segmented_image_fixed[labeled_seg == value] = 3 if mark else 1 - return segmented_image_fixed - - -def _mark_brightpoint(segmented_image, data, he_data, resolution, bp_min_flux=None): - """ - Mark brightpoints separately from granules - give them a value of 2. - - Parameters - ---------- - segmented_image : `numpy.ndarray` - The segmented image containing incorrect middles. - data : `numpy array` - The original image. - he_data : `numpy array` - Original image with local histogram equalization applied. - resolution : `float` - Spatial resolution (arcsec/pixel) of the data. - bp_min_flux : `float`, optional - Minimum flux per pixel for a region to be considered a brightpoint. - Default is `None` which will use data mean + 0.5 * sigma. - - Returns - ------- - segmented_image_fixed : `numpy.ndrray` - The segmented image with brightpoints marked as 2. - brightpoint_count: `int` - The number of brightpoints identified in the image. - granule_count: `int` - The number of granules identified, after re-classification of brightpoint. - """ - # General size limits - bp_size_limit = ( - 0.1 # Approximate max size of a photosphere bright point in square arcsec (see doi 10.3847/1538-4357/aab150) - ) - bp_pix_upper_limit = (bp_size_limit / resolution) ** 2 # Max area in pixels - bp_pix_lower_limit = 4 # Very small bright regions are likely artifacts - # General flux limit determined by visual inspection (set using equalized map) - if bp_min_flux is None: - stand_devs = 1.25 # General flux limit determined by visual inspection (set using equalized map) - bp_brightness_limit = np.nanmean(he_data) + stand_devs * np.nanstd(he_data) - else: - bp_brightness_limit = bp_min_flux - if len(np.unique(segmented_image)) > 3: - msg = "segmented_image must have only values of 1, 0 and 3 (if dim centers marked)" - raise ValueError(msg) - # Obtain gradient map and set threshold for gradient on BP edges - grad = np.abs(np.gradient(data)[0] + np.gradient(data)[1]) - bp_min_grad = np.quantile(grad, 0.95) - # Label all regions of flux greater than brightness limit (candidate regions) - bright_dim_seg = np.zeros_like(data) - bright_dim_seg[he_data > bp_brightness_limit] = 1 - labeled_bright_dim_seg = skimage.measure.label(bright_dim_seg + 1, connectivity=2) - values = np.unique(labeled_bright_dim_seg) - # From candidate regions, select those within pixel limit and gradient limit - segmented_image_fixed = np.copy(segmented_image.astype(float)) # Make type float to enable adding float values - bp_count = 0 - for value in values: - if (bright_dim_seg[labeled_bright_dim_seg == value])[0] == 1: # Check region is not the non-bp region - # check that region is within pixel limits. - region_size = len(labeled_bright_dim_seg[labeled_bright_dim_seg == value]) - if region_size < bp_pix_upper_limit and region_size > bp_pix_lower_limit: - # check that region has high average gradient (maybe try max gradient?) - region_mean_grad = np.mean(grad[labeled_bright_dim_seg == value]) - if region_mean_grad > bp_min_grad: - segmented_image_fixed[labeled_bright_dim_seg == value] = 2 - bp_count += 1 - gran_count = len(values) - 1 - bp_count # Subtract 1 for IG region. - return segmented_image_fixed, bp_count, gran_count - - -def segments_overlap_fraction(segment1, segment2): - """ - Compute the fraction of overlap between two segmented - `~sunpy.map.GenericMap`. - - Designed for comparing output Map from `segment` with other segmentation methods. - - Parameters - ---------- - segment1: `~sunpy.map.GenericMap` - Main `~sunpy.map.GenericMap` to compare against. Must have 0 = intergranule, 1 = granule. - segment2 :`~sunpy.map.GenericMap` - Comparison `~sunpy.map.GenericMap`. Must have 0 = intergranule, 1 = granule. - As an example, this could come from a simple segment using sklearn.cluster.KMeans - - Returns - ------- - confidence : `float` - The numeric confidence metric: 0 = no agreement and 1 = complete agreement. - """ - segment1 = np.array(segment1.data) - segment2 = np.array(segment2.data) - total_granules = np.count_nonzero(segment1 == 1) - total_intergranules = np.count_nonzero(segment1 == 0) - if total_granules == 0: - msg = "No granules in `segment1`. It is possible the clustering failed." - raise ValueError(msg) - if total_intergranules == 0: - msg = "No intergranules in `segment1`. It is possible the clustering failed." - raise ValueError(msg) - granule_agreement_count = 0 - intergranule_agreement_count = 0 - granule_agreement_count = ((segment1 == 1) * (segment2 == 1)).sum() - intergranule_agreement_count = ((segment1 == 0) * (segment2 == 0)).sum() - percentage_agreement_granules = granule_agreement_count / total_granules - percentage_agreement_intergranules = intergranule_agreement_count / total_intergranules - return np.mean([percentage_agreement_granules, percentage_agreement_intergranules]) +""" +This module contains functions that will segment images for granule detection. +""" + +import logging + +import matplotlib as mpl +import numpy as np +import scipy +import skimage + +import sunpy +import sunpy.map + +__all__ = ["segment", "segments_overlap_fraction"] + + +def segment(smap, *, skimage_method="li", mark_dim_centers=False, bp_min_flux=None): + """ + Segment an optical image of the solar photosphere into tri-value maps with: + + * 0 as intergranule + * 1 as granule + * 2 as brightpoint + + If mark_dim_centers is set to True, an additional label, 3, will be assigned to + dim granule centers. + + Parameters + ---------- + smap : `~sunpy.map.GenericMap` + `~sunpy.map.GenericMap` containing data to segment. Must have square pixels. + skimage_method : { "li", "otsu", "isodata", "mean", "minimum", "yen", "triangle" }, optional + scikit-image thresholding method, defaults to "li". + Depending on input data, one or more of these methods may be + significantly better or worse than the others. Typically, 'li', 'otsu', + 'mean', and 'isodata' are good choices, 'yen' and 'triangle' over- + identify intergranule material, and 'minimum' over identifies granules. + mark_dim_centers : `bool`, optional + Whether to mark dim granule centers as a separate category for future exploration. + bp_min_flux : `float`, optional + Minimum flux per pixel for a region to be considered a brightpoint. + Default is `None` which will use data mean + 0.5 * sigma. + + Returns + ------- + segmented_map : `~sunpy.map.GenericMap` + `~sunpy.map.GenericMap` containing a segmented image (with the original header). + """ + if not isinstance(smap, sunpy.map.mapbase.GenericMap): + msg = "Input must be an instance of a sunpy.map.GenericMap" + raise TypeError(msg) + if smap.scale[0].value == smap.scale[1].value: + resolution = smap.scale[0].value + else: + msg = "Currently only maps with square pixels are supported." + raise ValueError(msg) + # Obtain local histogram equalization of map. + # min-max normalization to [0, 1] + map_norm = (smap.data - np.nanmin(smap.data)) / (np.nanmax(smap.data) - np.nanmin(smap.data)) + map_he = skimage.filters.rank.equalize( + skimage.util.img_as_ubyte(map_norm), + footprint=skimage.morphology.disk(radius=100), + ) + # Apply initial skimage threshold. + median_filtered = scipy.ndimage.median_filter(map_he, size=3) + threshold = _get_threshold(median_filtered, skimage_method) + segmented_image = np.uint8(median_filtered > threshold) + # Fix the extra intergranule material bits in the middle of granules. + seg_im_fixed = _trim_intergranules(segmented_image, mark=mark_dim_centers) + # Mark brightpoint and get final granule and brightpoint count. + seg_im_markbp, brightpoint_count, granule_count = _mark_brightpoint( + seg_im_fixed, + smap.data, + map_he, + resolution, + bp_min_flux, + ) + logging.info(f"Segmentation has identified {granule_count} granules and {brightpoint_count} brightpoint") # NOQA: G004 + # Create output map using input wcs and adding colormap such that 0 (intergranules) = black, 1 (granule) = white, 2 (brightpoints) = yellow, 3 (dim_centers) = blue. + segmented_map = sunpy.map.Map(seg_im_markbp, smap.wcs) + cmap = mpl.colors.ListedColormap(["black", "white", "#ffc406", "blue"]) + norm = mpl.colors.BoundaryNorm(boundaries=[-0.5, 0.5, 1.5, 2.5, 3.5], ncolors=cmap.N) + segmented_map.plot_settings["cmap"] = cmap + segmented_map.plot_settings["norm"] = norm + return segmented_map + + +def _get_threshold(data, method): + """ + Get the threshold value using given skimage segmentation type. + + Parameters + ---------- + data : `numpy.ndarray` + Data to threshold. + method : { "li", "otsu", "isodata", "mean", "minimum", "yen", "triangle" } + scikit-image thresholding method. + + Returns + ------- + threshold : `float` + Threshold value. + """ + if not isinstance(data, np.ndarray): + msg = "Input data must be an instance of a np.ndarray" + raise TypeError(msg) + if len(data.flatten()) > 500**2: + logging.info( + "Input image is large (> 500**2), so threshold computation will be based on a random 500x500 sample of pixels", + ) + rng = np.random.default_rng() + # Computing threshold based on random sample works well and saves significant computational time + data = rng.choice( + data.flatten(), + (500, 500), + ) + method = method.lower() + method_funcs = { + "li": skimage.filters.threshold_li, + "otsu": skimage.filters.threshold_otsu, + "yen": skimage.filters.threshold_yen, + "mean": skimage.filters.threshold_mean, + "minimum": skimage.filters.threshold_minimum, + "triangle": skimage.filters.threshold_triangle, + "isodata": skimage.filters.threshold_isodata, + } + if method not in method_funcs: + raise ValueError("Method must be one of: " + ", ".join(list(method_funcs.keys()))) + return method_funcs[method](data) + + +def _trim_intergranules(segmented_image, *, mark=False): + """ + Remove the erroneous identification of intergranule material in the middle + of granules that the pure threshold segmentation produces. + + Parameters + ---------- + segmented_image : `numpy.ndarray` + The segmented image containing incorrect extra intergranules. + mark : `bool` + If `False` (the default), remove erroneous intergranules. + If `True`, mark them as 3 instead (for later examination). + + Returns + ------- + segmented_image_fixed : `numpy.ndarray` + The segmented image without incorrect extra intergranules. + """ + if len(np.unique(segmented_image)) > 2: + msg = "segmented_image must only have values of 1 and 0." + raise ValueError(msg) + # Float conversion for correct region labeling. + segmented_image_fixed = np.copy(segmented_image).astype(float) + # Add padding of intergranule around edges. + # Avoids the case where all edge pixels are granule, + # which will result in all dim centers as intergranules. + pad = int(np.shape(segmented_image)[0] / 200) + segmented_image_fixed[:, 0:pad] = 0 + segmented_image_fixed[0:pad, :] = 0 + segmented_image_fixed[:, -pad:] = 0 + segmented_image_fixed[-pad:, :] = 0 + labeled_seg = skimage.measure.label(segmented_image_fixed + 1, connectivity=2) + values = np.unique(labeled_seg) + # Find value of the large continuous 0-valued region. + size = 0 + for value in values: + if len(labeled_seg[labeled_seg == value]) > size and sum(segmented_image[labeled_seg == value] == 0): + real_IG_value = value + size = len(labeled_seg[labeled_seg == value]) + # Set all other 0 regions to mark value (3). + for value in values: + if np.sum(segmented_image[labeled_seg == value]) == 0 and value != real_IG_value: + segmented_image_fixed[labeled_seg == value] = 3 if mark else 1 + return segmented_image_fixed + + +def _mark_brightpoint(segmented_image, data, he_data, resolution, bp_min_flux=None): + """ + Mark brightpoints separately from granules - give them a value of 2. + + Parameters + ---------- + segmented_image : `numpy.ndarray` + The segmented image containing incorrect middles. + data : `numpy array` + The original image. + he_data : `numpy array` + Original image with local histogram equalization applied. + resolution : `float` + Spatial resolution (arcsec/pixel) of the data. + bp_min_flux : `float`, optional + Minimum flux per pixel for a region to be considered a brightpoint. + Default is `None` which will use data mean + 0.5 * sigma. + + Returns + ------- + segmented_image_fixed : `numpy.ndrray` + The segmented image with brightpoints marked as 2. + brightpoint_count: `int` + The number of brightpoints identified in the image. + granule_count: `int` + The number of granules identified, after re-classification of brightpoint. + """ + # General size limits + bp_size_limit = ( + 0.1 # Approximate max size of a photosphere bright point in square arcsec (see doi 10.3847/1538-4357/aab150) + ) + bp_pix_upper_limit = (bp_size_limit / resolution) ** 2 # Max area in pixels + bp_pix_lower_limit = 4 # Very small bright regions are likely artifacts + # General flux limit determined by visual inspection (set using equalized map) + if bp_min_flux is None: + stand_devs = 1.25 # General flux limit determined by visual inspection (set using equalized map) + bp_brightness_limit = np.nanmean(he_data) + stand_devs * np.nanstd(he_data) + else: + bp_brightness_limit = bp_min_flux + if len(np.unique(segmented_image)) > 3: + msg = "segmented_image must have only values of 1, 0 and 3 (if dim centers marked)" + raise ValueError(msg) + # Obtain gradient map and set threshold for gradient on BP edges + grad = np.abs(np.gradient(data)[0] + np.gradient(data)[1]) + bp_min_grad = np.quantile(grad, 0.95) + # Label all regions of flux greater than brightness limit (candidate regions) + bright_dim_seg = np.zeros_like(data) + bright_dim_seg[he_data > bp_brightness_limit] = 1 + labeled_bright_dim_seg = skimage.measure.label(bright_dim_seg + 1, connectivity=2) + values = np.unique(labeled_bright_dim_seg) + # From candidate regions, select those within pixel limit and gradient limit + segmented_image_fixed = np.copy(segmented_image.astype(float)) # Make type float to enable adding float values + bp_count = 0 + for value in values: + if (bright_dim_seg[labeled_bright_dim_seg == value])[0] == 1: # Check region is not the non-bp region + # check that region is within pixel limits. + region_size = len(labeled_bright_dim_seg[labeled_bright_dim_seg == value]) + if region_size < bp_pix_upper_limit and region_size > bp_pix_lower_limit: + # check that region has high average gradient (maybe try max gradient?) + region_mean_grad = np.mean(grad[labeled_bright_dim_seg == value]) + if region_mean_grad > bp_min_grad: + segmented_image_fixed[labeled_bright_dim_seg == value] = 2 + bp_count += 1 + gran_count = len(values) - 1 - bp_count # Subtract 1 for IG region. + return segmented_image_fixed, bp_count, gran_count + + +def segments_overlap_fraction(segment1, segment2): + """ + Compute the fraction of overlap between two segmented + `~sunpy.map.GenericMap`. + + Designed for comparing output Map from `segment` with other segmentation methods. + + Parameters + ---------- + segment1: `~sunpy.map.GenericMap` + Main `~sunpy.map.GenericMap` to compare against. Must have 0 = intergranule, 1 = granule. + segment2 :`~sunpy.map.GenericMap` + Comparison `~sunpy.map.GenericMap`. Must have 0 = intergranule, 1 = granule. + As an example, this could come from a simple segment using sklearn.cluster.KMeans + + Returns + ------- + confidence : `float` + The numeric confidence metric: 0 = no agreement and 1 = complete agreement. + """ + segment1 = np.array(segment1.data) + segment2 = np.array(segment2.data) + total_granules = np.count_nonzero(segment1 == 1) + total_intergranules = np.count_nonzero(segment1 == 0) + if total_granules == 0: + msg = "No granules in `segment1`. It is possible the clustering failed." + raise ValueError(msg) + if total_intergranules == 0: + msg = "No intergranules in `segment1`. It is possible the clustering failed." + raise ValueError(msg) + granule_agreement_count = 0 + intergranule_agreement_count = 0 + granule_agreement_count = ((segment1 == 1) * (segment2 == 1)).sum() + intergranule_agreement_count = ((segment1 == 0) * (segment2 == 0)).sum() + percentage_agreement_granules = granule_agreement_count / total_granules + percentage_agreement_intergranules = intergranule_agreement_count / total_intergranules + return np.mean([percentage_agreement_granules, percentage_agreement_intergranules]) diff --git a/sunkit_image/tests/test_granule.py b/sunkit_image/tests/test_granule.py index 1d4379a1..1f8c17cb 100644 --- a/sunkit_image/tests/test_granule.py +++ b/sunkit_image/tests/test_granule.py @@ -1,124 +1,124 @@ -import numpy as np -import pytest - -import sunpy -from sunpy.map import all_pixel_indices_from_map - -from sunkit_image.granule import ( - _get_threshold, - _mark_brightpoint, - _trim_intergranules, - segment, - segments_overlap_fraction, -) - - -def test_segment(granule_map): - segmented = segment(granule_map, skimage_method="li", mark_dim_centers=True) - assert isinstance(segmented, sunpy.map.mapbase.GenericMap) - # Check pixels are not empty. - initial_pix = all_pixel_indices_from_map(granule_map).value - seg_pixels = all_pixel_indices_from_map(segmented).value - assert np.size(seg_pixels) > 0 - assert seg_pixels.shape == initial_pix.shape - # Check that the values in the array have changed - assert np.any(np.not_equal(granule_map.data, segmented.data)) - - -def test_segment_errors(granule_map): - with pytest.raises(TypeError, match="Input must be an instance of a sunpy.map.GenericMap"): - segment(np.array([[1, 2, 3], [1, 2, 3]])) - with pytest.raises(ValueError, match="Method must be one of: li, otsu, yen, mean, minimum, triangle, isodata"): - segment(granule_map, skimage_method="banana") - - -def test_get_threshold(): - test_arr1 = np.array([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]) - threshold1 = _get_threshold(test_arr1, "li") - assert isinstance(threshold1, np.float64) - # Check that different arrays return different thresholds. - test_arr2 = np.array([[2, 3, 4, 5, 6], [2, 3, 4, 5, 6]]) - threshold2 = _get_threshold(test_arr2, "li") - assert threshold1 != threshold2 - - -def test_get_threshold_range(): - test_arr1 = np.array([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]) - threshold1 = _get_threshold(test_arr1, "li") - assert 0 < threshold1 < np.max(test_arr1) - - -def test_get_threshold_errors(): - with pytest.raises(TypeError, match="Input data must be an instance of a np.ndarray"): - _get_threshold([], "li") - with pytest.raises(ValueError, match="Method must be one of: li, otsu, yen, mean, minimum, triangle, isodata"): - _get_threshold(np.array([[1, 2], [1, 2]]), "banana") - - -def test_trim_intergranules(granule_map): - thresholded = np.uint8(granule_map.data > np.nanmedian(granule_map.data)) - # Check that returned array is not empty. - assert np.size(thresholded) > 0 - # Check that the correct dimensions are returned. - assert thresholded.shape == _trim_intergranules(thresholded).shape - # Check that erroneous zero values are caught and re-assigned - # e.g. inside of pad region, returned array has fewer 0-valued pixels then input - middles_removed = _trim_intergranules(thresholded) - pad = int(np.shape(thresholded)[0] / 200) - assert not np.count_nonzero(middles_removed[pad:-pad, pad:-pad]) < np.count_nonzero(thresholded[pad:-pad, pad:-pad]) - # Check that when mark=True, erroneous 0 values are set to 3 - middles_marked = _trim_intergranules(thresholded, mark=True) - marked_as_3 = np.count_nonzero(middles_marked[middles_marked == 3]) - assert marked_as_3 != 0 - # Check that when mark=False, erroneous 0 values are "removed" (set to 1), returning NO 3 values - middles_marked = _trim_intergranules(thresholded, mark=False) - marked_as_3 = np.count_nonzero(middles_marked[middles_marked == 3]) - assert marked_as_3 == 0 - - -def test_trim_intergranules_errors(): - rng = np.random.default_rng() - # Check that raises error if passed array is not binary. - data = rng.integers(low=0, high=10, size=(10, 10)) - with pytest.raises(ValueError, match="segmented_image must only have values of 1 and 0."): - _trim_intergranules(data) - - -def test_mark_brightpoint(granule_map, granule_map_he): - thresholded = np.uint8(granule_map.data > np.nanmedian(granule_map_he)) - brightpoint_marked, _, _ = _mark_brightpoint( - thresholded, - granule_map.data, - granule_map_he, - resolution=0.016, - bp_min_flux=None, - ) - # Check that the correct dimensions are returned. - assert thresholded.shape == brightpoint_marked.shape - # Check that returned array is not empty. - assert np.size(brightpoint_marked) > 0 - # Check that the returned array has some pixels of value 2 (for a dataset that we know has brightpoints by eye). - assert (brightpoint_marked == 2).sum() > 0 - - -def test_mark_brightpoint_error(granule_map, granule_map_he): - # Check that errors are raised for incorrect granule_map. - with pytest.raises(ValueError, match="segmented_image must have only"): - _mark_brightpoint(granule_map.data, granule_map.data, granule_map_he, resolution=0.016, bp_min_flux=None) - - -def test_segments_overlap_fraction(granule_minimap1): - # Check that segments_overlap_fraction is 1 when Maps are equal. - assert segments_overlap_fraction(granule_minimap1, granule_minimap1) == 1.0 - - -def test_segments_overlap_fraction2(granule_minimap1, granule_minimap2): - # Check that segments_overlap_fraction is between 0 and 1 when Maps are not equal. - assert segments_overlap_fraction(granule_minimap1, granule_minimap2) <= 1 - assert segments_overlap_fraction(granule_minimap1, granule_minimap2) >= 0 - - -def test_segments_overlap_fraction_errors(granule_minimap3): - # Check that error is raised if there are no granules or intergranules in image. - with pytest.raises(Exception, match="clustering failed"): - segments_overlap_fraction(granule_minimap3, granule_minimap3) +import numpy as np +import pytest + +import sunpy +from sunpy.map import all_pixel_indices_from_map + +from sunkit_image.granule import ( + _get_threshold, + _mark_brightpoint, + _trim_intergranules, + segment, + segments_overlap_fraction, +) + + +def test_segment(granule_map): + segmented = segment(granule_map, skimage_method="li", mark_dim_centers=True) + assert isinstance(segmented, sunpy.map.mapbase.GenericMap) + # Check pixels are not empty. + initial_pix = all_pixel_indices_from_map(granule_map).value + seg_pixels = all_pixel_indices_from_map(segmented).value + assert np.size(seg_pixels) > 0 + assert seg_pixels.shape == initial_pix.shape + # Check that the values in the array have changed + assert np.any(np.not_equal(granule_map.data, segmented.data)) + + +def test_segment_errors(granule_map): + with pytest.raises(TypeError, match="Input must be an instance of a sunpy.map.GenericMap"): + segment(np.array([[1, 2, 3], [1, 2, 3]])) + with pytest.raises(ValueError, match="Method must be one of: li, otsu, yen, mean, minimum, triangle, isodata"): + segment(granule_map, skimage_method="banana") + + +def test_get_threshold(): + test_arr1 = np.array([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]) + threshold1 = _get_threshold(test_arr1, "li") + assert isinstance(threshold1, np.float64) + # Check that different arrays return different thresholds. + test_arr2 = np.array([[2, 3, 4, 5, 6], [2, 3, 4, 5, 6]]) + threshold2 = _get_threshold(test_arr2, "li") + assert threshold1 != threshold2 + + +def test_get_threshold_range(): + test_arr1 = np.array([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]) + threshold1 = _get_threshold(test_arr1, "li") + assert 0 < threshold1 < np.max(test_arr1) + + +def test_get_threshold_errors(): + with pytest.raises(TypeError, match="Input data must be an instance of a np.ndarray"): + _get_threshold([], "li") + with pytest.raises(ValueError, match="Method must be one of: li, otsu, yen, mean, minimum, triangle, isodata"): + _get_threshold(np.array([[1, 2], [1, 2]]), "banana") + + +def test_trim_intergranules(granule_map): + thresholded = np.uint8(granule_map.data > np.nanmedian(granule_map.data)) + # Check that returned array is not empty. + assert np.size(thresholded) > 0 + # Check that the correct dimensions are returned. + assert thresholded.shape == _trim_intergranules(thresholded).shape + # Check that erroneous zero values are caught and re-assigned + # e.g. inside of pad region, returned array has fewer 0-valued pixels then input + middles_removed = _trim_intergranules(thresholded) + pad = int(np.shape(thresholded)[0] / 200) + assert not np.count_nonzero(middles_removed[pad:-pad, pad:-pad]) < np.count_nonzero(thresholded[pad:-pad, pad:-pad]) + # Check that when mark=True, erroneous 0 values are set to 3 + middles_marked = _trim_intergranules(thresholded, mark=True) + marked_as_3 = np.count_nonzero(middles_marked[middles_marked == 3]) + assert marked_as_3 != 0 + # Check that when mark=False, erroneous 0 values are "removed" (set to 1), returning NO 3 values + middles_marked = _trim_intergranules(thresholded, mark=False) + marked_as_3 = np.count_nonzero(middles_marked[middles_marked == 3]) + assert marked_as_3 == 0 + + +def test_trim_intergranules_errors(): + rng = np.random.default_rng() + # Check that raises error if passed array is not binary. + data = rng.integers(low=0, high=10, size=(10, 10)) + with pytest.raises(ValueError, match="segmented_image must only have values of 1 and 0."): + _trim_intergranules(data) + + +def test_mark_brightpoint(granule_map, granule_map_he): + thresholded = np.uint8(granule_map.data > np.nanmedian(granule_map_he)) + brightpoint_marked, _, _ = _mark_brightpoint( + thresholded, + granule_map.data, + granule_map_he, + resolution=0.016, + bp_min_flux=None, + ) + # Check that the correct dimensions are returned. + assert thresholded.shape == brightpoint_marked.shape + # Check that returned array is not empty. + assert np.size(brightpoint_marked) > 0 + # Check that the returned array has some pixels of value 2 (for a dataset that we know has brightpoints by eye). + assert (brightpoint_marked == 2).sum() > 0 + + +def test_mark_brightpoint_error(granule_map, granule_map_he): + # Check that errors are raised for incorrect granule_map. + with pytest.raises(ValueError, match="segmented_image must have only"): + _mark_brightpoint(granule_map.data, granule_map.data, granule_map_he, resolution=0.016, bp_min_flux=None) + + +def test_segments_overlap_fraction(granule_minimap1): + # Check that segments_overlap_fraction is 1 when Maps are equal. + assert segments_overlap_fraction(granule_minimap1, granule_minimap1) == 1.0 + + +def test_segments_overlap_fraction2(granule_minimap1, granule_minimap2): + # Check that segments_overlap_fraction is between 0 and 1 when Maps are not equal. + assert segments_overlap_fraction(granule_minimap1, granule_minimap2) <= 1 + assert segments_overlap_fraction(granule_minimap1, granule_minimap2) >= 0 + + +def test_segments_overlap_fraction_errors(granule_minimap3): + # Check that error is raised if there are no granules or intergranules in image. + with pytest.raises(Exception, match="clustering failed"): + segments_overlap_fraction(granule_minimap3, granule_minimap3)