Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Configuration management mechanism and initial settings variables #95

Merged
merged 10 commits into from
Aug 22, 2023
12 changes: 12 additions & 0 deletions src/shiver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import shiver.models.makeslice # noqa: F401 pylint: disable=unused-import
import shiver.models.convert_dgs_to_single_mde # noqa: F401 pylint: disable=unused-import
import shiver.models.generate_dgs_mde # noqa: F401 pylint: disable=unused-import
from shiver.configuration import Configuration
from .version import __version__

# make sure matplotlib is correctly set before we import shiver
Expand All @@ -31,6 +32,17 @@ def main():
"""
Main entry point for Qt application
"""
config = Configuration()
if not config.is_valid():
msg = (
"Error with configuration settings!",
f"Check and update your file: {config.config_file_path}",
"with the latest settings found here:",
f"{config.template_file_path} and start the application again.",
)

print(" ".join(msg))
return
app = QApplication(sys.argv)
window = Shiver()
window.show()
Expand Down
98 changes: 98 additions & 0 deletions src/shiver/configuration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""Module to load the the settings from SHOME/.shiver/configuration.ini file

Will fall back to a default"""
import os
import shutil

from configparser import ConfigParser
from pathlib import Path
from mantid.kernel import Logger

logger = Logger("SHIVER")

# configuration settings file path
CONFIG_PATH_FILE = os.path.join(Path.home(), ".shiver", "configuration.ini")


class Configuration:
"""Load and validate Configuration Data"""

def __init__(self):
"""initialization of configuration mechanism"""
# capture the current state
self.valid = False

# locate the template configuration file
project_directory = Path(__file__).resolve().parent
self.template_file_path = os.path.join(project_directory, "configuration_template.ini")

# retrieve the file path of the file
self.config_file_path = CONFIG_PATH_FILE
logger.information(f"{self.config_file_path} with be used")

# if template conf file path exists
if os.path.exists(self.template_file_path):
# file does not exist create it from template
if not os.path.exists(self.config_file_path):
# if directory structure does not exist create it
if not os.path.exists(os.path.dirname(self.config_file_path)):
os.makedirs(os.path.dirname(self.config_file_path))
shutil.copy2(self.template_file_path, self.config_file_path)

self.config = ConfigParser()
# parse the file
try:
self.config.read(self.config_file_path)
# validate the file has the all the latest variables
self.validate()
except ValueError as err:
logger.error(str(err))
logger.error(f"Problem with the file: {self.config_file_path}")
else:
logger.error(f"Template configuration file: {self.template_file_path} is missing!")

def validate(self):
"""validates that the fields exist at the config_file_path and writes any missing fields/data
using the template configuration file: configuration_template.ini as a guide"""
template_config = ConfigParser()
template_config.read(self.template_file_path)
for section in template_config.sections():
# if section is missing
if section not in self.config.sections():
# copy the whole section
self.config.add_section(section)

for field in template_config[section]:
if field not in self.config[section]:
# copy the field
self.config[section][field] = template_config[section][field]
with open(self.config_file_path, "w", encoding="utf8") as config_file:
self.config.write(config_file)
self.valid = True

def is_valid(self):
"""returns the configuration state"""
return self.valid


def get_data(section, name=None):
"""retrieves the configuration data for a variable with name"""
# default file path location
config_file_path = CONFIG_PATH_FILE
if os.path.exists(config_file_path):
config = ConfigParser()
# parse the file
config.read(config_file_path)
try:
if name:
value = config[section][name]
# in case of boolean string value cast it to bool
if value in ("True", "False"):
return value == "True"
return value
return config[section]
except KeyError as err:
# requested section/field do not exist
logger.error(str(err))
return None
return None
16 changes: 16 additions & 0 deletions src/shiver/configuration_template.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
[generate_tab.oncat]
#url to oncat portal
oncat_url = https://oncat.ornl.gov
#client id for on cat; it is unique for Shiver
client_id = 99025bb3-ce06-4f4b-bcf2-36ebf925cd1d
#the flag (bool: True/False) indicates the location of the names of the datasets (notes/comments vs. sequence name)
use_notes = False

[main_tab.plot]
#options: full prints dimension data, name_only, workspace title, None: no title is printed in plots
title = True
#the flag (bool: True/False) indicates the plot scale (logarithmic or not)
logarithmic_intensity = False

[global.other]
help_url = https://neutrons.github.io/Shiver/GUI/
4 changes: 3 additions & 1 deletion src/shiver/models/help.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
""" single help module """
import webbrowser
from shiver.configuration import get_data


def help_function(context):
"""
open a browser with the appropriate help page
"""
help_url = get_data("global.other", "help_url")
if context:
webbrowser.open("https://neutrons.github.io/Shiver/GUI/")
webbrowser.open(help_url)
25 changes: 18 additions & 7 deletions src/shiver/views/histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
)
from qtpy.QtCore import Signal

from shiver.configuration import get_data
from .loading_buttons import LoadingButtons
from .histogram_parameters import HistogramParameter
from .workspace_tables import InputWorkspaces, HistogramWorkspaces
Expand Down Expand Up @@ -78,13 +79,7 @@ def make_slice_finish(self, ws_name, ndims):
self.makeslice_finish_signal.emit(ws_name, ndims)

def _make_slice_finish(self, ws_name, ndims):
display_name = self.plot_display_name_callback(ws_name, ndims)
min_intensity = self.histogram_parameters.dimensions.intensity_min.text()
max_intensity = self.histogram_parameters.dimensions.intensity_max.text()
intensity_limits = {
"min": float(min_intensity) if min_intensity != "" else None,
"max": float(max_intensity) if max_intensity != "" else None,
}
display_name, intensity_limits = self.get_plot_data(ws_name, ndims)
do_default_plot(ws_name, ndims, display_name, intensity_limits)
self.histogram_workspaces.histogram_workspaces.set_selected(ws_name)

Expand Down Expand Up @@ -203,3 +198,19 @@ def unset_all(self):
self.input_workspaces.mde_workspaces.unset_all()
self.input_workspaces.norm_workspaces.deselect_all()
self.set_field_invalid_state(self.input_workspaces.mde_workspaces)

def get_plot_data(self, ws_name, ndims):
"""Get display name and intensities data for plotting."""
plot_title_preference = get_data("main_tab.plot", "title")
display_name = None
if plot_title_preference == "full":
display_name = self.plot_display_name_callback(ws_name, ndims)
if plot_title_preference == "name_only":
display_name = ws_name.name()
min_intensity = self.histogram_parameters.dimensions.intensity_min.text()
max_intensity = self.histogram_parameters.dimensions.intensity_max.text()
intensity_limits = {
"min": float(min_intensity) if min_intensity != "" else None,
"max": float(max_intensity) if max_intensity != "" else None,
}
return (display_name, intensity_limits)
1 change: 0 additions & 1 deletion src/shiver/views/mainwindow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def __init__(self, parent=None):
super().__init__(parent)

self.tabs = QTabWidget()

histogram = Histogram(self)
histogram_model = HistogramModel()
self.histogram_presenter = HistogramPresenter(histogram, histogram_model)
Expand Down
22 changes: 13 additions & 9 deletions src/shiver/views/oncat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,7 @@
QDoubleSpinBox,
)
from qtpy.QtCore import QTimer

# CONSTANTS
# NOTE: the client ID is unique for Shiver
ONCAT_URL = "https://oncat.ornl.gov"
CLIENT_ID = "99025bb3-ce06-4f4b-bcf2-36ebf925cd1d"
from shiver.configuration import get_data


class OncatToken:
Expand Down Expand Up @@ -53,15 +49,20 @@ def write_token(self, token):
class OnCatAgent:
"""Agent to interface with OnCat"""

def __init__(self) -> None:
def __init__(self, use_notes=False) -> None:
"""Initialize OnCat agent"""
# get configuration settings
self._use_notes = use_notes
self._oncat_url = get_data("generate_tab.oncat", "oncat_url")
self._client_id = get_data("generate_tab.oncat", "client_id")

user_home_dir = os.path.expanduser("~")
self._token = OncatToken(
os.path.abspath(f"{user_home_dir}/.shiver/oncat_token.json"),
)
self._agent = pyoncat.ONCat(
ONCAT_URL,
client_id=CLIENT_ID,
self._oncat_url,
client_id=self._client_id,
# Pass in token getter/setter callbacks here:
token_getter=self._token.read_token,
token_setter=self._token.write_token,
Expand Down Expand Up @@ -149,6 +150,7 @@ def get_datasets(self, facility: str, instrument: str, ipts: int) -> list:
self._agent,
ipts_number=ipts,
instrument=instrument,
use_notes=self._use_notes,
facility=facility,
)

Expand Down Expand Up @@ -275,8 +277,9 @@ def __init__(self, parent=None):
# error message callback
self.error_message_callback = None

self.use_notes = get_data("generate_tab.oncat", "use_notes")
# OnCat agent
self.oncat_agent = OnCatAgent()
self.oncat_agent = OnCatAgent(self.use_notes)

# Sync with remote
self.sync_with_remote(refresh=True)
Expand Down Expand Up @@ -326,6 +329,7 @@ def get_suggested_selected_files(self) -> list:
login=self.oncat_agent.get_agent_instance(),
ipts_number=self.get_ipts_number(),
instrument=self.get_instrument(),
use_notes=self.use_notes,
facility=self.get_facility(),
group_by_angle=group_by_angle,
angle_bin=self.angle_target.value(),
Expand Down
49 changes: 37 additions & 12 deletions src/shiver/views/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,44 @@
import matplotlib.pyplot as plt
from mantidqt.widgets.sliceviewer.presenters.presenter import SliceViewer
from mantidqt.plotting.functions import manage_workspace_names, plot_md_ws_from_names
from shiver.configuration import get_data


@manage_workspace_names
def do_1d_plot(workspaces, display_name, intensity_limits=None):
def do_1d_plot(workspaces, display_name, intensity_limits=None, log_scale=False):
"""Create an 1D plot for the provided workspace"""
fig = plot_md_ws_from_names(workspaces, False, False)
plot_title = display_name if display_name else workspaces[0].name()
min_limit = intensity_limits["min"] if intensity_limits is not None and "min" in intensity_limits else None
max_limit = intensity_limits["max"] if intensity_limits is not None and "max" in intensity_limits else None
fig.axes[0].set_ylim(min_limit, max_limit)

# logarithic case
if log_scale:
fig.axes[0].set_yscale("log")

# y limits
if not (min_limit is None and max_limit is None):
fig.axes[0].set_ylim(min_limit, max_limit)

fig.canvas.manager.set_window_title(display_name)
fig.axes[0].set_title(plot_title)
return fig


@manage_workspace_names
def do_colorfill_plot(workspaces, display_name=None, intensity_limits=None):
def do_colorfill_plot(workspaces, display_name=None, intensity_limits=None, log_scale=False):
"""Create a colormesh plot for the provided workspace"""
fig, axis = plt.subplots(subplot_kw={"projection": "mantid"})
plot_title = display_name or workspaces[0].name()
# y limits
min_limit = intensity_limits["min"] if intensity_limits is not None and "min" in intensity_limits else None
max_limit = intensity_limits["max"] if intensity_limits is not None and "max" in intensity_limits else None
colormesh = axis.pcolormesh(workspaces[0], vmin=min_limit, vmax=max_limit)

# logarithic case
scale_norm = "linear"
if log_scale:
scale_norm = "log"
colormesh = axis.pcolormesh(workspaces[0], vmin=min_limit, vmax=max_limit, norm=scale_norm)
axis.set_title(plot_title)
fig.canvas.manager.set_window_title(plot_title)

Expand All @@ -34,7 +49,7 @@ def do_colorfill_plot(workspaces, display_name=None, intensity_limits=None):


@manage_workspace_names
def do_slice_viewer(workspaces, parent=None, intensity_limits=None):
def do_slice_viewer(workspaces, parent=None, intensity_limits=None, log_scale=False):
"""Open sliceviewer for the provided workspace"""
presenter = SliceViewer(ws=workspaces[0], parent=parent)

Expand All @@ -48,21 +63,31 @@ def do_slice_viewer(workspaces, parent=None, intensity_limits=None):
if intensity_limits is not None and intensity_limits["max"] is not None
else presenter.view.data_view.colorbar.cmax_value
)
presenter.view.data_view.colorbar.cmin.setText(f"{min_limit:.4}")
presenter.view.data_view.colorbar.clim_changed()
presenter.view.data_view.colorbar.cmax.setText(f"{max_limit:.4}")
presenter.view.data_view.colorbar.clim_changed()

# y limits
if not (min_limit is None and max_limit is None):
presenter.view.data_view.colorbar.cmin.setText(f"{min_limit:.4}")
presenter.view.data_view.colorbar.clim_changed()
presenter.view.data_view.colorbar.cmax.setText(f"{max_limit:.4}")
presenter.view.data_view.colorbar.clim_changed()

# logarithic case
if log_scale:
norm_scale = "Log"
presenter.view.data_view.colorbar.norm.setCurrentText(norm_scale)
presenter.view.data_view.colorbar.norm_changed()

presenter.view.show()
return presenter.view


def do_default_plot(workspace, ndims, display_name=None, intensity_limits=None):
"""Create the default plot for the workspace and number of dimensions"""
log_scale = get_data("main_tab.plot", "logarithmic_intensity")
if ndims == 1:
return do_1d_plot([workspace], display_name, intensity_limits)
return do_1d_plot([workspace], display_name, intensity_limits, log_scale)
if ndims == 2:
return do_colorfill_plot([workspace], display_name, intensity_limits)
return do_colorfill_plot([workspace], display_name, intensity_limits, log_scale)
if ndims in (3, 4):
return do_slice_viewer([workspace], intensity_limits=intensity_limits)
return do_slice_viewer([workspace], intensity_limits=intensity_limits, log_scale=log_scale)
return None
16 changes: 16 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
"""pytest config"""
import os
from configparser import ConfigParser

import pytest
from mantid.simpleapi import mtd
from shiver import Shiver
Expand All @@ -16,3 +19,16 @@ def shiver_app():
def clear_ads():
"""clear the ADS after every test"""
mtd.clear()


@pytest.fixture(scope="session")
def user_conf_file(tmp_path_factory, request):
"""Fixture to create a custom configuration file in tmp_path"""
# custom configuration file
config_data = request.param
user_config = ConfigParser(allow_no_value=True)
user_config.read_string(config_data)
user_path = os.path.join(tmp_path_factory.mktemp("data"), "test_config.ini")
with open(user_path, "w", encoding="utf8") as config_file:
user_config.write(config_file)
return user_path
Loading