Skip to content

Commit

Permalink
Merge pull request #1569 from bnmajor/masking-boundaries
Browse files Browse the repository at this point in the history
Masking boundaries
  • Loading branch information
bnmajor authored Sep 20, 2023
2 parents 4a552a9 + 9c2b7ee commit ba1880b
Show file tree
Hide file tree
Showing 8 changed files with 384 additions and 195 deletions.
4 changes: 4 additions & 0 deletions hexrd/ui/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,7 @@ class LLNLTransform:
'IMAGE-PLATE-4',
],
}

KEY_ROTATE_ANGLE_FINE = 0.00175
KEY_ROTATE_ANGLE_COARSE = 0.01
KEY_TRANSLATE_DELTA = 0.5
3 changes: 2 additions & 1 deletion hexrd/ui/hexrd_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,7 +874,8 @@ def raw_masks_dict(self):
for det, mask in data:
if det == name:
final_mask = np.logical_and(final_mask, mask)
if self.threshold_mask_status:
if (self.threshold_mask_status and
self.threshold_masks.get(name) is not None):
idx = self.current_imageseries_idx
thresh_mask = self.threshold_masks[name][idx]
final_mask = np.logical_and(final_mask, thresh_mask)
Expand Down
4 changes: 3 additions & 1 deletion hexrd/ui/image_canvas.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
import math

from PySide2.QtCore import QThreadPool, QTimer, Signal
from PySide2.QtCore import QThreadPool, QTimer, Signal, Qt
from PySide2.QtWidgets import QFileDialog, QMessageBox

from matplotlib.backends.backend_qt5agg import FigureCanvas
Expand Down Expand Up @@ -71,6 +71,8 @@ def __init__(self, parent=None, image_names=None):
if image_names is not None:
self.load_images(image_names)

self.setFocusPolicy(Qt.ClickFocus)

self.setup_connections()

def setup_connections(self):
Expand Down
192 changes: 140 additions & 52 deletions hexrd/ui/interactive_template.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,21 @@
import numpy as np

from PySide2.QtCore import Qt

from matplotlib import patches
from matplotlib.path import Path
from matplotlib.transforms import Affine2D

from skimage.draw import polygon

from hexrd.ui.create_hedm_instrument import create_hedm_instrument
from hexrd.ui import resource_loader
from hexrd.ui.constants import (
KEY_ROTATE_ANGLE_FINE, KEY_TRANSLATE_DELTA, ViewType
)
from hexrd.ui.hexrd_config import HexrdConfig
from hexrd.ui.utils import has_nan


class InteractiveTemplate:
def __init__(self, parent=None):
self.parent = parent.image_tab_widget.image_canvases[0]
self.ax = self.parent.axes_images[0]
self.panels = create_hedm_instrument().detectors
def __init__(self, canvas, detector, axes=None, instrument=None):
self.current_canvas = canvas
self.img = None
self.shape = None
self.press = None
Expand All @@ -27,11 +24,54 @@ def __init__(self, parent=None):
self.translation = [0, 0]
self.complete = False
self.event_key = None
self.parent.setFocusPolicy(Qt.ClickFocus)
self.detector = detector
self.instrument = instrument
self._static = True
self.axis_image = (
axes.get_images()[0] if axes else canvas.axes_images[0])
self._key_angle = KEY_ROTATE_ANGLE_FINE

self.button_press_cid = None
self.button_release_cid = None
self.motion_cid = None
self.key_press_cid = None
self.button_drag_cid = None

@property
def axis(self):
if not self.current_canvas.raw_axes:
return self.current_canvas.axis

for axes in self.current_canvas.raw_axes.values():
if axes.get_title() == self.detector:
return axes

return list(self.current_canvas.raw_axes.values())[0]

@property
def static_mode(self):
return self._static

@static_mode.setter
def static_mode(self, mode):
if mode == self._static:
return

self._static = mode
self.update_style(color='black')
if not mode:
self.connect_translate_rotate()
self.update_style(color='red')

@property
def raw_axes(self):
return list(self.parent.raw_axes.values())[0]
def key_rotation_angle(self):
return self._key_angle

@key_rotation_angle.setter
def key_rotation_angle(self, angle=None):
if angle is None:
angle = KEY_ROTATE_ANGLE
self._key_angle = angle

def update_image(self, img):
self.img = img
Expand All @@ -41,32 +81,42 @@ def rotate_shape(self, angle):
self.rotate_template(self.shape.xy, angle)
self.redraw()

def create_shape(self, module, file_name, det, instr):
def create_polygon(self, verts, **polygon_kwargs):
self.complete = False
with resource_loader.resource_path(module, file_name) as f:
data = np.loadtxt(f)
verts = self.panels['default'].cartToPixel(data)
verts[:, [0, 1]] = verts[:, [1, 0]]
self.shape = patches.Polygon(verts, fill=False, lw=1, color='cyan')
self.shape = patches.Polygon(verts, **polygon_kwargs)
if has_nan(verts):
# This template contains more than one polygon and the last point
# should not be connected to the first. See Tardis IP for example.
self.shape.set_closed(False)
self.shape_styles.append({'line': '-', 'width': 1, 'color': 'cyan'})
self.update_position(instr, det)
self.shape_styles.append(polygon_kwargs)
self.update_position()
self.connect_translate_rotate()
self.raw_axes.add_patch(self.shape)
self.axis.add_patch(self.shape)
self.redraw()

def update_style(self, style, width, color):
self.shape_styles[-1] = {'line': style, 'width': width, 'color': color}
self.shape.set_linestyle(style)
self.shape.set_linewidth(width)
self.shape.set_edgecolor(color)
def update_style(self, style=None, width=None, color=None):
if not self.shape:
return

if style:
self.shape.set_linestyle(style)
if width:
self.shape.set_linewidth(width)
if color:
self.shape.set_edgecolor(color)
self.shape_styles[-1] = {
'line': self.shape.get_linestyle(),
'width': self.shape.get_linewidth(),
'color': self.shape.get_edgecolor()
}
self.shape.set_fill(False)
self.redraw()

def update_position(self, instr, det):
pos = HexrdConfig().boundary_position(instr, det)
def update_position(self):
pos = None
if self.instrument is not None:
pos = HexrdConfig().boundary_position(
self.instrument, self.detector)
if pos is None:
self.center = self.get_midpoint()
else:
Expand All @@ -75,7 +125,7 @@ def update_position(self, instr, det):
self.translate_template(dx, dy)
self.total_rotation = pos['angle']
self.rotate_template(self.shape.xy, pos['angle'])
if instr == 'PXRDIP':
if self.instrument == 'PXRDIP':
self.rotate_shape(angle=90)

@property
Expand All @@ -89,7 +139,7 @@ def masked_image(self):

@property
def bounds(self):
l, r, b, t = self.ax.get_extent()
l, r, b, t = self.axis_image.get_extent()
x0, y0 = np.nanmin(self.shape.xy, axis=0)
x1, y1 = np.nanmax(self.shape.xy, axis=0)
return np.array([max(np.floor(y0), t),
Expand All @@ -110,13 +160,13 @@ def rotation(self):
return self.total_rotation

def clear(self):
if self.shape in self.raw_axes.patches:
if self.shape in self.axis.patches:
self.shape.remove()
self.redraw()
self.total_rotation = 0.

def save_boundary(self, color):
if self.shape in self.raw_axes.patches:
if self.shape in self.axis.patches:
self.shape.set_linestyle('--')
self.redraw()

Expand All @@ -134,26 +184,26 @@ def toggle_boundaries(self, show):
# This template contains more than one polygon and the last point
# should not be connected to the first. See Tardis IP for example.
shape.set_closed(False)
self.raw_axes.add_patch(shape)
self.axis.add_patch(shape)
if self.shape:
self.shape = self.raw_axes.patches[-1]
self.shape = self.axis.patches[-1]
self.shape.remove()
self.shape.set_linestyle(self.shape_styles[-1]['line'])
self.raw_axes.add_patch(self.shape)
self.axis.add_patch(self.shape)
self.connect_translate_rotate()
self.redraw()
else:
if self.shape:
self.disconnect()
self.patches = [p for p in self.raw_axes.patches]
self.patches = [p for p in self.axis.patches]
self.redraw()

def disconnect(self):
self.parent.mpl_disconnect(self.button_press_cid)
self.parent.mpl_disconnect(self.button_release_cid)
self.parent.mpl_disconnect(self.motion_cid)
self.parent.mpl_disconnect(self.key_press_cid)
self.parent.mpl_disconnect(self.button_drag_cid)
self.current_canvas.mpl_disconnect(self.button_press_cid)
self.current_canvas.mpl_disconnect(self.button_release_cid)
self.current_canvas.mpl_disconnect(self.motion_cid)
self.current_canvas.mpl_disconnect(self.key_press_cid)
self.current_canvas.mpl_disconnect(self.button_drag_cid)

def completed(self):
self.disconnect()
Expand Down Expand Up @@ -199,7 +249,7 @@ def get_paths(self):
return all_paths

def redraw(self):
self.parent.draw_idle()
self.current_canvas.draw_idle()

def scale_template(self, sx=1, sy=1):
xy = self.shape.xy
Expand All @@ -214,6 +264,9 @@ def scale_template(self, sx=1, sy=1):
self.redraw()

def on_press(self, event):
if self.static_mode:
return

self.event_key = event.key
if event.key is None:
self.on_press_translate(event)
Expand All @@ -227,23 +280,31 @@ def on_release(self, event):
self.on_rotate_release(event)

def on_key(self, event):
if self.static_mode:
return

if 'shift' in event.key:
self.on_key_rotate(event)
else:
self.on_key_translate(event)

def connect_translate_rotate(self):
self.button_press_cid = self.parent.mpl_connect(
if self.static_mode:
return

self.disconnect()

self.button_press_cid = self.current_canvas.mpl_connect(
'button_press_event', self.on_press)
self.button_release_cid = self.parent.mpl_connect(
self.button_release_cid = self.current_canvas.mpl_connect(
'button_release_event', self.on_release)
self.motion_cid = self.parent.mpl_connect(
self.motion_cid = self.current_canvas.mpl_connect(
'motion_notify_event', self.on_translate)
self.key_press_cid = self.parent.mpl_connect(
self.key_press_cid = self.current_canvas.mpl_connect(
'key_press_event', self.on_key)
self.button_drag_cid = self.parent.mpl_connect(
self.button_drag_cid = self.current_canvas.mpl_connect(
'motion_notify_event', self.on_rotate)
self.parent.setFocus()
self.current_canvas.setFocus()

def translate_template(self, dx, dy):
self.shape.set_xy(self.shape.xy + np.array([dx, dy]))
Expand All @@ -253,7 +314,7 @@ def translate_template(self, dx, dy):
def on_key_translate(self, event):
dx0, dy0 = self.translation
dx1, dy1 = 0, 0
delta = 0.5
delta = KEY_TRANSLATE_DELTA
if event.key == 'right':
dx1 = delta
elif event.key == 'left':
Expand Down Expand Up @@ -315,13 +376,33 @@ def on_press_rotate(self, event):
# need to set the press value twice
self.press = self.shape.xy, event.xdata, event.ydata
self.center = self.get_midpoint()
self.shape.set_transform(self.ax.axes.transData)
self.shape.set_transform(self.axis_image.axes.transData)
self.press = self.shape.xy, event.xdata, event.ydata

def rotate_template(self, points, angle):
center = self.center
canvas = self.current_canvas
if canvas.mode == ViewType.polar:
# We need to correct for the extent ratio and the aspect ratio
# Make a copy to modify (we should *not* modify the original)
points = np.array(points)
extent = canvas.iviewer.pv.extent

canvas_aspect = compute_aspect_ratio(canvas.axis)
extent_aspect = (extent[2] - extent[3]) / (extent[1] - extent[0])

aspect_ratio = extent_aspect * canvas_aspect
points[:, 0] *= aspect_ratio
center = (center[0] * aspect_ratio, center[1])

x = [np.cos(angle), np.sin(angle)]
y = [-np.sin(angle), np.cos(angle)]
verts = np.dot(points - self.center, np.array([x, y])) + self.center
verts = np.dot(points - center, np.array([x, y])) + center

if canvas.mode == ViewType.polar:
# Reverse the aspect ratio correction
verts[:, 0] /= aspect_ratio

self.shape.set_xy(verts)

def on_rotate(self, event):
Expand All @@ -337,7 +418,7 @@ def on_rotate(self, event):
self.redraw()

def on_key_rotate(self, event):
angle = 0.00175
angle = self.key_rotation_angle
# !!! only catch arrow keys
if event.key == 'shift+left' or event.key == 'shift+up':
angle *= -1.
Expand All @@ -353,7 +434,7 @@ def get_midpoint(self):
return [(x1 + x0)/2, (y1 + y0)/2]

def mouse_position(self, e):
xmin, xmax, ymin, ymax = self.ax.get_extent()
xmin, xmax, ymin, ymax = self.axis_image.get_extent()
x, y = self.get_midpoint()
xdata = e.xdata
ydata = e.ydata
Expand Down Expand Up @@ -388,3 +469,10 @@ def on_rotate_release(self, event):
self.press = None
self.rotate_template(xy, angle)
self.redraw()


def compute_aspect_ratio(axis):
# Compute the aspect ratio of a matplotlib axis
ll, ur = axis.get_position() * axis.figure.get_size_inches()
width, height = ur - ll
return width / height
Loading

0 comments on commit ba1880b

Please sign in to comment.