diff --git a/src/shiver/__init__.py b/src/shiver/__init__.py index 0af88963..743c996e 100644 --- a/src/shiver/__init__.py +++ b/src/shiver/__init__.py @@ -15,6 +15,7 @@ import shiver.models.makeslice # noqa: F401 pylint: disable=unused-import import shiver.models.convert_dgs_to_single_mde # noqa: F401 pylint: disable=unused-import import shiver.models.generate_dgs_mde # noqa: F401 pylint: disable=unused-import +from shiver.configuration import Configuration from .version import __version__ # make sure matplotlib is correctly set before we import shiver @@ -31,6 +32,17 @@ def main(): """ Main entry point for Qt application """ + config = Configuration() + if not config.is_valid(): + msg = ( + "Error with configuration settings!", + f"Check and update your file: {config.config_file_path}", + "with the latest settings found here:", + f"{config.template_file_path} and start the application again.", + ) + + print(" ".join(msg)) + return app = QApplication(sys.argv) window = Shiver() window.show() diff --git a/src/shiver/configuration.py b/src/shiver/configuration.py new file mode 100644 index 00000000..6395b04e --- /dev/null +++ b/src/shiver/configuration.py @@ -0,0 +1,98 @@ +"""Module to load the the settings from SHOME/.shiver/configuration.ini file + +Will fall back to a default""" +import os +import shutil + +from configparser import ConfigParser +from pathlib import Path +from mantid.kernel import Logger + +logger = Logger("SHIVER") + +# configuration settings file path +CONFIG_PATH_FILE = os.path.join(Path.home(), ".shiver", "configuration.ini") + + +class Configuration: + """Load and validate Configuration Data""" + + def __init__(self): + """initialization of configuration mechanism""" + # capture the current state + self.valid = False + + # locate the template configuration file + project_directory = Path(__file__).resolve().parent + self.template_file_path = os.path.join(project_directory, "configuration_template.ini") + + # retrieve the file path of the file + self.config_file_path = CONFIG_PATH_FILE + logger.information(f"{self.config_file_path} with be used") + + # if template conf file path exists + if os.path.exists(self.template_file_path): + # file does not exist create it from template + if not os.path.exists(self.config_file_path): + # if directory structure does not exist create it + if not os.path.exists(os.path.dirname(self.config_file_path)): + os.makedirs(os.path.dirname(self.config_file_path)) + shutil.copy2(self.template_file_path, self.config_file_path) + + self.config = ConfigParser() + # parse the file + try: + self.config.read(self.config_file_path) + # validate the file has the all the latest variables + self.validate() + except ValueError as err: + logger.error(str(err)) + logger.error(f"Problem with the file: {self.config_file_path}") + else: + logger.error(f"Template configuration file: {self.template_file_path} is missing!") + + def validate(self): + """validates that the fields exist at the config_file_path and writes any missing fields/data + using the template configuration file: configuration_template.ini as a guide""" + template_config = ConfigParser() + template_config.read(self.template_file_path) + for section in template_config.sections(): + # if section is missing + if section not in self.config.sections(): + # copy the whole section + self.config.add_section(section) + + for field in template_config[section]: + if field not in self.config[section]: + # copy the field + self.config[section][field] = template_config[section][field] + with open(self.config_file_path, "w", encoding="utf8") as config_file: + self.config.write(config_file) + self.valid = True + + def is_valid(self): + """returns the configuration state""" + return self.valid + + +def get_data(section, name=None): + """retrieves the configuration data for a variable with name""" + # default file path location + config_file_path = CONFIG_PATH_FILE + if os.path.exists(config_file_path): + config = ConfigParser() + # parse the file + config.read(config_file_path) + try: + if name: + value = config[section][name] + # in case of boolean string value cast it to bool + if value in ("True", "False"): + return value == "True" + return value + return config[section] + except KeyError as err: + # requested section/field do not exist + logger.error(str(err)) + return None + return None diff --git a/src/shiver/configuration_template.ini b/src/shiver/configuration_template.ini new file mode 100644 index 00000000..e0a01093 --- /dev/null +++ b/src/shiver/configuration_template.ini @@ -0,0 +1,16 @@ +[generate_tab.oncat] +#url to oncat portal +oncat_url = https://oncat.ornl.gov +#client id for on cat; it is unique for Shiver +client_id = 99025bb3-ce06-4f4b-bcf2-36ebf925cd1d +#the flag (bool: True/False) indicates the location of the names of the datasets (notes/comments vs. sequence name) +use_notes = False + +[main_tab.plot] +#options: full prints dimension data, name_only, workspace title, None: no title is printed in plots +title = True +#the flag (bool: True/False) indicates the plot scale (logarithmic or not) +logarithmic_intensity = False + +[global.other] +help_url = https://neutrons.github.io/Shiver/GUI/ diff --git a/src/shiver/models/help.py b/src/shiver/models/help.py index c03bf4dc..44ad46d7 100644 --- a/src/shiver/models/help.py +++ b/src/shiver/models/help.py @@ -1,10 +1,12 @@ """ single help module """ import webbrowser +from shiver.configuration import get_data def help_function(context): """ open a browser with the appropriate help page """ + help_url = get_data("global.other", "help_url") if context: - webbrowser.open("https://neutrons.github.io/Shiver/GUI/") + webbrowser.open(help_url) diff --git a/src/shiver/views/histogram.py b/src/shiver/views/histogram.py index 5933de28..b5c22a89 100644 --- a/src/shiver/views/histogram.py +++ b/src/shiver/views/histogram.py @@ -6,6 +6,7 @@ ) from qtpy.QtCore import Signal +from shiver.configuration import get_data from .loading_buttons import LoadingButtons from .histogram_parameters import HistogramParameter from .workspace_tables import InputWorkspaces, HistogramWorkspaces @@ -78,13 +79,7 @@ def make_slice_finish(self, ws_name, ndims): self.makeslice_finish_signal.emit(ws_name, ndims) def _make_slice_finish(self, ws_name, ndims): - display_name = self.plot_display_name_callback(ws_name, ndims) - min_intensity = self.histogram_parameters.dimensions.intensity_min.text() - max_intensity = self.histogram_parameters.dimensions.intensity_max.text() - intensity_limits = { - "min": float(min_intensity) if min_intensity != "" else None, - "max": float(max_intensity) if max_intensity != "" else None, - } + display_name, intensity_limits = self.get_plot_data(ws_name, ndims) do_default_plot(ws_name, ndims, display_name, intensity_limits) self.histogram_workspaces.histogram_workspaces.set_selected(ws_name) @@ -203,3 +198,19 @@ def unset_all(self): self.input_workspaces.mde_workspaces.unset_all() self.input_workspaces.norm_workspaces.deselect_all() self.set_field_invalid_state(self.input_workspaces.mde_workspaces) + + def get_plot_data(self, ws_name, ndims): + """Get display name and intensities data for plotting.""" + plot_title_preference = get_data("main_tab.plot", "title") + display_name = None + if plot_title_preference == "full": + display_name = self.plot_display_name_callback(ws_name, ndims) + if plot_title_preference == "name_only": + display_name = ws_name.name() + min_intensity = self.histogram_parameters.dimensions.intensity_min.text() + max_intensity = self.histogram_parameters.dimensions.intensity_max.text() + intensity_limits = { + "min": float(min_intensity) if min_intensity != "" else None, + "max": float(max_intensity) if max_intensity != "" else None, + } + return (display_name, intensity_limits) diff --git a/src/shiver/views/mainwindow.py b/src/shiver/views/mainwindow.py index f1e4aa4f..90b552ff 100644 --- a/src/shiver/views/mainwindow.py +++ b/src/shiver/views/mainwindow.py @@ -21,7 +21,6 @@ def __init__(self, parent=None): super().__init__(parent) self.tabs = QTabWidget() - histogram = Histogram(self) histogram_model = HistogramModel() self.histogram_presenter = HistogramPresenter(histogram, histogram_model) diff --git a/src/shiver/views/oncat.py b/src/shiver/views/oncat.py index 6f0ca1d0..c0141c96 100644 --- a/src/shiver/views/oncat.py +++ b/src/shiver/views/oncat.py @@ -15,11 +15,7 @@ QDoubleSpinBox, ) from qtpy.QtCore import QTimer - -# CONSTANTS -# NOTE: the client ID is unique for Shiver -ONCAT_URL = "https://oncat.ornl.gov" -CLIENT_ID = "99025bb3-ce06-4f4b-bcf2-36ebf925cd1d" +from shiver.configuration import get_data class OncatToken: @@ -53,15 +49,20 @@ def write_token(self, token): class OnCatAgent: """Agent to interface with OnCat""" - def __init__(self) -> None: + def __init__(self, use_notes=False) -> None: """Initialize OnCat agent""" + # get configuration settings + self._use_notes = use_notes + self._oncat_url = get_data("generate_tab.oncat", "oncat_url") + self._client_id = get_data("generate_tab.oncat", "client_id") + user_home_dir = os.path.expanduser("~") self._token = OncatToken( os.path.abspath(f"{user_home_dir}/.shiver/oncat_token.json"), ) self._agent = pyoncat.ONCat( - ONCAT_URL, - client_id=CLIENT_ID, + self._oncat_url, + client_id=self._client_id, # Pass in token getter/setter callbacks here: token_getter=self._token.read_token, token_setter=self._token.write_token, @@ -149,6 +150,7 @@ def get_datasets(self, facility: str, instrument: str, ipts: int) -> list: self._agent, ipts_number=ipts, instrument=instrument, + use_notes=self._use_notes, facility=facility, ) @@ -275,8 +277,9 @@ def __init__(self, parent=None): # error message callback self.error_message_callback = None + self.use_notes = get_data("generate_tab.oncat", "use_notes") # OnCat agent - self.oncat_agent = OnCatAgent() + self.oncat_agent = OnCatAgent(self.use_notes) # Sync with remote self.sync_with_remote(refresh=True) @@ -326,6 +329,7 @@ def get_suggested_selected_files(self) -> list: login=self.oncat_agent.get_agent_instance(), ipts_number=self.get_ipts_number(), instrument=self.get_instrument(), + use_notes=self.use_notes, facility=self.get_facility(), group_by_angle=group_by_angle, angle_bin=self.angle_target.value(), diff --git a/src/shiver/views/plots.py b/src/shiver/views/plots.py index f713cdf5..3d37d9f5 100644 --- a/src/shiver/views/plots.py +++ b/src/shiver/views/plots.py @@ -2,29 +2,44 @@ import matplotlib.pyplot as plt from mantidqt.widgets.sliceviewer.presenters.presenter import SliceViewer from mantidqt.plotting.functions import manage_workspace_names, plot_md_ws_from_names +from shiver.configuration import get_data @manage_workspace_names -def do_1d_plot(workspaces, display_name, intensity_limits=None): +def do_1d_plot(workspaces, display_name, intensity_limits=None, log_scale=False): """Create an 1D plot for the provided workspace""" fig = plot_md_ws_from_names(workspaces, False, False) plot_title = display_name if display_name else workspaces[0].name() min_limit = intensity_limits["min"] if intensity_limits is not None and "min" in intensity_limits else None max_limit = intensity_limits["max"] if intensity_limits is not None and "max" in intensity_limits else None - fig.axes[0].set_ylim(min_limit, max_limit) + + # logarithic case + if log_scale: + fig.axes[0].set_yscale("log") + + # y limits + if not (min_limit is None and max_limit is None): + fig.axes[0].set_ylim(min_limit, max_limit) + fig.canvas.manager.set_window_title(display_name) fig.axes[0].set_title(plot_title) return fig @manage_workspace_names -def do_colorfill_plot(workspaces, display_name=None, intensity_limits=None): +def do_colorfill_plot(workspaces, display_name=None, intensity_limits=None, log_scale=False): """Create a colormesh plot for the provided workspace""" fig, axis = plt.subplots(subplot_kw={"projection": "mantid"}) plot_title = display_name or workspaces[0].name() + # y limits min_limit = intensity_limits["min"] if intensity_limits is not None and "min" in intensity_limits else None max_limit = intensity_limits["max"] if intensity_limits is not None and "max" in intensity_limits else None - colormesh = axis.pcolormesh(workspaces[0], vmin=min_limit, vmax=max_limit) + + # logarithic case + scale_norm = "linear" + if log_scale: + scale_norm = "log" + colormesh = axis.pcolormesh(workspaces[0], vmin=min_limit, vmax=max_limit, norm=scale_norm) axis.set_title(plot_title) fig.canvas.manager.set_window_title(plot_title) @@ -34,7 +49,7 @@ def do_colorfill_plot(workspaces, display_name=None, intensity_limits=None): @manage_workspace_names -def do_slice_viewer(workspaces, parent=None, intensity_limits=None): +def do_slice_viewer(workspaces, parent=None, intensity_limits=None, log_scale=False): """Open sliceviewer for the provided workspace""" presenter = SliceViewer(ws=workspaces[0], parent=parent) @@ -48,10 +63,19 @@ def do_slice_viewer(workspaces, parent=None, intensity_limits=None): if intensity_limits is not None and intensity_limits["max"] is not None else presenter.view.data_view.colorbar.cmax_value ) - presenter.view.data_view.colorbar.cmin.setText(f"{min_limit:.4}") - presenter.view.data_view.colorbar.clim_changed() - presenter.view.data_view.colorbar.cmax.setText(f"{max_limit:.4}") - presenter.view.data_view.colorbar.clim_changed() + + # y limits + if not (min_limit is None and max_limit is None): + presenter.view.data_view.colorbar.cmin.setText(f"{min_limit:.4}") + presenter.view.data_view.colorbar.clim_changed() + presenter.view.data_view.colorbar.cmax.setText(f"{max_limit:.4}") + presenter.view.data_view.colorbar.clim_changed() + + # logarithic case + if log_scale: + norm_scale = "Log" + presenter.view.data_view.colorbar.norm.setCurrentText(norm_scale) + presenter.view.data_view.colorbar.norm_changed() presenter.view.show() return presenter.view @@ -59,10 +83,11 @@ def do_slice_viewer(workspaces, parent=None, intensity_limits=None): def do_default_plot(workspace, ndims, display_name=None, intensity_limits=None): """Create the default plot for the workspace and number of dimensions""" + log_scale = get_data("main_tab.plot", "logarithmic_intensity") if ndims == 1: - return do_1d_plot([workspace], display_name, intensity_limits) + return do_1d_plot([workspace], display_name, intensity_limits, log_scale) if ndims == 2: - return do_colorfill_plot([workspace], display_name, intensity_limits) + return do_colorfill_plot([workspace], display_name, intensity_limits, log_scale) if ndims in (3, 4): - return do_slice_viewer([workspace], intensity_limits=intensity_limits) + return do_slice_viewer([workspace], intensity_limits=intensity_limits, log_scale=log_scale) return None diff --git a/tests/conftest.py b/tests/conftest.py index c22bff39..466adb3f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,7 @@ """pytest config""" +import os +from configparser import ConfigParser + import pytest from mantid.simpleapi import mtd from shiver import Shiver @@ -16,3 +19,16 @@ def shiver_app(): def clear_ads(): """clear the ADS after every test""" mtd.clear() + + +@pytest.fixture(scope="session") +def user_conf_file(tmp_path_factory, request): + """Fixture to create a custom configuration file in tmp_path""" + # custom configuration file + config_data = request.param + user_config = ConfigParser(allow_no_value=True) + user_config.read_string(config_data) + user_path = os.path.join(tmp_path_factory.mktemp("data"), "test_config.ini") + with open(user_path, "w", encoding="utf8") as config_file: + user_config.write(config_file) + return user_path diff --git a/tests/models/test_configuration.py b/tests/models/test_configuration.py new file mode 100644 index 00000000..12573514 --- /dev/null +++ b/tests/models/test_configuration.py @@ -0,0 +1,207 @@ +"""Tests for Configuration mechanism""" +import os +from configparser import ConfigParser +from pathlib import Path + +import pytest +from shiver import main +from shiver.configuration import Configuration, get_data + + +def test_config_path_default(): + """Test configuration default file path""" + config = Configuration() + assert config.config_file_path.endswith(".shiver/configuration.ini") is True + # check the valid state + assert config.is_valid() + assert config.valid == config.is_valid() + + +def test_config_path_in_folder(monkeypatch, tmp_path): + """Test configuration configuration user-defined file path that does not exist in a new directory""" + + user_path = os.path.join(tmp_path, "temp2", "test_config.ini") + assert not os.path.exists(user_path) + + monkeypatch.setattr("shiver.configuration.CONFIG_PATH_FILE", user_path) + + config = Configuration() + # check if the file exists now + assert os.path.exists(user_path) + assert config.is_valid() + + +def test_config_path_does_not_exist(monkeypatch, tmp_path): + """Test configuration user-defined file path that does not exist""" + user_path = os.path.join(tmp_path, "test_config.ini") + assert not os.path.exists(user_path) + + monkeypatch.setattr("shiver.configuration.CONFIG_PATH_FILE", user_path) + + config = Configuration() + # check if the file is exists now + assert os.path.exists(user_path) + assert config.is_valid() + + +@pytest.mark.parametrize( + "user_conf_file", + [ + """ + [generate_tab.oncat] + oncat_url = https://oncat.ornl.gov + client_id = 99025bb3-ce06-4f4b-bcf2-36ebf925cd1d + use_notes = False + + """ + ], + indirect=True, +) +def test_field_validate_fields_exist(monkeypatch, user_conf_file): + """Test configuration validate all fields exist with the same values as templates + Note: update the parameters if the fields increase""" + # read the custom configuration file + monkeypatch.setattr("shiver.configuration.CONFIG_PATH_FILE", user_conf_file) + user_config = Configuration() + + assert user_config.config_file_path.endswith(user_conf_file) is True + # check if the file exists + assert os.path.exists(user_conf_file) + + # check all fields are the same as the configuration template file + project_directory = Path(__file__).resolve().parent.parent.parent + template_file_path = os.path.join(project_directory, "src", "shiver", "configuration_template.ini") + template_config = ConfigParser() + template_config.read(template_file_path) + for section in user_config.config.sections(): + for field in user_config.config[section]: + assert user_config.config[section][field] == template_config[section][field] + + +@pytest.mark.parametrize( + "user_conf_file", + [ + """ + [generate_tab.oncat] + oncat_url = test_url + client_id = 0000-0000 + use_notes = True + """ + ], + indirect=True, +) +def test_field_validate_fields_same(monkeypatch, user_conf_file): + """Test configuration validate all fields exist with their values; different from the template""" + + # read the custom configuration file + monkeypatch.setattr("shiver.configuration.CONFIG_PATH_FILE", user_conf_file) + user_config = Configuration() + + # check if the file exists + assert os.path.exists(user_conf_file) + assert user_config.config_file_path == user_conf_file + + # check all field values have the same values as the user configuration file + assert get_data("generate_tab.oncat", "oncat_url") == "test_url" + assert get_data("generate_tab.oncat", "client_id") == "0000-0000" + # cast to bool + assert get_data("generate_tab.oncat", "use_notes") is True + + +@pytest.mark.parametrize( + "user_conf_file", + [ + """ + [generate_tab.oncat] + client_id = 0000-0000 + """ + ], + indirect=True, +) +def test_field_validate_fields_missing(monkeypatch, user_conf_file): + """Test configuration validate missing fields added from the template""" + + # read the custom configuration file + monkeypatch.setattr("shiver.configuration.CONFIG_PATH_FILE", user_conf_file) + user_config = Configuration() + + # check if the file exists + assert os.path.exists(user_conf_file) + assert user_config.config_file_path == user_conf_file + + # check all field values have the same values as the user configuration file + assert get_data("generate_tab.oncat", "oncat_url") == "https://oncat.ornl.gov" + assert get_data("generate_tab.oncat", "client_id") == "0000-0000" + assert get_data("generate_tab.oncat", "use_notes") is False + + +@pytest.mark.parametrize("user_conf_file", ["""[generate_tab.oncat]"""], indirect=True) +def test_get_data_valid(monkeypatch, user_conf_file): + """Test configuration get_data - valid""" + + monkeypatch.setattr("shiver.configuration.CONFIG_PATH_FILE", user_conf_file) + config = Configuration() + assert config.config_file_path.endswith(user_conf_file) is True + # get the data + # section + assert len(get_data("generate_tab.oncat", "")) == 3 + # fields + assert get_data("generate_tab.oncat", "oncat_url") == "https://oncat.ornl.gov" + assert get_data("generate_tab.oncat", "client_id") == "99025bb3-ce06-4f4b-bcf2-36ebf925cd1d" + assert get_data("generate_tab.oncat", "use_notes") is False + + assert config.is_valid() + + +@pytest.mark.parametrize( + "user_conf_file", + [ + """ + [generate_tab.oncat] + oncat_url = test_url + client_id = 0000-0000 + use_notes = 1 + """ + ], + indirect=True, +) +def test_get_data_invalid(monkeypatch, user_conf_file): + """Test configuration get_data - invalid""" + # read the custom configuration file + monkeypatch.setattr("shiver.configuration.CONFIG_PATH_FILE", user_conf_file) + config = Configuration() + assert config.config_file_path.endswith(user_conf_file) is True + + # section + assert get_data("section_not_here", "") is None + + assert len(get_data("generate_tab.oncat", "")) == 3 + # field + assert get_data("generate_tab.oncat", "field_not_here") is None + + +@pytest.mark.parametrize( + "user_conf_file", + [ + """ + [main_tab.plot] + title = name_only + logarithmic_intensity = False + """ + ], + indirect=True, +) +def test_conf_init_invalid(capsys, user_conf_file, monkeypatch): + """Test starting the app with invalid configuration""" + + # mock conf info + monkeypatch.setattr("shiver.configuration.CONFIG_PATH_FILE", user_conf_file) + + def mock_is_valid(self): # pylint: disable=unused-argument + return False + + monkeypatch.setattr("shiver.configuration.Configuration.is_valid", mock_is_valid) + + main() + captured = capsys.readouterr() + assert captured[0].startswith("Error with configuration settings!") diff --git a/tests/views/test_oncat.py b/tests/views/test_oncat.py index 14350cfc..8d818c09 100644 --- a/tests/views/test_oncat.py +++ b/tests/views/test_oncat.py @@ -1,6 +1,10 @@ #!/usr/bin/env python # pylint: disable=all """Test the views for the ONCat application.""" +import os +from pathlib import Path +from configparser import ConfigParser + import pytest import pyoncat from qtpy.QtWidgets import QGroupBox @@ -12,11 +16,18 @@ get_data_from_oncat, get_dataset_names, get_dataset_info, - ONCAT_URL, - CLIENT_ID, ) +def get_configuration_settings(): + """get configuration settings from the configuration_template file""" + template_config = ConfigParser() + project_directory = Path(__file__).resolve().parent.parent.parent + template_file_path = os.path.join(project_directory, "src", "shiver", "configuration_template.ini") + template_config.read(template_file_path) + return template_config + + class MockRecord: def __init__(self, *args, **kwargs) -> None: pass @@ -62,8 +73,20 @@ def mock_get_dataset_names(*args, **kwargs): # mockey patch class pyoncat.ONCat monkeypatch.setattr("pyoncat.ONCat", MockONcat) + # mockey patch get_settings data + def mock_get_data(*args, **kwargs): + # get configuration setting from the configuration_template file + template_config = get_configuration_settings() + return template_config[args[1], args[2]] + + monkeypatch.setattr("shiver.configuration.get_data", mock_get_data) + # test the class agent = OnCatAgent() + # test configuration settings are stored from template configuration file + assert agent._oncat_url == "https://oncat.ornl.gov" + assert agent._client_id == "99025bb3-ce06-4f4b-bcf2-36ebf925cd1d" + assert agent._use_notes is False # test login agent.login("test_login", "test_password") # test is_connected @@ -89,9 +112,12 @@ def login(self, *args, **kwargs) -> None: class DummyOnCatAgent: def __init__(self) -> None: + template_config = get_configuration_settings() + oncat_url = template_config["generate_tab.oncat"]["oncat_url"] + client_id = template_config["generate_tab.oncat"]["client_id"] self._agent = pyoncat.ONCat( - ONCAT_URL, - client_id=CLIENT_ID, + oncat_url, + client_id=client_id, flow=pyoncat.RESOURCE_OWNER_CREDENTIALS_FLOW, ) @@ -155,12 +181,24 @@ def error_message_callback(msg): assert err_msgs[-1] == "Invalid username or password. Please try again." -def test_oncat(monkeypatch, qtbot): +@pytest.mark.parametrize( + "user_conf_file", + [ + """ + [generate_tab.oncat] + oncat_url = test_url + client_id = 0000-0000 + use_notes = False + """ + ], + indirect=True, +) +def test_oncat(monkeypatch, user_conf_file, qtbot): """Test the Oncat class.""" # mockpatch OnCatAgent class MockOnCatAgent: - def __init__(self) -> None: + def __init__(self, use_notes) -> None: pass def login(self, *args, **kwargs) -> None: @@ -191,6 +229,9 @@ def mock_get_dataset_info(*args, **kwargs): monkeypatch.setattr("shiver.views.oncat.get_dataset_info", mock_get_dataset_info) + # mock get_oncat_url, client_id and use_notes info + monkeypatch.setattr("shiver.configuration.CONFIG_PATH_FILE", user_conf_file) + err_msgs = [] def error_message_callback(msg): @@ -201,6 +242,8 @@ def error_message_callback(msg): oncat.connect_error_callback(error_message_callback) qtbot.addWidget(oncat) oncat.show() + # test use_notes are saved from configuration settings + assert oncat.use_notes is False # test connect status check assert oncat.connected_to_oncat is True # test get_suggested_path diff --git a/tests/views/test_plots.py b/tests/views/test_plots.py index 52f54d24..70e5cccf 100644 --- a/tests/views/test_plots.py +++ b/tests/views/test_plots.py @@ -1,16 +1,33 @@ #!/usr/bin/env python """UI tests for Plots""" +import pytest + # pylint: disable=no-name-in-module from mantid.simpleapi import ( mtd, CreateMDHistoWorkspace, ) from shiver.views.plots import do_default_plot +from shiver.views.histogram import Histogram -def test_plot1d(qtbot): +@pytest.mark.parametrize( + "user_conf_file", + [ + """ + [main_tab.plot] + title = name_only + logarithmic_intensity = False + """ + ], + indirect=True, +) +def test_plot1d(qtbot, user_conf_file, monkeypatch): """Test for 1D plot with intensities and display title""" + # mock get_oncat_url, client_id and use_notes info + monkeypatch.setattr("shiver.configuration.CONFIG_PATH_FILE", user_conf_file) + # clear mantid workspace mtd.clear() @@ -30,12 +47,66 @@ def test_plot1d(qtbot): fig = do_default_plot(workspace, 1, title, {"min": intensity_min, "max": intensity_max}) assert fig.axes[0].get_title() == title assert fig.axes[0].get_ylim() == (intensity_min, intensity_max) - qtbot.wait(500) + assert fig.axes[0].get_yscale() == "linear" + + qtbot.wait(100) + + +@pytest.mark.parametrize( + "user_conf_file", + [ + """ + [main_tab.plot] + title = name_only + logarithmic_intensity = True + """ + ], + indirect=True, +) +def test_plot1d_scale(qtbot, user_conf_file, monkeypatch): + """Test for 1D plot with log scale""" + + # mock get_oncat_url, client_id and use_notes info + monkeypatch.setattr("shiver.configuration.CONFIG_PATH_FILE", user_conf_file) + + # clear mantid workspace + mtd.clear() + workspace = CreateMDHistoWorkspace( + Dimensionality=1, + Extents="-3,3", + SignalInput=range(0, 10), + ErrorInput=range(0, 10), + NumberOfBins="10", + Names="Dim1", + Units="MomentumTransfer", + ) -def test_plot2d(qtbot): + intensity_min = None + intensity_max = None + title = "1D Plot" + fig = do_default_plot(workspace, 1, title, {"min": intensity_min, "max": intensity_max}) + assert fig.axes[0].get_yscale() == "log" + qtbot.wait(100) + + +@pytest.mark.parametrize( + "user_conf_file", + [ + """ + [main_tab.plot] + title = name_only + logarithmic_intensity = False + """ + ], + indirect=True, +) +def test_plot2d(qtbot, user_conf_file, monkeypatch): """Test for 2D plot with intensities and display title""" + # mock get_oncat_url, client_id and use_notes info + monkeypatch.setattr("shiver.configuration.CONFIG_PATH_FILE", user_conf_file) + # clear mantid workspace mtd.clear() @@ -54,13 +125,70 @@ def test_plot2d(qtbot): title = "2D Plot" fig = do_default_plot(workspace, 2, title, {"min": intensity_min, "max": intensity_max}) assert fig.axes[0].get_title() == title - assert fig.axes[0].collections[0].get_clim() == (intensity_min, intensity_max) - qtbot.wait(500) + assert fig.axes[1].collections[1].get_clim() == (intensity_min, intensity_max) + assert fig.axes[1].get_yscale() == "linear" + + qtbot.wait(100) + + +@pytest.mark.parametrize( + "user_conf_file", + [ + """ + [main_tab.plot] + title = name_only + logarithmic_intensity = True + """ + ], + indirect=True, +) +def test_plot2d_scale(qtbot, user_conf_file, monkeypatch): + """Test for 2D plot with log scale""" + # mock get_oncat_url, client_id and use_notes info + monkeypatch.setattr("shiver.configuration.CONFIG_PATH_FILE", user_conf_file) -def test_plot3d(qtbot): + # clear mantid workspace + mtd.clear() + + workspace = CreateMDHistoWorkspace( + Dimensionality=2, + Extents="-2,2,-5,5", + SignalInput=range(0, 100), + ErrorInput=range(0, 100), + NumberOfBins="10,10", + Names="Dim1,Dim2", + Units="Momentum,Energy", + ) + + intensity_min = 0.001 + intensity_max = 1 + title = "2D Plot" + fig = do_default_plot(workspace, 2, title, {"min": intensity_min, "max": intensity_max}) + assert fig.axes[0].get_title() == title + assert fig.axes[1].collections[1].get_clim() == (intensity_min, intensity_max) + assert fig.axes[1].get_yscale() == "log" + + qtbot.wait(100) + + +@pytest.mark.parametrize( + "user_conf_file", + [ + """ + [main_tab.plot] + title = name_only + logarithmic_intensity = False + """ + ], + indirect=True, +) +def test_plot3d(qtbot, user_conf_file, monkeypatch): """Test for 3D plot with intensities""" + # mock get_oncat_url, client_id and use_notes info + monkeypatch.setattr("shiver.configuration.CONFIG_PATH_FILE", user_conf_file) + # clear mantid workspace mtd.clear() @@ -80,13 +208,64 @@ def test_plot3d(qtbot): view = do_default_plot(workspace, 3, title, {"min": intensity_min, "max": intensity_max}) assert view.data_view.colorbar.cmin_value == intensity_min assert view.data_view.colorbar.cmax_value == intensity_max + assert view.data_view.colorbar.norm.currentText() == "Linear" + + qtbot.wait(100) + + +@pytest.mark.parametrize( + "user_conf_file", + [ + """ + [main_tab.plot] + title = name_only + logarithmic_intensity = True + """ + ], + indirect=True, +) +def test_plot3d_scale(qtbot, user_conf_file, monkeypatch): + """Test for 3D plot with log scale""" + + # mock get_oncat_url, client_id and use_notes info + monkeypatch.setattr("shiver.configuration.CONFIG_PATH_FILE", user_conf_file) + + # clear mantid workspace + mtd.clear() - qtbot.wait(500) + workspace = CreateMDHistoWorkspace( + Dimensionality=3, + Extents="-5,5,-10,10,-20,20", + SignalInput=range(0, 1000), + ErrorInput=range(0, 1000), + NumberOfBins="10,10,10", + Names="Dim1,Dim2,Dim3", + Units="Energy,Momentum,Other", + ) + title = None + view = do_default_plot(workspace, 3, title) + assert view.data_view.colorbar.norm.currentText() == "Log" + qtbot.wait(100) -def test_plot4d(qtbot): + +@pytest.mark.parametrize( + "user_conf_file", + [ + """ + [main_tab.plot] + title = name_only + logarithmic_intensity = False + """ + ], + indirect=True, +) +def test_plot4d(qtbot, user_conf_file, monkeypatch): """Test for 4D plot with intensities""" + # mock get_oncat_url, client_id and use_notes info + monkeypatch.setattr("shiver.configuration.CONFIG_PATH_FILE", user_conf_file) + # clear mantid workspace mtd.clear() @@ -106,13 +285,68 @@ def test_plot4d(qtbot): view = do_default_plot(workspace, 4, title, {"min": intensity_min, "max": intensity_max}) assert view.data_view.colorbar.cmin_value == intensity_min assert view.data_view.colorbar.cmax_value == intensity_max + assert view.data_view.colorbar.norm.currentText() == "Linear" + + qtbot.wait(100) - qtbot.wait(500) +@pytest.mark.parametrize( + "user_conf_file", + [ + """ + [main_tab.plot] + title = name_only + logarithmic_intensity = True + """ + ], + indirect=True, +) +def test_plot4d_invalid_scale(qtbot, user_conf_file, monkeypatch): + """Test for 4D plot with invalid intensities""" + + # mock get_oncat_url, client_id and use_notes info + monkeypatch.setattr("shiver.configuration.CONFIG_PATH_FILE", user_conf_file) + + # clear mantid workspace + mtd.clear() -def test_plot5d(qtbot): + workspace = CreateMDHistoWorkspace( + Dimensionality=4, + Extents="-5,5,-10,10,-20,20,-30,30", + SignalInput=range(0, 10000), + ErrorInput=range(0, 10000), + NumberOfBins="10,10,10,10", + Names="Dim1,Dim2,Dim3,Dim4", + Units="M,E,O,EX", + ) + + intensity_min = -10.34 + intensity_max = 12.1 + title = None + view = do_default_plot(workspace, 4, title, {"min": intensity_min, "max": intensity_max}) + qtbot.wait(200) + # mantid plot updates user values if invalid are passed + assert view.data_view.colorbar.cmin_value != intensity_min + assert view.data_view.colorbar.norm.currentText() == "Log" + + +@pytest.mark.parametrize( + "user_conf_file", + [ + """ + [main_tab.plot] + title = name_only + logarithmic_intensity = False + """ + ], + indirect=True, +) +def test_plot5d(qtbot, user_conf_file, monkeypatch): """Test for 5D plot with intensities -invalid""" + # mock get_oncat_url, client_id and use_notes info + monkeypatch.setattr("shiver.configuration.CONFIG_PATH_FILE", user_conf_file) + # clear mantid workspace mtd.clear() @@ -132,3 +366,133 @@ def test_plot5d(qtbot): view = do_default_plot(workspace, 5, title, {"min": intensity_min, "max": intensity_max}) assert view is None qtbot.wait(100) + + +@pytest.mark.parametrize( + "user_conf_file", + [ + """ + [main_tab.plot] + title = name_only + logarithmic_intensity = False + """ + ], + indirect=True, +) +def test_plot_data_name_only(qtbot, user_conf_file, monkeypatch): + """Test plot inputs with name_only title""" + + # mock get_oncat_url, client_id and use_notes info + monkeypatch.setattr("shiver.configuration.CONFIG_PATH_FILE", user_conf_file) + histogram = Histogram() + # clear mantid workspace + mtd.clear() + + workspace = CreateMDHistoWorkspace( + Dimensionality=4, + Extents="-3,3,-10,10,-20,20,-30,30", + SignalInput=range(0, 10000), + ErrorInput=range(0, 10000), + NumberOfBins="10,10,10,10", + Names="Dim1,Dim2,Dim3,Dim4", + Units="EnergyT,MomentumT,Other,Extra", + ) + + data = histogram.get_plot_data(workspace, 4) + assert data[0] == "workspace" + assert data[1]["min"] is None + assert data[1]["max"] is None + + qtbot.wait(100) + + +@pytest.mark.parametrize( + "user_conf_file", + [ + """ + [main_tab.plot] + title = full + logarithmic_intensity = False + """ + ], + indirect=True, +) +def test_plot_data_full(qtbot, user_conf_file, monkeypatch): + """Test plot inputs with a customized full title and intensitites""" + + # mock plot info + monkeypatch.setattr("shiver.configuration.CONFIG_PATH_FILE", user_conf_file) + histogram = Histogram() + + def monk_display_name(ws_name, ndims): + return f"full {ws_name}: {ndims}" + + histogram.plot_display_name_callback = monk_display_name + + # clear mantid workspace + mtd.clear() + + workspace = CreateMDHistoWorkspace( + Dimensionality=4, + Extents="-3,3,-10,10,-20,20,-30,30", + SignalInput=range(0, 10000), + ErrorInput=range(0, 10000), + NumberOfBins="10,10,10,10", + Names="Dim1,Dim2,Dim3,Dim4", + Units="EnergyT,MomentumT,Other,Extra", + ) + + intensity_min = -10.34 + intensity_max = 12.09 + histogram.histogram_parameters.dimensions.intensity_min.setText(str(intensity_min)) + histogram.histogram_parameters.dimensions.intensity_max.setText(str(intensity_max)) + + data = histogram.get_plot_data(workspace, 4) + assert data[0] == "full workspace: 4" + assert data[1]["min"] == intensity_min + assert data[1]["max"] == intensity_max + + qtbot.wait(100) + + +@pytest.mark.parametrize( + "user_conf_file", + [ + """ + [main_tab.plot] + title = None + logarithmic_intensity = False + """ + ], + indirect=True, +) +def test_plot_data_none(qtbot, user_conf_file, monkeypatch): + """Test plot inputs with no title and one intensity""" + + # mock plot info + monkeypatch.setattr("shiver.configuration.CONFIG_PATH_FILE", user_conf_file) + histogram = Histogram() + + # clear mantid workspace + mtd.clear() + + workspace = CreateMDHistoWorkspace( + Dimensionality=4, + Extents="-3,3,-10,10,-20,20,-30,30", + SignalInput=range(0, 10000), + ErrorInput=range(0, 10000), + NumberOfBins="10,10,10,10", + Names="Dim1,Dim2,Dim3,Dim4", + Units="EnergyT,MomentumT,Other,Extra", + ) + + intensity_max = 12.09 + + histogram.histogram_parameters.dimensions.intensity_max.setText(str(intensity_max)) + + data = histogram.get_plot_data(workspace, 4) + assert data[0] is None + assert data[1]["min"] is None + assert data[1]["max"] == intensity_max + + qtbot.wait(100)