Skip to content

Commit

Permalink
Merge pull request #95 from neutrons/config_managment
Browse files Browse the repository at this point in the history
Configuration management mechanism and initial settings variables
  • Loading branch information
mpatrou authored Aug 22, 2023
2 parents 21c7b88 + e73eacf commit eccdce4
Show file tree
Hide file tree
Showing 12 changed files with 843 additions and 46 deletions.
12 changes: 12 additions & 0 deletions src/shiver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import shiver.models.makeslice # noqa: F401 pylint: disable=unused-import
import shiver.models.convert_dgs_to_single_mde # noqa: F401 pylint: disable=unused-import
import shiver.models.generate_dgs_mde # noqa: F401 pylint: disable=unused-import
from shiver.configuration import Configuration
from .version import __version__

# make sure matplotlib is correctly set before we import shiver
Expand All @@ -31,6 +32,17 @@ def main():
"""
Main entry point for Qt application
"""
config = Configuration()
if not config.is_valid():
msg = (
"Error with configuration settings!",
f"Check and update your file: {config.config_file_path}",
"with the latest settings found here:",
f"{config.template_file_path} and start the application again.",
)

print(" ".join(msg))
return
app = QApplication(sys.argv)
window = Shiver()
window.show()
Expand Down
98 changes: 98 additions & 0 deletions src/shiver/configuration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""Module to load the the settings from SHOME/.shiver/configuration.ini file
Will fall back to a default"""
import os
import shutil

from configparser import ConfigParser
from pathlib import Path
from mantid.kernel import Logger

logger = Logger("SHIVER")

# configuration settings file path
CONFIG_PATH_FILE = os.path.join(Path.home(), ".shiver", "configuration.ini")


class Configuration:
"""Load and validate Configuration Data"""

def __init__(self):
"""initialization of configuration mechanism"""
# capture the current state
self.valid = False

# locate the template configuration file
project_directory = Path(__file__).resolve().parent
self.template_file_path = os.path.join(project_directory, "configuration_template.ini")

# retrieve the file path of the file
self.config_file_path = CONFIG_PATH_FILE
logger.information(f"{self.config_file_path} with be used")

# if template conf file path exists
if os.path.exists(self.template_file_path):
# file does not exist create it from template
if not os.path.exists(self.config_file_path):
# if directory structure does not exist create it
if not os.path.exists(os.path.dirname(self.config_file_path)):
os.makedirs(os.path.dirname(self.config_file_path))
shutil.copy2(self.template_file_path, self.config_file_path)

self.config = ConfigParser()
# parse the file
try:
self.config.read(self.config_file_path)
# validate the file has the all the latest variables
self.validate()
except ValueError as err:
logger.error(str(err))
logger.error(f"Problem with the file: {self.config_file_path}")
else:
logger.error(f"Template configuration file: {self.template_file_path} is missing!")

def validate(self):
"""validates that the fields exist at the config_file_path and writes any missing fields/data
using the template configuration file: configuration_template.ini as a guide"""
template_config = ConfigParser()
template_config.read(self.template_file_path)
for section in template_config.sections():
# if section is missing
if section not in self.config.sections():
# copy the whole section
self.config.add_section(section)

for field in template_config[section]:
if field not in self.config[section]:
# copy the field
self.config[section][field] = template_config[section][field]
with open(self.config_file_path, "w", encoding="utf8") as config_file:
self.config.write(config_file)
self.valid = True

def is_valid(self):
"""returns the configuration state"""
return self.valid


def get_data(section, name=None):
"""retrieves the configuration data for a variable with name"""
# default file path location
config_file_path = CONFIG_PATH_FILE
if os.path.exists(config_file_path):
config = ConfigParser()
# parse the file
config.read(config_file_path)
try:
if name:
value = config[section][name]
# in case of boolean string value cast it to bool
if value in ("True", "False"):
return value == "True"
return value
return config[section]
except KeyError as err:
# requested section/field do not exist
logger.error(str(err))
return None
return None
16 changes: 16 additions & 0 deletions src/shiver/configuration_template.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
[generate_tab.oncat]
#url to oncat portal
oncat_url = https://oncat.ornl.gov
#client id for on cat; it is unique for Shiver
client_id = 99025bb3-ce06-4f4b-bcf2-36ebf925cd1d
#the flag (bool: True/False) indicates the location of the names of the datasets (notes/comments vs. sequence name)
use_notes = False

[main_tab.plot]
#options: full prints dimension data, name_only, workspace title, None: no title is printed in plots
title = True
#the flag (bool: True/False) indicates the plot scale (logarithmic or not)
logarithmic_intensity = False

[global.other]
help_url = https://neutrons.github.io/Shiver/GUI/
4 changes: 3 additions & 1 deletion src/shiver/models/help.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
""" single help module """
import webbrowser
from shiver.configuration import get_data


def help_function(context):
"""
open a browser with the appropriate help page
"""
help_url = get_data("global.other", "help_url")
if context:
webbrowser.open("https://neutrons.github.io/Shiver/GUI/")
webbrowser.open(help_url)
25 changes: 18 additions & 7 deletions src/shiver/views/histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
)
from qtpy.QtCore import Signal

from shiver.configuration import get_data
from .loading_buttons import LoadingButtons
from .histogram_parameters import HistogramParameter
from .workspace_tables import InputWorkspaces, HistogramWorkspaces
Expand Down Expand Up @@ -78,13 +79,7 @@ def make_slice_finish(self, ws_name, ndims):
self.makeslice_finish_signal.emit(ws_name, ndims)

def _make_slice_finish(self, ws_name, ndims):
display_name = self.plot_display_name_callback(ws_name, ndims)
min_intensity = self.histogram_parameters.dimensions.intensity_min.text()
max_intensity = self.histogram_parameters.dimensions.intensity_max.text()
intensity_limits = {
"min": float(min_intensity) if min_intensity != "" else None,
"max": float(max_intensity) if max_intensity != "" else None,
}
display_name, intensity_limits = self.get_plot_data(ws_name, ndims)
do_default_plot(ws_name, ndims, display_name, intensity_limits)
self.histogram_workspaces.histogram_workspaces.set_selected(ws_name)

Expand Down Expand Up @@ -203,3 +198,19 @@ def unset_all(self):
self.input_workspaces.mde_workspaces.unset_all()
self.input_workspaces.norm_workspaces.deselect_all()
self.set_field_invalid_state(self.input_workspaces.mde_workspaces)

def get_plot_data(self, ws_name, ndims):
"""Get display name and intensities data for plotting."""
plot_title_preference = get_data("main_tab.plot", "title")
display_name = None
if plot_title_preference == "full":
display_name = self.plot_display_name_callback(ws_name, ndims)
if plot_title_preference == "name_only":
display_name = ws_name.name()
min_intensity = self.histogram_parameters.dimensions.intensity_min.text()
max_intensity = self.histogram_parameters.dimensions.intensity_max.text()
intensity_limits = {
"min": float(min_intensity) if min_intensity != "" else None,
"max": float(max_intensity) if max_intensity != "" else None,
}
return (display_name, intensity_limits)
1 change: 0 additions & 1 deletion src/shiver/views/mainwindow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def __init__(self, parent=None):
super().__init__(parent)

self.tabs = QTabWidget()

histogram = Histogram(self)
histogram_model = HistogramModel()
self.histogram_presenter = HistogramPresenter(histogram, histogram_model)
Expand Down
22 changes: 13 additions & 9 deletions src/shiver/views/oncat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,7 @@
QDoubleSpinBox,
)
from qtpy.QtCore import QTimer

# CONSTANTS
# NOTE: the client ID is unique for Shiver
ONCAT_URL = "https://oncat.ornl.gov"
CLIENT_ID = "99025bb3-ce06-4f4b-bcf2-36ebf925cd1d"
from shiver.configuration import get_data


class OncatToken:
Expand Down Expand Up @@ -53,15 +49,20 @@ def write_token(self, token):
class OnCatAgent:
"""Agent to interface with OnCat"""

def __init__(self) -> None:
def __init__(self, use_notes=False) -> None:
"""Initialize OnCat agent"""
# get configuration settings
self._use_notes = use_notes
self._oncat_url = get_data("generate_tab.oncat", "oncat_url")
self._client_id = get_data("generate_tab.oncat", "client_id")

user_home_dir = os.path.expanduser("~")
self._token = OncatToken(
os.path.abspath(f"{user_home_dir}/.shiver/oncat_token.json"),
)
self._agent = pyoncat.ONCat(
ONCAT_URL,
client_id=CLIENT_ID,
self._oncat_url,
client_id=self._client_id,
# Pass in token getter/setter callbacks here:
token_getter=self._token.read_token,
token_setter=self._token.write_token,
Expand Down Expand Up @@ -149,6 +150,7 @@ def get_datasets(self, facility: str, instrument: str, ipts: int) -> list:
self._agent,
ipts_number=ipts,
instrument=instrument,
use_notes=self._use_notes,
facility=facility,
)

Expand Down Expand Up @@ -275,8 +277,9 @@ def __init__(self, parent=None):
# error message callback
self.error_message_callback = None

self.use_notes = get_data("generate_tab.oncat", "use_notes")
# OnCat agent
self.oncat_agent = OnCatAgent()
self.oncat_agent = OnCatAgent(self.use_notes)

# Sync with remote
self.sync_with_remote(refresh=True)
Expand Down Expand Up @@ -326,6 +329,7 @@ def get_suggested_selected_files(self) -> list:
login=self.oncat_agent.get_agent_instance(),
ipts_number=self.get_ipts_number(),
instrument=self.get_instrument(),
use_notes=self.use_notes,
facility=self.get_facility(),
group_by_angle=group_by_angle,
angle_bin=self.angle_target.value(),
Expand Down
49 changes: 37 additions & 12 deletions src/shiver/views/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,44 @@
import matplotlib.pyplot as plt
from mantidqt.widgets.sliceviewer.presenters.presenter import SliceViewer
from mantidqt.plotting.functions import manage_workspace_names, plot_md_ws_from_names
from shiver.configuration import get_data


@manage_workspace_names
def do_1d_plot(workspaces, display_name, intensity_limits=None):
def do_1d_plot(workspaces, display_name, intensity_limits=None, log_scale=False):
"""Create an 1D plot for the provided workspace"""
fig = plot_md_ws_from_names(workspaces, False, False)
plot_title = display_name if display_name else workspaces[0].name()
min_limit = intensity_limits["min"] if intensity_limits is not None and "min" in intensity_limits else None
max_limit = intensity_limits["max"] if intensity_limits is not None and "max" in intensity_limits else None
fig.axes[0].set_ylim(min_limit, max_limit)

# logarithic case
if log_scale:
fig.axes[0].set_yscale("log")

# y limits
if not (min_limit is None and max_limit is None):
fig.axes[0].set_ylim(min_limit, max_limit)

fig.canvas.manager.set_window_title(display_name)
fig.axes[0].set_title(plot_title)
return fig


@manage_workspace_names
def do_colorfill_plot(workspaces, display_name=None, intensity_limits=None):
def do_colorfill_plot(workspaces, display_name=None, intensity_limits=None, log_scale=False):
"""Create a colormesh plot for the provided workspace"""
fig, axis = plt.subplots(subplot_kw={"projection": "mantid"})
plot_title = display_name or workspaces[0].name()
# y limits
min_limit = intensity_limits["min"] if intensity_limits is not None and "min" in intensity_limits else None
max_limit = intensity_limits["max"] if intensity_limits is not None and "max" in intensity_limits else None
colormesh = axis.pcolormesh(workspaces[0], vmin=min_limit, vmax=max_limit)

# logarithic case
scale_norm = "linear"
if log_scale:
scale_norm = "log"
colormesh = axis.pcolormesh(workspaces[0], vmin=min_limit, vmax=max_limit, norm=scale_norm)
axis.set_title(plot_title)
fig.canvas.manager.set_window_title(plot_title)

Expand All @@ -34,7 +49,7 @@ def do_colorfill_plot(workspaces, display_name=None, intensity_limits=None):


@manage_workspace_names
def do_slice_viewer(workspaces, parent=None, intensity_limits=None):
def do_slice_viewer(workspaces, parent=None, intensity_limits=None, log_scale=False):
"""Open sliceviewer for the provided workspace"""
presenter = SliceViewer(ws=workspaces[0], parent=parent)

Expand All @@ -48,21 +63,31 @@ def do_slice_viewer(workspaces, parent=None, intensity_limits=None):
if intensity_limits is not None and intensity_limits["max"] is not None
else presenter.view.data_view.colorbar.cmax_value
)
presenter.view.data_view.colorbar.cmin.setText(f"{min_limit:.4}")
presenter.view.data_view.colorbar.clim_changed()
presenter.view.data_view.colorbar.cmax.setText(f"{max_limit:.4}")
presenter.view.data_view.colorbar.clim_changed()

# y limits
if not (min_limit is None and max_limit is None):
presenter.view.data_view.colorbar.cmin.setText(f"{min_limit:.4}")
presenter.view.data_view.colorbar.clim_changed()
presenter.view.data_view.colorbar.cmax.setText(f"{max_limit:.4}")
presenter.view.data_view.colorbar.clim_changed()

# logarithic case
if log_scale:
norm_scale = "Log"
presenter.view.data_view.colorbar.norm.setCurrentText(norm_scale)
presenter.view.data_view.colorbar.norm_changed()

presenter.view.show()
return presenter.view


def do_default_plot(workspace, ndims, display_name=None, intensity_limits=None):
"""Create the default plot for the workspace and number of dimensions"""
log_scale = get_data("main_tab.plot", "logarithmic_intensity")
if ndims == 1:
return do_1d_plot([workspace], display_name, intensity_limits)
return do_1d_plot([workspace], display_name, intensity_limits, log_scale)
if ndims == 2:
return do_colorfill_plot([workspace], display_name, intensity_limits)
return do_colorfill_plot([workspace], display_name, intensity_limits, log_scale)
if ndims in (3, 4):
return do_slice_viewer([workspace], intensity_limits=intensity_limits)
return do_slice_viewer([workspace], intensity_limits=intensity_limits, log_scale=log_scale)
return None
16 changes: 16 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
"""pytest config"""
import os
from configparser import ConfigParser

import pytest
from mantid.simpleapi import mtd
from shiver import Shiver
Expand All @@ -16,3 +19,16 @@ def shiver_app():
def clear_ads():
"""clear the ADS after every test"""
mtd.clear()


@pytest.fixture(scope="session")
def user_conf_file(tmp_path_factory, request):
"""Fixture to create a custom configuration file in tmp_path"""
# custom configuration file
config_data = request.param
user_config = ConfigParser(allow_no_value=True)
user_config.read_string(config_data)
user_path = os.path.join(tmp_path_factory.mktemp("data"), "test_config.ini")
with open(user_path, "w", encoding="utf8") as config_file:
user_config.write(config_file)
return user_path
Loading

0 comments on commit eccdce4

Please sign in to comment.