From ac2cd8d31bc2cbde61fda289f1a932ed4fad4305 Mon Sep 17 00:00:00 2001 From: Christopher Sherman Date: Thu, 18 Jan 2024 09:59:16 -0800 Subject: [PATCH] Enabling code formatting checks, applying yapf --- docs/conf.py | 35 +-- .../geos_ats/command_line_parsers.py | 116 ++++++-- geos_ats_package/geos_ats/common_utilities.py | 32 ++- .../geos_ats/configuration_record.py | 241 +++++++++++----- .../geos_ats/environment_setup.py | 29 +- geos_ats_package/geos_ats/geos_ats_debug.py | 6 +- .../geos_ats/helpers/curve_check.py | 114 +++++--- .../geos_ats/helpers/permute_array.py | 18 +- .../geos_ats/helpers/restart_check.py | 270 ++++++++++++------ .../geos_ats/machine_utilities.py | 4 +- .../geos_ats/machines/batchGeosatsMoab.py | 22 +- .../geos_ats/machines/bgqos_0_ASQ.py | 59 ++-- geos_ats_package/geos_ats/machines/darwin.py | 2 +- .../geosAtsSlurmProcessorScheduled.py | 12 +- geos_ats_package/geos_ats/machines/lassen.py | 17 +- geos_ats_package/geos_ats/machines/nersc.py | 9 +- geos_ats_package/geos_ats/machines/openmpi.py | 16 +- geos_ats_package/geos_ats/machines/summit.py | 17 +- geos_ats_package/geos_ats/main.py | 103 ++++--- geos_ats_package/geos_ats/reporting.py | 165 +++++++---- geos_ats_package/geos_ats/rules.py | 30 +- geos_ats_package/geos_ats/scheduler.py | 20 +- geos_ats_package/geos_ats/suite_settings.py | 2 +- geos_ats_package/geos_ats/test_builder.py | 15 +- geos_ats_package/geos_ats/test_case.py | 101 ++++--- geos_ats_package/geos_ats/test_modifier.py | 2 +- geos_ats_package/geos_ats/test_steps.py | 160 +++++++---- geos_ats_package/geos_ats/user_utilities.py | 2 +- geosx_mesh_doctor/checks/check_fractures.py | 71 +++-- geosx_mesh_doctor/checks/collocated_nodes.py | 12 +- geosx_mesh_doctor/checks/element_volumes.py | 17 +- .../checks/fix_elements_orderings.py | 14 +- geosx_mesh_doctor/checks/generate_cube.py | 32 ++- .../checks/generate_fractures.py | 180 +++++++----- .../checks/generate_global_ids.py | 9 +- geosx_mesh_doctor/checks/non_conformal.py | 98 ++++--- geosx_mesh_doctor/checks/reorient_mesh.py | 52 ++-- .../checks/self_intersecting_elements.py | 32 +-- .../checks/supported_elements.py | 74 ++--- geosx_mesh_doctor/checks/triangle_distance.py | 46 +-- geosx_mesh_doctor/checks/vtk_polyhedron.py | 30 +- geosx_mesh_doctor/checks/vtk_utils.py | 44 +-- geosx_mesh_doctor/mesh_doctor.py | 4 +- geosx_mesh_doctor/parsing/__init__.py | 1 - geosx_mesh_doctor/parsing/cli_parsing.py | 27 +- .../parsing/collocated_nodes_parsing.py | 25 +- .../parsing/element_volumes_parsing.py | 29 +- .../parsing/fix_elements_orderings_parsing.py | 23 +- .../parsing/generate_cube_parsing.py | 61 ++-- .../parsing/generate_fractures_parsing.py | 52 ++-- .../parsing/generate_global_ids_parsing.py | 15 +- .../parsing/non_conformal_parsing.py | 40 +-- .../self_intersecting_elements_parsing.py | 35 ++- .../parsing/supported_elements_parsing.py | 55 ++-- .../parsing/vtk_output_parsing.py | 29 +- geosx_mesh_doctor/register.py | 20 +- geosx_mesh_doctor/tests/test_cli_parsing.py | 19 +- .../tests/test_collocated_nodes.py | 3 +- .../tests/test_element_volumes.py | 6 +- geosx_mesh_doctor/tests/test_generate_cube.py | 26 +- .../tests/test_generate_fractures.py | 134 ++++++--- geosx_mesh_doctor/tests/test_non_conformal.py | 20 +- geosx_mesh_doctor/tests/test_reorient_mesh.py | 25 +- .../tests/test_self_intersecting_elements.py | 4 +- .../tests/test_supported_elements.py | 151 ++++++---- .../tests/test_triangle_distance.py | 108 ++++--- .../geosx_mesh_tools/abaqus_converter.py | 35 ++- .../geosx_mesh_tools/main.py | 18 +- .../geosx_xml_tools/attribute_coverage.py | 47 +-- .../geosx_xml_tools/command_line_parsers.py | 95 ++++-- .../geosx_xml_tools/main.py | 36 ++- .../geosx_xml_tools/regex_tools.py | 10 +- .../geosx_xml_tools/table_generator.py | 11 +- .../tests/generate_test_xml.py | 24 +- .../geosx_xml_tools/tests/test_manager.py | 104 ++++--- .../geosx_xml_tools/unit_manager.py | 27 +- .../geosx_xml_tools/xml_formatter.py | 56 ++-- .../geosx_xml_tools/xml_processor.py | 65 +++-- .../geosx_xml_tools/xml_redundancy_check.py | 5 +- .../hdf5_wrapper/use_example.py | 3 +- hdf5_wrapper_package/hdf5_wrapper/wrapper.py | 7 +- .../hdf5_wrapper/wrapper_tests.py | 41 ++- .../pygeosx_tools/file_io.py | 30 +- .../pygeosx_tools/mesh_interpolation.py | 9 +- .../pygeosx_tools/well_log.py | 7 +- .../pygeosx_tools/wrapper.py | 69 +++-- .../timehistory/plot_time_history.py | 104 ++++--- yapf.cfg | 5 + 88 files changed, 2624 insertions(+), 1426 deletions(-) create mode 100644 yapf.cfg diff --git a/docs/conf.py b/docs/conf.py index 06f479f..deadd37 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -18,17 +18,13 @@ # Add python modules to be documented python_root = '..' -python_modules = ('geosx_mesh_tools_package', - 'geosx_xml_tools_package', - 'geosx_mesh_doctor', - 'geos_ats_package', - 'hdf5_wrapper_package', - 'pygeosx_tools_package', +python_modules = ('geosx_mesh_tools_package', 'geosx_xml_tools_package', + 'geosx_mesh_doctor', 'geos_ats_package', + 'hdf5_wrapper_package', 'pygeosx_tools_package', 'timehistory_package') for m in python_modules: sys.path.insert(0, os.path.abspath(os.path.join(python_root, m))) - # -- Project information ----------------------------------------------------- project = u'GEOS Python Packages' @@ -40,7 +36,6 @@ # The full version, including alpha/beta/rc tags release = u'' - # -- General configuration --------------------------------------------------- # If your documentation needs a minimal Sphinx version, state it here. @@ -51,18 +46,14 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx_design', - 'sphinx.ext.todo', - 'sphinx.ext.autodoc', - 'sphinx.ext.doctest', - 'sphinx.ext.imgmath', - 'sphinxarg.ext', - 'sphinx.ext.napoleon', - 'sphinxcontrib.programoutput' + 'sphinx_design', 'sphinx.ext.todo', 'sphinx.ext.autodoc', + 'sphinx.ext.doctest', 'sphinx.ext.imgmath', 'sphinxarg.ext', + 'sphinx.ext.napoleon', 'sphinxcontrib.programoutput' ] - -autodoc_mock_imports = ["pygeosx", "pylvarray", "meshio", "lxml", "mpi4py", "h5py", "ats", "scipy"] +autodoc_mock_imports = [ + "pygeosx", "pylvarray", "meshio", "lxml", "mpi4py", "h5py", "ats", "scipy" +] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: @@ -90,7 +81,6 @@ # The name of the Pygments (syntax highlighting) style to use. pygments_style = 'sphinx' - # -- Theme options ---------------------------------------------- extensions += [ 'sphinx_rtd_theme', @@ -98,10 +88,7 @@ html_theme = "sphinx_rtd_theme" -html_theme_options = { - 'navigation_depth': -1, - 'collapse_navigation': False -} +html_theme_options = {'navigation_depth': -1, 'collapse_navigation': False} html_static_path = ['./_static'] @@ -109,8 +96,6 @@ 'theme_overrides.css', ] - # -- Options for HTMLHelp output --------------------------------------------- # Output file base name for HTML help builder. htmlhelp_basename = 'geosPythonPackagesDoc' - diff --git a/geos_ats_package/geos_ats/command_line_parsers.py b/geos_ats_package/geos_ats/command_line_parsers.py index 95332d4..8242aee 100644 --- a/geos_ats_package/geos_ats/command_line_parsers.py +++ b/geos_ats_package/geos_ats/command_line_parsers.py @@ -5,18 +5,30 @@ from pydoc import locate action_options = { - "run": "execute the test cases that previously did not pass.", - "rerun": "ignore the status from previous runs, and rerun the tests.", - "continue": "continue running, ignoring tests that have either passed or failed", - "list": "list the test cases.", - "commands": "display the command line of each test step.", - "reset": "Removes Failed status from any test case.", - "clean": "remove files generated by the test cases.", - "veryclean": "does a clean plus removes non testcase created files (TestLog, results, ...)", - "check": "skip the action steps and just run the check steps.", - "rebaseline": "rebaseline the testcases from a previous run.", - "rebaselinefailed": "rebaseline only failed testcases from a previous run.", - "report": "generate a text or html report, see config for the reporting options.", + "run": + "execute the test cases that previously did not pass.", + "rerun": + "ignore the status from previous runs, and rerun the tests.", + "continue": + "continue running, ignoring tests that have either passed or failed", + "list": + "list the test cases.", + "commands": + "display the command line of each test step.", + "reset": + "Removes Failed status from any test case.", + "clean": + "remove files generated by the test cases.", + "veryclean": + "does a clean plus removes non testcase created files (TestLog, results, ...)", + "check": + "skip the action steps and just run the check steps.", + "rebaseline": + "rebaseline the testcases from a previous run.", + "rebaselinefailed": + "rebaseline only failed testcases from a previous run.", + "report": + "generate a text or html report, see config for the reporting options.", } check_options = { @@ -38,22 +50,47 @@ def build_command_line_parser(): parser = argparse.ArgumentParser(description="Runs GEOS integrated tests") - parser.add_argument("geos_bin_dir", type=str, help="GEOS binary directory.") + parser.add_argument("geos_bin_dir", + type=str, + help="GEOS binary directory.") - parser.add_argument("-w", "--workingDir", type=str, help="Initial working directory") + parser.add_argument("-w", + "--workingDir", + type=str, + help="Initial working directory") action_names = ','.join(action_options.keys()) - parser.add_argument("-a", "--action", type=str, default="run", help=f"Test actions options ({action_names})") + parser.add_argument("-a", + "--action", + type=str, + default="run", + help=f"Test actions options ({action_names})") check_names = ','.join(check_options.keys()) - parser.add_argument("-c", "--check", type=str, default="all", help=f"Test check options ({check_names})") + parser.add_argument("-c", + "--check", + type=str, + default="all", + help=f"Test check options ({check_names})") verbosity_names = ','.join(verbose_options.keys()) - parser.add_argument("-v", "--verbose", type=str, default="info", help=f"Log verbosity options ({verbosity_names})") - - parser.add_argument("-d", "--detail", action="store_true", default=False, help="Show detailed action/check options") - - parser.add_argument("-i", "--info", action="store_true", default=False, help="Info on various topics") + parser.add_argument("-v", + "--verbose", + type=str, + default="info", + help=f"Log verbosity options ({verbosity_names})") + + parser.add_argument("-d", + "--detail", + action="store_true", + default=False, + help="Show detailed action/check options") + + parser.add_argument("-i", + "--info", + action="store_true", + default=False, + help="Info on various topics") parser.add_argument("-r", "--restartCheckOverrides", @@ -62,21 +99,32 @@ def build_command_line_parser(): help='Restart check parameter override (name value)', default=[]) - parser.add_argument("--salloc", - default=True, - help="Used by the chaosM machine to first allocate nodes with salloc, before running the tests") + parser.add_argument( + "--salloc", + default=True, + help= + "Used by the chaosM machine to first allocate nodes with salloc, before running the tests" + ) parser.add_argument( "--sallocoptions", type=str, default="", - help="Used to override all command-line options for salloc. No other options with be used or added.") + help= + "Used to override all command-line options for salloc. No other options with be used or added." + ) - parser.add_argument("--ats", nargs='+', default=[], action="append", help="pass arguments to ats") + parser.add_argument("--ats", + nargs='+', + default=[], + action="append", + help="pass arguments to ats") parser.add_argument("--machine", default=None, help="name of the machine") - parser.add_argument("--machine-dir", default=None, help="Search path for machine definitions") + parser.add_argument("--machine-dir", + default=None, + help="Search path for machine definitions") parser.add_argument("-l", "--logs", type=str, default=None) @@ -84,12 +132,16 @@ def build_command_line_parser(): "--failIfTestsFail", action="store_true", default=False, - help="geos_ats normally exits with 0. This will cause it to exit with an error code if there was a failed test." + help= + "geos_ats normally exits with 0. This will cause it to exit with an error code if there was a failed test." ) parser.add_argument("-n", "-N", "--numNodes", type=int, default="2") - parser.add_argument("ats_targets", type=str, nargs='*', help="ats files or directories.") + parser.add_argument("ats_targets", + type=str, + nargs='*', + help="ats files or directories.") return parser @@ -103,7 +155,8 @@ def parse_command_line_arguments(args): check = options.check if check not in check_options: print( - f"Selected check option ({check}) not recognized. Try running with --help/--details for more information") + f"Selected check option ({check}) not recognized. Try running with --help/--details for more information" + ) exit_flag = True action = options.action @@ -120,7 +173,8 @@ def parse_command_line_arguments(args): # Print detailed information if options.detail: - for option_type, details in zip(['action', 'check'], [action_options, check_options]): + for option_type, details in zip(['action', 'check'], + [action_options, check_options]): print(f'\nAvailable {option_type} options:') for k, v in details.items(): print(f' {k}: {v}') diff --git a/geos_ats_package/geos_ats/common_utilities.py b/geos_ats_package/geos_ats/common_utilities.py index 0ad4ad8..bab9da7 100644 --- a/geos_ats_package/geos_ats/common_utilities.py +++ b/geos_ats_package/geos_ats/common_utilities.py @@ -18,7 +18,7 @@ def Error(msg): def Log(msg): - import ats # type: ignore[import] + import ats # type: ignore[import] testmode = False try: testmode = ats.tests.AtsTest.getOptions().get("testmode") @@ -48,7 +48,8 @@ def _getwidth(self): maxwidth = 100 if os.name == "posix": try: - sttyout = subprocess.Popen(["stty", "size"], stdout=subprocess.PIPE).communicate()[0] + sttyout = subprocess.Popen( + ["stty", "size"], stdout=subprocess.PIPE).communicate()[0] maxwidth = int(sttyout.split()[1]) except: # If the stty size approach does not work, the use a default maxwidth @@ -79,7 +80,10 @@ def printTable(self, outfile=sys.stdout): # find the max column sizes colWidth = [] for i in range(self.columns): - colWidth.append(max([len(str(row[i])) for row in self.table if row is not None])) + colWidth.append( + max([ + len(str(row[i])) for row in self.table if row is not None + ])) # adjust the colWidths down if colmax is step for i in range(self.columns): @@ -89,7 +93,8 @@ def printTable(self, outfile=sys.stdout): # last column is floating - total = sum(colWidth) + self.columns * (1 + len(self.sep)) + len(self.indent) + total = sum(colWidth) + self.columns * (1 + len(self.sep)) + len( + self.indent) if total > self.maxwidth: colWidth[-1] = max(10, self.maxwidth - (total - colWidth[-1])) @@ -99,7 +104,9 @@ def printTable(self, outfile=sys.stdout): # row break controls. # if row is None then this is a break - addBreak = (row is None) or (self.rowbreak and rowbreakindex > 0 and rowbreakindex % self.rowbreak == 0) + addBreak = (row + is None) or (self.rowbreak and rowbreakindex > 0 + and rowbreakindex % self.rowbreak == 0) if addBreak: table_str += self.indent for i in range(self.columns): @@ -125,9 +132,14 @@ def printTable(self, outfile=sys.stdout): drow = str(row[i]) if i == self.columns - 1: - lines.append(textwrap.wrap(drow, colWidth[i], break_long_words=False)) + lines.append( + textwrap.wrap(drow, + colWidth[i], + break_long_words=False)) else: - lines.append(textwrap.wrap(drow, colWidth[i], break_long_words=True)) + lines.append( + textwrap.wrap(drow, colWidth[i], + break_long_words=True)) maxlines = max([len(x) for x in lines]) @@ -236,8 +248,10 @@ def removeLogDirectories(dir): ff = os.path.join(dir, f) if os.path.isdir(ff) and not os.path.islink(ff): tests = [ - all([os.path.exists(os.path.join(ff, "ats.log")), - os.path.exists(os.path.join(ff, "geos_ats.config"))]), + all([ + os.path.exists(os.path.join(ff, "ats.log")), + os.path.exists(os.path.join(ff, "geos_ats.config")) + ]), f.find("TestLogs.") == 0 ] if any(tests): diff --git a/geos_ats_package/geos_ats/configuration_record.py b/geos_ats_package/geos_ats/configuration_record.py index bfb721e..2388e7c 100644 --- a/geos_ats_package/geos_ats/configuration_record.py +++ b/geos_ats_package/geos_ats/configuration_record.py @@ -42,7 +42,8 @@ def set(self, name, value): value = item.type(value) except ValueError: - Error("Attempted to set config.%s (which is %s) with %s" % (name, str(item.type), str(value))) + Error("Attempted to set config.%s (which is %s) with %s" % + (name, str(item.type), str(value))) item.value = item.type(value) @@ -67,12 +68,14 @@ def checkname(self, name): matches = difflib.get_close_matches(name, self._items.keys()) if len(matches) == 0: Error("Unknown config name: %s. " - "See 'geos_ats -i config' for the complete list." % (name)) + "See 'geos_ats -i config' for the complete list." % + (name)) else: Error("Unknown config name: %s. " "Perhaps you meant '%s'. " - "See 'geos_ats -i config' for the complete list." % (name, matches[0])) + "See 'geos_ats -i config' for the complete list." % + (name, matches[0])) def __setattr__(self, name, value): if name in self._items: @@ -90,7 +93,7 @@ def __getattr__(self, name): # The global config object config = Config() # Global testTimings object -globalTestTimings = {} # type: ignore[var-annotated] +globalTestTimings = {} # type: ignore[var-annotated] # Depth of testconfig recursion configDepth = 0 @@ -98,7 +101,7 @@ def __getattr__(self, name): def infoConfigShow(public, outfile=sys.stdout): topic = InfoTopic("config show", outfile) topic.startBanner() - import ats # type: ignore[import] + import ats # type: ignore[import] keys = sorted(config._items.keys()) table = TextTable(3) @@ -151,12 +154,18 @@ def infoConfigDocumentation(public): def infoConfig(*args): menu = InfoTopic("config") - menu.addTopic("show", "Show all the config options", lambda *x: infoConfigShow(True)) - menu.addTopic("doc", "Documentation for the config options", lambda *x: infoConfigDocumentation(True)) - menu.addTopic("showall", "Show all the config options (including the internal options)", - lambda: infoConfigShow(False)) - menu.addTopic("docall", "Documentation for the config options (including the internal options)", - lambda: infoConfigDocumentation(False)) + menu.addTopic("show", "Show all the config options", + lambda *x: infoConfigShow(True)) + menu.addTopic("doc", "Documentation for the config options", + lambda *x: infoConfigDocumentation(True)) + menu.addTopic( + "showall", + "Show all the config options (including the internal options)", + lambda: infoConfigShow(False)) + menu.addTopic( + "docall", + "Documentation for the config options (including the internal options)", + lambda: infoConfigDocumentation(False)) menu.process(args) @@ -167,95 +176,151 @@ def initializeConfig(configFile, configOverride, options): geos_atsdir = os.path.realpath(os.path.dirname(__file__)) # configfile - config.add("testbaseline_dir", str, "", "Base directory that contains all the baselines") + config.add("testbaseline_dir", str, "", + "Base directory that contains all the baselines") - config.add("geos_bin_dir", str, "", "Directory that contains 'geos' and related executables.") + config.add("geos_bin_dir", str, "", + "Directory that contains 'geos' and related executables.") - config.add("userscript_path", str, "", - "Directory that contains scripts for testing, searched after test directory and executable_path.") + config.add( + "userscript_path", str, "", + "Directory that contains scripts for testing, searched after test directory and executable_path." + ) - config.add("clean_on_pass", bool, False, "If True, then after a TestCase passes, " - "all temporary files are removed.") + config.add( + "clean_on_pass", bool, False, "If True, then after a TestCase passes, " + "all temporary files are removed.") # geos options - config.add("geos_default_args", str, "-i", - "A string containing arguments that will always appear on the geos commandline") + config.add( + "geos_default_args", str, "-i", + "A string containing arguments that will always appear on the geos commandline" + ) # reporting - config.add("report_html", bool, True, "True if HTML formatted results will be generated with the report action") - config.add("report_html_file", str, "test_results.html", "Location to write the html report") - config.add("report_html_periodic", bool, True, "True to update the html file during the periodic reports") - config.add("browser_command", str, "firefox -no-remote", "Command to use to launch a browser to view html results") - config.add("browser", bool, False, "If True, then launch the browser_command to view the report_html_file") - config.add("report_doc_dir", str, os.path.normpath(os.path.join(geos_atsdir, "..", "doc")), + config.add( + "report_html", bool, True, + "True if HTML formatted results will be generated with the report action" + ) + config.add("report_html_file", str, "test_results.html", + "Location to write the html report") + config.add("report_html_periodic", bool, True, + "True to update the html file during the periodic reports") + config.add("browser_command", str, "firefox -no-remote", + "Command to use to launch a browser to view html results") + config.add( + "browser", bool, False, + "If True, then launch the browser_command to view the report_html_file" + ) + config.add("report_doc_dir", str, + os.path.normpath(os.path.join(geos_atsdir, "..", "doc")), "Location to the test doc directory (used with html reports)") - config.add("report_doc_link", bool, True, "Link against docgen (used with html reports)") - config.add("report_doc_remake", bool, False, - "Remake test documentation, even if it already exists (used with html reports)") + config.add("report_doc_link", bool, True, + "Link against docgen (used with html reports)") + config.add( + "report_doc_remake", bool, False, + "Remake test documentation, even if it already exists (used with html reports)" + ) - config.add("report_text", bool, True, "True if you want text results to be generated with the report action") - config.add("report_text_file", str, "test_results.txt", "Location to write the text report") - config.add("report_text_echo", bool, True, "If True, echo the report to stdout") - config.add("report_wait", bool, False, "Wait until all tests are complete before reporting") + config.add( + "report_text", bool, True, + "True if you want text results to be generated with the report action") + config.add("report_text_file", str, "test_results.txt", + "Location to write the text report") + config.add("report_text_echo", bool, True, + "If True, echo the report to stdout") + config.add("report_wait", bool, False, + "Wait until all tests are complete before reporting") - config.add("report_ini", bool, True, "True if you want ini results to be generated with the report action") - config.add("report_ini_file", str, "test_results.ini", "Location to write the ini report") + config.add( + "report_ini", bool, True, + "True if you want ini results to be generated with the report action") + config.add("report_ini_file", str, "test_results.ini", + "Location to write the ini report") - config.add("report_notations", type([]), [], "Lines of text that are inserted into the reports.") + config.add("report_notations", type([]), [], + "Lines of text that are inserted into the reports.") - config.add("report_notbuilt_regexp", str, "(not built into this version)", - "Regular expression that must appear in output to indicate that feature is not built.") + config.add( + "report_notbuilt_regexp", str, "(not built into this version)", + "Regular expression that must appear in output to indicate that feature is not built." + ) - config.add("checkmessages_always_ignore_regexp", type([]), ["not available in this version"], + config.add("checkmessages_always_ignore_regexp", type([]), + ["not available in this version"], "Regular expression to ignore in all checkmessages steps.") - config.add("checkmessages_never_ignore_regexp", type([]), ["not yet implemented"], + config.add("checkmessages_never_ignore_regexp", type([]), + ["not yet implemented"], "Regular expression to not ignore in all checkmessages steps.") - config.add("report_timing", bool, False, "True if you want timing file to be generated with the report action") - config.add("report_timing_overwrite", bool, False, - "True if you want timing file to overwrite existing timing file rather than augment it") + config.add( + "report_timing", bool, False, + "True if you want timing file to be generated with the report action") + config.add( + "report_timing_overwrite", bool, False, + "True if you want timing file to overwrite existing timing file rather than augment it" + ) # timing and priority - config.add("priority", str, "equal", "Method of prioritization of tests: [\"equal\", \"processors\",\"timing\"]") + config.add( + "priority", str, "equal", + "Method of prioritization of tests: [\"equal\", \"processors\",\"timing\"]" + ) config.add("timing_file", str, "timing.txt", "Location of timing file") # batch - config.add("batch_dryrun", bool, False, - "If true, the batch jobs will not be submitted, but the batch scripts will be created") - config.add("batch_interactive", bool, False, "If true, the batch jobs will be treated as interactive jobs") + config.add( + "batch_dryrun", bool, False, + "If true, the batch jobs will not be submitted, but the batch scripts will be created" + ) + config.add("batch_interactive", bool, False, + "If true, the batch jobs will be treated as interactive jobs") config.add("batch_bank", str, "", "The name of the bank to use") config.add("batch_ppn", int, 0, "Number of processors per node") - config.add("batch_partition", str, "", "the batch partition, if not specified the default will be used.") + config.add( + "batch_partition", str, "", + "the batch partition, if not specified the default will be used.") config.add("batch_queue", str, "pbatch", "the batch queue.") - config.add("batch_header", type([]), [], "Additional lines to add to the batch header") + config.add("batch_header", type([]), [], + "Additional lines to add to the batch header") # retry - config.add("max_retry", int, 2, "Maximum number of times to retry failed runs.") - config.add("retry_err_regexp", str, - "(launch failed|Failure in initializing endpoint|channel initialization failed)", - "Regular expression that must appear in error log in order to retry.") + config.add("max_retry", int, 2, + "Maximum number of times to retry failed runs.") + config.add( + "retry_err_regexp", str, + "(launch failed|Failure in initializing endpoint|channel initialization failed)", + "Regular expression that must appear in error log in order to retry.") # timeout - config.add("default_timelimit", str, "30m", - "This sets a default timelimit for all test steps which do not explicitly set a timelimit.") - config.add("override_timelimit", bool, False, - "If true, the value used for the default time limit will override the time limit for each test step.") + config.add( + "default_timelimit", str, "30m", + "This sets a default timelimit for all test steps which do not explicitly set a timelimit." + ) + config.add( + "override_timelimit", bool, False, + "If true, the value used for the default time limit will override the time limit for each test step." + ) # Decomposition Multiplication config.add( "decomp_factor", int, 1, "This sets the multiplication factor to be applied to the decomposition and number of procs of all eligible tests." ) - config.add("override_np", int, 0, "If non-zero, maximum number of processors to use for each test step.") + config.add( + "override_np", int, 0, + "If non-zero, maximum number of processors to use for each test step.") # global environment variables - config.add("environment", dict, {}, "Additional environment variables to use during testing") + config.add("environment", dict, {}, + "Additional environment variables to use during testing") # General check config for check in ("restartcheck", ): config.add( - "%s_enabled" % check, bool, True, "If True, this check has the possibility of running, " + "%s_enabled" % check, bool, True, + "If True, this check has the possibility of running, " "but might not run depending on the '--check' command line option. " "If False, this check will never be run.") @@ -267,13 +332,18 @@ def initializeConfig(configFile, configOverride, options): public=False) # Checks: Restartcheck - config.add("restart_skip_missing", bool, False, "Determines whether new/missing fields are ignored") - config.add("restart_exclude_pattern", list, [], "A list of field names to ignore in restart files") + config.add("restart_skip_missing", bool, False, + "Determines whether new/missing fields are ignored") + config.add("restart_exclude_pattern", list, [], + "A list of field names to ignore in restart files") # Checks: Curvecheck - config.add("curvecheck_enabled", bool, True, "Determines whether curvecheck steps are run.") - config.add("curvecheck_tapestry_mode", bool, False, - "Provide temporary backwards compatibility for nighty and weekly suites until they are using geos_ats") + config.add("curvecheck_enabled", bool, True, + "Determines whether curvecheck steps are run.") + config.add( + "curvecheck_tapestry_mode", bool, False, + "Provide temporary backwards compatibility for nighty and weekly suites until they are using geos_ats" + ) config.add("curvecheck_absolute", float, 1e-5, "absolute tolerance") config.add("curvecheck_relative", float, 1e-5, "relative tolerance") config.add( @@ -288,29 +358,41 @@ def initializeConfig(configFile, configOverride, options): "curvecheck_delete_temps", bool, True, "Curvecheck generates a number of temporary data files that are used to create the images for the html file. If this parameter is true, curvecheck will delete these temporary files. By default, the parameter is true." ) - config.add("gnuplot_executable", str, os.path.join("/usr", "bin", "gnuplot"), "Location to gnuplot") + config.add("gnuplot_executable", str, + os.path.join("/usr", "bin", "gnuplot"), "Location to gnuplot") # Rebaseline: config.add( - "rebaseline_undo", bool, False, "If True, and the action is set to 'rebaseline'," + "rebaseline_undo", bool, False, + "If True, and the action is set to 'rebaseline'," " this option will undo (revert) a previous rebaseline.") - config.add("rebaseline_ask", bool, True, "If True, the rebaseline will not occur until the user has anwered an" - " 'are you sure?' question") + config.add( + "rebaseline_ask", bool, True, + "If True, the rebaseline will not occur until the user has anwered an" + " 'are you sure?' question") # test modifier config.add("testmodifier", str, "", "Name of a test modifier to apply") # filters - config.add("filter_maxprocessors", int, -1, "If not -1, Run only those tests where the number of" - " processors is less than or equal to this value") + config.add( + "filter_maxprocessors", int, -1, + "If not -1, Run only those tests where the number of" + " processors is less than or equal to this value") # machines - config.add("machine_options", list, [], "Arguments to pass to the machine module") + config.add("machine_options", list, [], + "Arguments to pass to the machine module") - config.add("script_launch", int, 0, "Whether to launch scripts (and other serial steps) on compute nodes") - config.add("openmpi_install", str, "", "Location to the openmpi installation") - config.add("openmpi_maxprocs", int, 0, "Number of maximum processors openmpi") - config.add("openmpi_procspernode", int, 1, "Number of processors per node for openmpi") + config.add( + "script_launch", int, 0, + "Whether to launch scripts (and other serial steps) on compute nodes") + config.add("openmpi_install", str, "", + "Location to the openmpi installation") + config.add("openmpi_maxprocs", int, 0, + "Number of maximum processors openmpi") + config.add("openmpi_procspernode", int, 1, + "Number of processors per node for openmpi") config.add( "openmpi_precommand", str, "", "A string that will be" @@ -325,9 +407,12 @@ def initializeConfig(configFile, configOverride, options): " it will be replaced by the unique name of the test.") config.add("windows_mpiexe", str, "", "Location to mpiexe") - config.add("windows_nompi", bool, False, "Run executables on nompi processor") - config.add("windows_oversubscribe", int, 1, - "Multiplier to number of processors to allow oversubscription of processors") + config.add("windows_nompi", bool, False, + "Run executables on nompi processor") + config.add( + "windows_oversubscribe", int, 1, + "Multiplier to number of processors to allow oversubscription of processors" + ) # populate the config with overrides from the command line for key, value in configOverride.items(): diff --git a/geos_ats_package/geos_ats/environment_setup.py b/geos_ats_package/geos_ats/environment_setup.py index 145e548..f68b660 100644 --- a/geos_ats_package/geos_ats/environment_setup.py +++ b/geos_ats_package/geos_ats/environment_setup.py @@ -7,7 +7,8 @@ def setup_ats(src_path, build_path, ats_xargs, ats_machine, ats_machine_dir): bin_dir = os.path.join(build_path, "bin") geos_ats_fname = os.path.join(bin_dir, "run_geos_ats") - ats_dir = os.path.abspath(os.path.join(src_path, "integratedTests", "tests", "allTests")) + ats_dir = os.path.abspath( + os.path.join(src_path, "integratedTests", "tests", "allTests")) test_path = os.path.join(build_path, "integratedTests") link_path = os.path.join(test_path, "integratedTests") run_script_fname = os.path.join(test_path, "geos_ats.sh") @@ -30,11 +31,14 @@ def setup_ats(src_path, build_path, ats_xargs, ats_machine, ats_machine_dir): # Write the bash script to run ats. with open(run_script_fname, "w") as g: g.write("#!/bin/bash\n") - g.write(f"{geos_ats_fname} {bin_dir} --workingDir {ats_dir} --logs {log_dir} {ats_args} \"$@\"\n") + g.write( + f"{geos_ats_fname} {bin_dir} --workingDir {ats_dir} --logs {log_dir} {ats_args} \"$@\"\n" + ) # Make the script executable st = os.stat(run_script_fname) - os.chmod(run_script_fname, st.st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH) + os.chmod(run_script_fname, + st.st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH) def main(): @@ -46,11 +50,22 @@ def main(): parser = argparse.ArgumentParser(description="Setup ATS script") parser.add_argument("src_path", type=str, help="GEOS src path") parser.add_argument("build_path", type=str, help="GEOS build path") - parser.add_argument("--ats", nargs='+', default=[], action="append", help="Arguments that should be passed to ats") - parser.add_argument("--machine", type=str, default='', help="ATS machine name") - parser.add_argument("--machine-dir", type=str, default='', help="ATS machine directory") + parser.add_argument("--ats", + nargs='+', + default=[], + action="append", + help="Arguments that should be passed to ats") + parser.add_argument("--machine", + type=str, + default='', + help="ATS machine name") + parser.add_argument("--machine-dir", + type=str, + default='', + help="ATS machine directory") options, unkown_args = parser.parse_known_args() - setup_ats(options.src_path, options.build_path, options.ats, options.machine, options.machine_dir) + setup_ats(options.src_path, options.build_path, options.ats, + options.machine, options.machine_dir) if __name__ == '__main__': diff --git a/geos_ats_package/geos_ats/geos_ats_debug.py b/geos_ats_package/geos_ats/geos_ats_debug.py index 1535a18..2043ab0 100644 --- a/geos_ats_package/geos_ats/geos_ats_debug.py +++ b/geos_ats_package/geos_ats/geos_ats_debug.py @@ -8,13 +8,15 @@ from geos_ats import main -def debug_geosats(build_root='~/GEOS/build-quartz-gcc@12-release', extra_args=[]): +def debug_geosats(build_root='~/GEOS/build-quartz-gcc@12-release', + extra_args=[]): # Search for and parse the ats script build_root = os.path.expanduser(build_root) ats_script = os.path.join(build_root, 'integratedTests', 'geos_ats.sh') if not os.path.isfile(ats_script): raise InputError( - 'Could not find geos_ats.sh at the expected location... Make sure to run \"make ats_environment\"') + 'Could not find geos_ats.sh at the expected location... Make sure to run \"make ats_environment\"' + ) with open(ats_script, 'r') as f: header = f.readline() diff --git a/geos_ats_package/geos_ats/helpers/curve_check.py b/geos_ats_package/geos_ats/helpers/curve_check.py index e2c8bc2..c955334 100644 --- a/geos_ats_package/geos_ats/helpers/curve_check.py +++ b/geos_ats_package/geos_ats/helpers/curve_check.py @@ -71,12 +71,20 @@ def evaluate_external_script(script, fn, data): if hasattr(module, fn): return getattr(module, fn)(**data) else: - raise Exception(f'External script does not contain the expected function ({fn})') + raise Exception( + f'External script does not contain the expected function ({fn})' + ) else: raise FileNotFoundError(f'Could not find script: {script}') -def check_diff(parameter_name, set_name, target, baseline, tolerance, errors, modifier='baseline'): +def check_diff(parameter_name, + set_name, + target, + baseline, + tolerance, + errors, + modifier='baseline'): """ Compute the L2-norm of the diff and compare to the set tolerance @@ -99,7 +107,8 @@ def check_diff(parameter_name, set_name, target, baseline, tolerance, errors, mo ) -def curve_check_figure(parameter_name, location_str, set_name, data, data_sizes, output_root, ncol, units_time): +def curve_check_figure(parameter_name, location_str, set_name, data, + data_sizes, output_root, ncol, units_time): """ Generate figures associated with the curve check @@ -165,7 +174,8 @@ def curve_check_figure(parameter_name, location_str, set_name, data, data_sizes, kwargs['label'] = k ax.plot(t, x[:, jj], color=c, **style[k], **kwargs) except Exception as e: - print(f'Error rendering curve {value_key}: {str(e)}') + print( + f'Error rendering curve {value_key}: {str(e)}') else: # Spatial axis horizontal_label = 'X (m)' @@ -175,9 +185,14 @@ def curve_check_figure(parameter_name, location_str, set_name, data, data_sizes, kwargs = {} if (jj == 0): kwargs['label'] = k - ax.plot(position, x[jj, :], color=c, **style[k], **kwargs) + ax.plot(position, + x[jj, :], + color=c, + **style[k], + **kwargs) except Exception as e: - print(f'Error rendering curve {value_key}: {str(e)}') + print( + f'Error rendering curve {value_key}: {str(e)}') # Set labels ax.set_xlabel(horizontal_label) @@ -185,10 +200,12 @@ def curve_check_figure(parameter_name, location_str, set_name, data, data_sizes, # ax.set_xlim(t[[0, -1]]) ax.legend(loc=2) plt.tight_layout() - fig.savefig(os.path.join(output_root, f'{parameter_name}_{set_name}'), dpi=200) + fig.savefig(os.path.join(output_root, f'{parameter_name}_{set_name}'), + dpi=200) -def compare_time_history_curves(fname, baseline, curve, tolerance, output, output_n_column, units_time, +def compare_time_history_curves(fname, baseline, curve, tolerance, output, + output_n_column, units_time, script_instructions): """ Compute time history curves @@ -216,7 +233,8 @@ def compare_time_history_curves(fname, baseline, curve, tolerance, output, outpu if len(curve) != len(tolerance): raise Exception( - f'Curvecheck inputs must be of the same length: curves ({len(curve)}) and tolerance ({len(tolerance)})') + f'Curvecheck inputs must be of the same length: curves ({len(curve)}) and tolerance ({len(tolerance)})' + ) # Load data and check sizes data = {} @@ -285,8 +303,10 @@ def compare_time_history_curves(fname, baseline, curve, tolerance, output, outpu key += f' {s}' key2 += f' {s}' data['script'][key] = data['target'][key] - data['script'][key2] = evaluate_external_script(script, fn, data['target']) - data_sizes[p][s]['script'] = list(np.shape(data['script'][key2])) + data['script'][key2] = evaluate_external_script( + script, fn, data['target']) + data_sizes[p][s]['script'] = list( + np.shape(data['script'][key2])) except Exception as e: errors.append(str(e)) @@ -317,16 +337,22 @@ def compare_time_history_curves(fname, baseline, curve, tolerance, output, outpu if set_sizes['target'] == set_sizes['baseline']: check_diff(p, s, xa, xb, tol[p][s], errors) else: - warnings.append(size_err.format(p, s, *set_sizes['target'], *set_sizes['baseline'])) + warnings.append( + size_err.format(p, s, *set_sizes['target'], + *set_sizes['baseline'])) # Check whether the data can be interpolated - if (len(set_sizes['baseline']) == 1) or (set_sizes['target'][1:] == set_sizes['baseline'][1:]): - warnings.append(f'Interpolating target curve in time: {p}_{s}') + if (len(set_sizes['baseline']) + == 1) or (set_sizes['target'][1:] + == set_sizes['baseline'][1:]): + warnings.append( + f'Interpolating target curve in time: {p}_{s}') ta = data['target'][f'{p} Time'] tb = data['baseline'][f'{p} Time'] xc = interpolate_values_time(ta, xa, tb) check_diff(p, s, xc, xb, tol[p][s], errors) else: - errors.append(f'Cannot perform a curve check for {p}_{s}') + errors.append( + f'Cannot perform a curve check for {p}_{s}') if (('script' in set_sizes) and ('target' in set_sizes)): xa = data['target'][key] xb = data['script'][key] @@ -337,7 +363,8 @@ def compare_time_history_curves(fname, baseline, curve, tolerance, output, outpu os.makedirs(output, exist_ok=True) for p, set_data in data_sizes.items(): for s, set_sizes in set_data.items(): - curve_check_figure(p, location_strings[p], s, data, data_sizes, output, output_n_column, units_time) + curve_check_figure(p, location_strings[p], s, data, data_sizes, + output, output_n_column, units_time) return warnings, errors @@ -360,7 +387,8 @@ def __call__(self, parser, namespace, values, option_string=None): elif len(values) == 2: pairs.append((values[0], values[1])) else: - raise Exception('Only a single value or a pair of values are expected') + raise Exception( + 'Only a single value or a pair of values are expected') setattr(namespace, self.dest, pairs) @@ -388,31 +416,40 @@ def __call__(self, parser, namespace, values, option_string=None): action=PairAction, help='Curves to check (value) or (value, setname)', default=[]) - parser.add_argument("-t", - "--tolerance", - nargs='+', - action='append', - help=f"The tolerance for each curve check diffs (||x-y||/N)", - default=[]) - parser.add_argument("-w", - "--Werror", - action="store_true", - help="Force all warnings to be errors, default is False.", - default=False) - parser.add_argument("-o", "--output", help="Output figures to this directory", default='./curve_check_figures') + parser.add_argument( + "-t", + "--tolerance", + nargs='+', + action='append', + help=f"The tolerance for each curve check diffs (||x-y||/N)", + default=[]) + parser.add_argument( + "-w", + "--Werror", + action="store_true", + help="Force all warnings to be errors, default is False.", + default=False) + parser.add_argument("-o", + "--output", + help="Output figures to this directory", + default='./curve_check_figures') unit_choices = list(unit_map.keys()) - parser.add_argument("-n", "--n-column", help="Number of columns for the output figure", default=1) + parser.add_argument("-n", + "--n-column", + help="Number of columns for the output figure", + default=1) parser.add_argument("-u", "--units-time", help=f"Time units for plots (default=seconds)", choices=unit_choices, default='seconds') - parser.add_argument("-s", - "--script", - nargs='+', - action=ScriptAction, - help='Python script instructions (path, function, value, setname)', - default=[]) + parser.add_argument( + "-s", + "--script", + nargs='+', + action=ScriptAction, + help='Python script instructions (path, function, value, setname)', + default=[]) return parser @@ -423,8 +460,9 @@ def main(): """ parser = curve_check_parser() args = parser.parse_args() - warnings, errors = compare_time_history_curves(args.filename, args.baseline, args.curve, args.tolerance, - args.output, args.n_column, args.units_time, args.script) + warnings, errors = compare_time_history_curves( + args.filename, args.baseline, args.curve, args.tolerance, args.output, + args.n_column, args.units_time, args.script) # Write errors/warnings to the screen if args.Werror: diff --git a/geos_ats_package/geos_ats/helpers/permute_array.py b/geos_ats_package/geos_ats/helpers/permute_array.py index bbc9491..e150de5 100644 --- a/geos_ats_package/geos_ats/helpers/permute_array.py +++ b/geos_ats_package/geos_ats/helpers/permute_array.py @@ -1,4 +1,4 @@ -import numpy as np # type: ignore[import] +import numpy as np # type: ignore[import] import logging logger = logging.getLogger('geos_ats') @@ -10,16 +10,18 @@ def permuteArray(data, shape, permutation): return None, msg if len(permutation.shape) != 1: - msg = "The permutation must be a 1D array, not %s" % len(permutation.shape) + msg = "The permutation must be a 1D array, not %s" % len( + permutation.shape) return None, msg if shape.size != permutation.size: - msg = "The shape and permutation arrays must have the same length. %s != %s" % (shape.size, permutation.size) + msg = "The shape and permutation arrays must have the same length. %s != %s" % ( + shape.size, permutation.size) return None, msg if np.prod(shape) != data.size: - msg = "The shape is %s which yields a total size of %s but the real size is %s." % (shape, np.prod(shape), - data.size) + msg = "The shape is %s which yields a total size of %s but the real size is %s." % ( + shape, np.prod(shape), data.size) return None, msg if np.any(np.sort(permutation) != np.arange(shape.size)): @@ -38,7 +40,8 @@ def permuteArray(data, shape, permutation): data = np.transpose(data, reverse_permutation) if np.any(data.shape != shape): - msg = "Reshaping failed. Shape is %s but should be %s" % (data.shape, shape) + msg = "Reshaping failed. Shape is %s but should be %s" % (data.shape, + shape) return None, msg return data, None @@ -50,7 +53,8 @@ def testPermuteArray(shape, permutation): original_data = np.arange(np.prod(shape)).reshape(shape) transposed_data = original_data.transpose(permutation) - reshaped_data, error_msg = permuteArray(transposed_data.flatten(), shape, permutation) + reshaped_data, error_msg = permuteArray(transposed_data.flatten(), + shape, permutation) assert (error_msg is None) assert (np.all(original_data == reshaped_data)) diff --git a/geos_ats_package/geos_ats/helpers/restart_check.py b/geos_ats_package/geos_ats/helpers/restart_check.py index 261c3bf..33a32b6 100644 --- a/geos_ats_package/geos_ats/helpers/restart_check.py +++ b/geos_ats_package/geos_ats/helpers/restart_check.py @@ -1,6 +1,6 @@ -import h5py # type: ignore[import] -from mpi4py import MPI # type: ignore[import] -import numpy as np # type: ignore[import] +import h5py # type: ignore[import] +from mpi4py import MPI # type: ignore[import] +import numpy as np # type: ignore[import] import sys import os import re @@ -8,14 +8,17 @@ import logging from pathlib import Path try: - from geos_ats.helpers.permute_array import permuteArray # type: ignore[import] + from geos_ats.helpers.permute_array import permuteArray # type: ignore[import] except ImportError: # Fallback method to be used if geos_ats isn't found - from permute_array import permuteArray # type: ignore[import] + from permute_array import permuteArray # type: ignore[import] RTOL_DEFAULT = 0.0 ATOL_DEFAULT = 0.0 -EXCLUDE_DEFAULT = [".*/commandLine", ".*/schema$", ".*/globalToLocalMap", ".*/timeHistoryOutput.*/restart"] +EXCLUDE_DEFAULT = [ + ".*/commandLine", ".*/schema$", ".*/globalToLocalMap", + ".*/timeHistoryOutput.*/restart" +] logger = logging.getLogger('geos_ats') @@ -80,11 +83,14 @@ def __init__(self, def filesDiffer(self): try: - with h5py.File(self.file_path, "r") as file, h5py.File(self.baseline_path, "r") as base_file: + with h5py.File(self.file_path, + "r") as file, h5py.File(self.baseline_path, + "r") as base_file: self.file_path = file.filename self.baseline_path = base_file.filename self.output.write("\nRank %s is comparing %s with %s \n" % - (MPI.COMM_WORLD.Get_rank(), self.file_path, self.baseline_path)) + (MPI.COMM_WORLD.Get_rank(), self.file_path, + self.baseline_path)) self.compareGroups(file, base_file) except IOError as e: @@ -163,7 +169,8 @@ def compareFloatScalars(self, path, val, base_val): """ dif = abs(val - base_val) if dif > self.atol and dif > self.rtol * abs(base_val): - msg = "Scalar values of types %s and %s differ: %s, %s.\n" % (val.dtype, base_val.dtype, val, base_val) + msg = "Scalar values of types %s and %s differ: %s, %s.\n" % ( + val.dtype, base_val.dtype, val, base_val) self.errorMsg(path, msg, True) def compareIntScalars(self, path, val, base_val): @@ -175,7 +182,8 @@ def compareIntScalars(self, path, val, base_val): BASE_VAL [in]: The baseline value to compare against. """ if val != base_val: - msg = "Scalar values of types %s and %s differ: %s, %s.\n" % (val.dtype, base_val.dtype, val, base_val) + msg = "Scalar values of types %s and %s differ: %s, %s.\n" % ( + val.dtype, base_val.dtype, val, base_val) self.errorMsg(path, msg, True) def compareStringScalars(self, path, val, base_val): @@ -187,7 +195,8 @@ def compareStringScalars(self, path, val, base_val): BASE_VAL [in]: The baseline value to compare against. """ if val != base_val: - msg = "Scalar values of types %s and %s differ: %s, %s.\n" % (val.dtype, base_val.dtype, val, base_val) + msg = "Scalar values of types %s and %s differ: %s, %s.\n" % ( + val.dtype, base_val.dtype, val, base_val) self.errorMsg(path, msg, True) def compareFloatArrays(self, path, arr, base_arr): @@ -225,8 +234,8 @@ def compareFloatArrays(self, path, arr, base_arr): # If the shapes are different they can't be compared. if arr.shape != base_arr.shape: - msg = "Datasets have different shapes and therefore can't be compared: %s, %s.\n" % (arr.shape, - base_arr.shape) + msg = "Datasets have different shapes and therefore can't be compared: %s, %s.\n" % ( + arr.shape, base_arr.shape) self.errorMsg(path, msg, True) return @@ -252,7 +261,8 @@ def compareFloatArrays(self, path, arr, base_arr): absTol = self.atol # Get the indices of the max absolute and relative error - max_absolute_index = np.unravel_index(np.argmax(difference), difference.shape) + max_absolute_index = np.unravel_index(np.argmax(difference), + difference.shape) relative_difference = difference / (abs_base_arr + 1e-20) @@ -260,7 +270,8 @@ def compareFloatArrays(self, path, arr, base_arr): if self.atol != 0: relative_difference = np.nan_to_num(relative_difference, 0) - max_relative_index = np.unravel_index(np.argmax(relative_difference), relative_difference.shape) + max_relative_index = np.unravel_index(np.argmax(relative_difference), + relative_difference.shape) if self.rtol != 0.0: relative_difference /= self.rtol @@ -274,7 +285,9 @@ def compareFloatArrays(self, path, arr, base_arr): absolute_limited = np.zeros(q.shape, dtype=bool) else: # Multiply ABS_BASE_ARR by RTOL and rename it to RTOL_ABS_BASE - rtol_abs_base = np.multiply(self.rtol, abs_base_arr, out=abs_base_arr) + rtol_abs_base = np.multiply(self.rtol, + abs_base_arr, + out=abs_base_arr) # Calculate which entries are limited by the relative tolerance. relative_limited = rtol_abs_base > absTol @@ -284,7 +297,8 @@ def compareFloatArrays(self, path, arr, base_arr): q[relative_limited] = relative_difference[relative_limited] # Compute q for the entries which are limited by the absolute tolerance. - absolute_limited = np.logical_not(relative_limited, out=relative_limited) + absolute_limited = np.logical_not(relative_limited, + out=relative_limited) q[absolute_limited] /= absTol # If the maximum q value is greater than 1.0 than issue an error. @@ -292,42 +306,55 @@ def compareFloatArrays(self, path, arr, base_arr): offenders = np.greater(q, 1.0) n_offenders = np.sum(offenders) - absolute_offenders = np.logical_and(offenders, absolute_limited, out=offenders) + absolute_offenders = np.logical_and(offenders, + absolute_limited, + out=offenders) q_num_absolute = np.sum(absolute_offenders) if q_num_absolute > 0: absolute_qs = q * absolute_offenders q_max_absolute = np.max(absolute_qs) - q_max_absolute_index = np.unravel_index(np.argmax(absolute_qs), absolute_qs.shape) + q_max_absolute_index = np.unravel_index( + np.argmax(absolute_qs), absolute_qs.shape) q_mean_absolute = np.mean(absolute_qs) q_std_absolute = np.std(absolute_qs) offenders = np.greater(q, 1.0, out=offenders) - relative_limited = np.logical_not(absolute_limited, out=absolute_limited) - relative_offenders = np.logical_and(offenders, relative_limited, out=offenders) + relative_limited = np.logical_not(absolute_limited, + out=absolute_limited) + relative_offenders = np.logical_and(offenders, + relative_limited, + out=offenders) q_num_relative = np.sum(relative_offenders) if q_num_relative > 0: relative_qs = q * relative_offenders q_max_relative = np.max(relative_qs) - q_max_relative_index = np.unravel_index(np.argmax(relative_qs), q.shape) + q_max_relative_index = np.unravel_index( + np.argmax(relative_qs), q.shape) q_mean_relative = np.mean(relative_qs) q_std_relative = np.std(relative_qs) message = "Arrays of types %s and %s have %d values of which %d fail both the relative and absolute tests.\n" % ( arr.dtype, base_arr.dtype, offenders.size, n_offenders) message += "\tMax absolute difference is at index %s: value = %s, base_value = %s\n" % ( - max_absolute_index, arr[max_absolute_index], base_arr[max_absolute_index]) + max_absolute_index, arr[max_absolute_index], + base_arr[max_absolute_index]) message += "\tMax relative difference is at index %s: value = %s, base_value = %s\n" % ( - max_relative_index, arr[max_relative_index], base_arr[max_relative_index]) + max_relative_index, arr[max_relative_index], + base_arr[max_relative_index]) message += "Statistics of the q values greater than 1.0 defined by absolute tolerance: N = %d\n" % q_num_absolute if q_num_absolute > 0: - message += "\tmax = %s, mean = %s, std = %s\n" % (q_max_absolute, q_mean_absolute, q_std_absolute) + message += "\tmax = %s, mean = %s, std = %s\n" % ( + q_max_absolute, q_mean_absolute, q_std_absolute) message += "\tmax is at index %s, value = %s, base_value = %s\n" % ( - q_max_absolute_index, arr[q_max_absolute_index], base_arr[q_max_absolute_index]) + q_max_absolute_index, arr[q_max_absolute_index], + base_arr[q_max_absolute_index]) message += "Statistics of the q values greater than 1.0 defined by relative tolerance: N = %d\n" % q_num_relative if q_num_relative > 0: - message += "\tmax = %s, mean = %s, std = %s\n" % (q_max_relative, q_mean_relative, q_std_relative) + message += "\tmax = %s, mean = %s, std = %s\n" % ( + q_max_relative, q_mean_relative, q_std_relative) message += "\tmax is at index %s, value = %s, base_value = %s\n" % ( - q_max_relative_index, arr[q_max_relative_index], base_arr[q_max_relative_index]) + q_max_relative_index, arr[q_max_relative_index], + base_arr[q_max_relative_index]) self.errorMsg(path, message, True) def compareIntArrays(self, path, arr, base_arr): @@ -340,8 +367,8 @@ def compareIntArrays(self, path, arr, base_arr): """ # If the shapes are different they can't be compared. if arr.shape != base_arr.shape: - msg = "Datasets have different shapes and therefore can't be compared: %s, %s.\n" % (arr.shape, - base_arr.shape) + msg = "Datasets have different shapes and therefore can't be compared: %s, %s.\n" % ( + arr.shape, base_arr.shape) self.errorMsg(path, msg, True) return @@ -355,7 +382,8 @@ def compareIntArrays(self, path, arr, base_arr): n_offenders = np.sum(offenders) if n_offenders != 0: - max_index = np.unravel_index(np.argmax(difference), difference.shape) + max_index = np.unravel_index(np.argmax(difference), + difference.shape) max_difference = difference[max_index] offenders_mean = np.mean(difference[offenders]) offenders_std = np.std(difference[offenders]) @@ -363,8 +391,8 @@ def compareIntArrays(self, path, arr, base_arr): message = "Arrays of types %s and %s have %s values of which %d have differing values.\n" % ( arr.dtype, base_arr.dtype, offenders.size, n_offenders) message += "Statistics of the differences greater than 0:\n" - message += "\tmax_index = %s, max = %s, mean = %s, std = %s\n" % (max_index, max_difference, offenders_mean, - offenders_std) + message += "\tmax_index = %s, max = %s, mean = %s, std = %s\n" % ( + max_index, max_difference, offenders_mean, offenders_std) self.errorMsg(path, message, True) def compareStringArrays(self, path, arr, base_arr): @@ -396,12 +424,14 @@ def compareData(self, path, arr, base_arr): np_strings = set(['S', 'a', 'U']) int_compare = arr.dtype.kind in np_ints and base_arr.dtype.kind in np_ints - float_compare = not int_compare and (arr.dtype.kind in np_numeric and base_arr.dtype.kind in np_numeric) + float_compare = not int_compare and (arr.dtype.kind in np_numeric and + base_arr.dtype.kind in np_numeric) string_compare = arr.dtype.kind in np_strings and base_arr.dtype.kind in np_strings # If the datasets have different types issue a warning. if arr.dtype != base_arr.dtype: - msg = "Datasets have different types: %s, %s.\n" % (arr.dtype, base_arr.dtype) + msg = "Datasets have different types: %s, %s.\n" % (arr.dtype, + base_arr.dtype) self.warningMsg(path, msg) # Handle empty datasets @@ -412,9 +442,15 @@ def compareData(self, path, arr, base_arr): if arr.size == 0 and base_arr.size == 0: return elif arr.size is None and base_arr.size is not None: - self.errorMsg(path, "File to compare has an empty dataset where the baseline's dataset is not empty.\n") + self.errorMsg( + path, + "File to compare has an empty dataset where the baseline's dataset is not empty.\n" + ) elif base_arr.size is None and arr.size is not None: - self.warningMsg(path, "Baseline has an empty dataset where the file to compare's dataset is not empty.\n") + self.warningMsg( + path, + "Baseline has an empty dataset where the file to compare's dataset is not empty.\n" + ) # If either of the datasets is a scalar convert it to an array. if arr.shape == (): @@ -433,7 +469,9 @@ def compareData(self, path, arr, base_arr): elif string_compare: return self.compareStringScalars(path, val, base_val) else: - return self.warningMsg(path, "Unrecognized type combination: %s %s.\n" % (arr.dtype, base_arr.dtype)) + return self.warningMsg( + path, "Unrecognized type combination: %s %s.\n" % + (arr.dtype, base_arr.dtype)) # Do the actual comparison. if float_compare: @@ -443,7 +481,9 @@ def compareData(self, path, arr, base_arr): elif string_compare: return self.compareStringArrays(path, arr, base_arr) else: - return self.warningMsg(path, "Unrecognized type combination: %s %s.\n" % (arr.dtype, base_arr.dtype)) + return self.warningMsg( + path, "Unrecognized type combination: %s %s.\n" % + (arr.dtype, base_arr.dtype)) def compareAttributes(self, path, attrs, base_attrs): """ @@ -506,8 +546,10 @@ def canCompare(self, group, base_group, name): return True def compareLvArrays(self, group, base_group, other_children_to_check): - if self.canCompare(group, base_group, "__dimensions__") and self.canCompare( - group, base_group, "__permutation__") and self.canCompare(group, base_group, "__values__"): + if self.canCompare( + group, base_group, "__dimensions__") and self.canCompare( + group, base_group, "__permutation__") and self.canCompare( + group, base_group, "__values__"): other_children_to_check.remove("__dimensions__") other_children_to_check.remove("__permutation__") other_children_to_check.remove("__values__") @@ -516,10 +558,12 @@ def compareLvArrays(self, group, base_group, other_children_to_check): base_dimensions = base_group["__dimensions__"][:] if len(dimensions.shape) != 1: - msg = "The dimensions of an LvArray must itself be a 1D array not %s\n" % len(dimensions.shape) + msg = "The dimensions of an LvArray must itself be a 1D array not %s\n" % len( + dimensions.shape) self.errorMsg(group.name, msg) - if dimensions.shape != base_dimensions.shape or np.any(dimensions != base_dimensions): + if dimensions.shape != base_dimensions.shape or np.any( + dimensions != base_dimensions): msg = "Cannot compare LvArrays because they have different dimensions. Dimensions = %s, base dimensions = %s\n" % ( dimensions, base_dimensions) self.errorMsg(group.name, msg) @@ -529,17 +573,20 @@ def compareLvArrays(self, group, base_group, other_children_to_check): base_permutation = base_group["__permutation__"][:] if len(permutation.shape) != 1: - msg = "The permutation of an LvArray must itself be a 1D array not %s\n" % len(permutation.shape) + msg = "The permutation of an LvArray must itself be a 1D array not %s\n" % len( + permutation.shape) self.errorMsg(group.name, msg) - if permutation.shape != dimensions.shape or np.any(np.sort(permutation) != np.arange(permutation.size)): + if permutation.shape != dimensions.shape or np.any( + np.sort(permutation) != np.arange(permutation.size)): msg = "LvArray in the file to compare has an invalid permutation. Dimensions = %s, Permutation = %s\n" % ( dimensions, permutation) self.errorMsg(group.name, msg) return True if base_permutation.shape != base_dimensions.shape or np.any( - np.sort(base_permutation) != np.arange(base_permutation.size)): + np.sort(base_permutation) != np.arange( + base_permutation.size)): msg = "LvArray in the baseline has an invalid permutation. Dimensions = %s, Permutation = %s\n" % ( base_dimensions, base_permutation) self.errorMsg(group.name, msg) @@ -554,7 +601,8 @@ def compareLvArrays(self, group, base_group, other_children_to_check): self.errorMsg(group.name, msg) return True - base_values, errorMsg = permuteArray(base_values, base_dimensions, base_permutation) + base_values, errorMsg = permuteArray(base_values, base_dimensions, + base_permutation) if base_values is None: msg = "Failed to permute the baseline LvArray: %s\n" % errorMsg self.errorMsg(group.name, msg) @@ -598,7 +646,9 @@ def compareGroups(self, group, base_group): elif isinstance(item1, h5py.Dataset): self.compareDatasets(item1, item2) else: - self.warningMsg(path, "Child %s has unknown type: %s.\n" % (name, type(item1))) + self.warningMsg( + path, "Child %s has unknown type: %s.\n" % + (name, type(item1))) def findFiles(file_pattern, baseline_pattern, comparison_args): @@ -619,14 +669,19 @@ def findFiles(file_pattern, baseline_pattern, comparison_args): files_to_compare = None with open(output_path, 'w') as output_file: comparison_args["output"] = output_file - writeHeader(file_pattern, file_path, baseline_pattern, baseline_path, comparison_args) + writeHeader(file_pattern, file_path, baseline_pattern, baseline_path, + comparison_args) # Check if comparing root files. if file_path.endswith(".root") and baseline_path.endswith(".root"): p = [re.compile("/file_pattern"), re.compile("/protocol/version")] - comp = FileComparison(file_path, baseline_path, 0.0, 0.0, p, output_file, True, False) + comp = FileComparison(file_path, baseline_path, 0.0, 0.0, p, + output_file, True, False) if comp.filesDiffer(): - write(output_file, "The root files are different, cannot compare data files.\n") + write( + output_file, + "The root files are different, cannot compare data files.\n" + ) return output_base_path, None else: write(output_file, "The root files are similar.\n") @@ -635,16 +690,20 @@ def findFiles(file_pattern, baseline_pattern, comparison_args): # We know the number of files are the same from the above comparison. with h5py.File(file_path, "r") as f: numberOfFiles = f["number_of_files"][0] - file_data_pattern = "".join(f["file_pattern"][:].tobytes().decode('ascii')[:-1]) + file_data_pattern = "".join( + f["file_pattern"][:].tobytes().decode('ascii')[:-1]) with h5py.File(baseline_path, "r") as f: - baseline_data_pattern = "".join(f["file_pattern"][:].tobytes().decode('ascii')[:-1]) + baseline_data_pattern = "".join( + f["file_pattern"][:].tobytes().decode('ascii')[:-1]) # Get the paths to the data files. files_to_compare = [] for i in range(numberOfFiles): - path_to_data = os.path.join(os.path.dirname(file_path), file_data_pattern % i) - path_to_baseline_data = os.path.join(os.path.dirname(baseline_path), baseline_data_pattern % i) + path_to_data = os.path.join(os.path.dirname(file_path), + file_data_pattern % i) + path_to_baseline_data = os.path.join( + os.path.dirname(baseline_path), baseline_data_pattern % i) files_to_compare += [(path_to_data, path_to_baseline_data)] else: @@ -692,7 +751,8 @@ def findMaxMatchingFile(file_path): return os.path.join(file_directory, max_match) -def writeHeader(file_pattern, file_path, baseline_pattern, baseline_path, args): +def writeHeader(file_pattern, file_path, baseline_pattern, baseline_path, + args): """ Write the header. @@ -704,11 +764,13 @@ def writeHeader(file_pattern, file_path, baseline_pattern, baseline_path, args): """ output = args["output"] msg = "Comparison of file %s from pattern %s\n" % (file_path, file_pattern) - msg += "Baseline file %s from pattern %s\n" % (baseline_path, baseline_pattern) + msg += "Baseline file %s from pattern %s\n" % (baseline_path, + baseline_pattern) msg += "Relative tolerance: %s\n" % args["rtol"] msg += "Absolute tolerance: %s\n" % args["atol"] msg += "Output file: %s\n" % output.name - msg += "Excluded groups: %s\n" % list(map(lambda e: e.pattern, args["regex_expressions"])) + msg += "Excluded groups: %s\n" % list( + map(lambda e: e.pattern, args["regex_expressions"])) msg += "Warnings are errors: %s\n\n" % args["warnings_are_errors"] write(output, msg) @@ -727,33 +789,47 @@ def main(): n_ranks = comm.Get_size() parser = argparse.ArgumentParser() - parser.add_argument("file_pattern", help="The pattern used to find the file to compare.") - parser.add_argument("baseline_pattern", help="The pattern used to find the baseline file.") - parser.add_argument("-r", - "--relative", - type=float, - help="The relative tolerance for floating point differences, default is %s." % RTOL_DEFAULT, - default=RTOL_DEFAULT) - parser.add_argument("-a", - "--absolute", - type=float, - help="The absolute tolerance for floating point differences, default is %s." % ATOL_DEFAULT, - default=ATOL_DEFAULT) - parser.add_argument("-e", - "--exclude", - action='append', - help="Regular expressions specifying which groups to skip, default is %s." % EXCLUDE_DEFAULT, - default=EXCLUDE_DEFAULT) - parser.add_argument("-m", - "--skip-missing", - action="store_true", - help="Ignore values that are missing from either the baseline or target file.", - default=False) - parser.add_argument("-w", - "--Werror", - action="store_true", - help="Force all warnings to be errors, default is False.", - default=False) + parser.add_argument("file_pattern", + help="The pattern used to find the file to compare.") + parser.add_argument("baseline_pattern", + help="The pattern used to find the baseline file.") + parser.add_argument( + "-r", + "--relative", + type=float, + help= + "The relative tolerance for floating point differences, default is %s." + % RTOL_DEFAULT, + default=RTOL_DEFAULT) + parser.add_argument( + "-a", + "--absolute", + type=float, + help= + "The absolute tolerance for floating point differences, default is %s." + % ATOL_DEFAULT, + default=ATOL_DEFAULT) + parser.add_argument( + "-e", + "--exclude", + action='append', + help= + "Regular expressions specifying which groups to skip, default is %s." % + EXCLUDE_DEFAULT, + default=EXCLUDE_DEFAULT) + parser.add_argument( + "-m", + "--skip-missing", + action="store_true", + help= + "Ignore values that are missing from either the baseline or target file.", + default=False) + parser.add_argument( + "-w", + "--Werror", + action="store_true", + help="Force all warnings to be errors, default is False.", + default=False) args = parser.parse_args() # Check the command line arguments @@ -773,7 +849,9 @@ def main(): comparison_args["skip_missing"] = args.skip_missing if rank == 0: - output_base_path, files_to_compare = findFiles(file_pattern, baseline_pattern, comparison_args) + output_base_path, files_to_compare = findFiles(file_pattern, + baseline_pattern, + comparison_args) else: output_base_path, files_to_compare = None, None @@ -787,13 +865,15 @@ def main(): for i in range(rank, len(files_to_compare), n_ranks): output_path = "%s.%d.restartcheck" % (output_base_path, i) diff_path = "%s.%d.diff.hdf5" % (output_base_path, i) - with open(output_path, 'w') as output_file, h5py.File(diff_path, "w") as diff_file: + with open(output_path, + 'w') as output_file, h5py.File(diff_path, "w") as diff_file: comparison_args["output"] = output_file comparison_args["diff_file"] = diff_file file_path, baseline_path = files_to_compare[i] logger.info(f"About to compare {file_path} and {baseline_path}") - if FileComparison(file_path, baseline_path, **comparison_args).filesDiffer(): + if FileComparison(file_path, baseline_path, + **comparison_args).filesDiffer(): differing_files += [files_to_compare[i]] output_file.write("The files are different.\n") else: @@ -813,14 +893,18 @@ def main(): if difference_found: write( - output_file, "\nCompared %d pairs of files of which %d are different.\n" % - (len(files_to_compare), len(all_differing_files))) + output_file, + "\nCompared %d pairs of files of which %d are different.\n" + % (len(files_to_compare), len(all_differing_files))) for file_path, base_path in all_differing_files: - write(output_file, "\t" + file_path + " and " + base_path + "\n") + write(output_file, + "\t" + file_path + " and " + base_path + "\n") return 1 else: - write(output_file, - "\nThe root files and the %d pairs of files compared are similar.\n" % len(files_to_compare)) + write( + output_file, + "\nThe root files and the %d pairs of files compared are similar.\n" + % len(files_to_compare)) return difference_found diff --git a/geos_ats_package/geos_ats/machine_utilities.py b/geos_ats_package/geos_ats/machine_utilities.py index 5e66b48..803ddb6 100644 --- a/geos_ats_package/geos_ats/machine_utilities.py +++ b/geos_ats_package/geos_ats/machine_utilities.py @@ -9,7 +9,9 @@ def CheckForEarlyTimeOut(test, retval, fraction): if not retval: return retval, fraction else: - if (config.max_retry > 0) and (config.retry_err_regexp != "") and (not hasattr(test, "checkstart")): + if (config.max_retry + > 0) and (config.retry_err_regexp + != "") and (not hasattr(test, "checkstart")): sourceFile = getattr(test, "errname") if os.path.exists(sourceFile): test.checkstart = 1 diff --git a/geos_ats_package/geos_ats/machines/batchGeosatsMoab.py b/geos_ats_package/geos_ats/machines/batchGeosatsMoab.py index 4f453cf..2af0b2d 100644 --- a/geos_ats_package/geos_ats/machines/batchGeosatsMoab.py +++ b/geos_ats_package/geos_ats/machines/batchGeosatsMoab.py @@ -1,9 +1,9 @@ #BATS:batchGeosatsMoab batchGeosatsMoab BatchGeosatsMoab -1 -from ats import machines, configuration, log, atsut, times, AtsTest # type: ignore[import] +from ats import machines, configuration, log, atsut, times, AtsTest # type: ignore[import] import subprocess, sys, os, shlex, time, socket, re -import utils, batchTemplate # type: ignore[import] -from batch import BatchMachine # type: ignore[import] +import utils, batchTemplate # type: ignore[import] +from batch import BatchMachine # type: ignore[import] import logging debug = configuration.debug @@ -19,9 +19,11 @@ def init(self): super(BatchGeosatsMoab, self).init() if "SLURM_NNODES" in os.environ.keys(): - self.ppn = int(os.getenv("SLURM_TASKS_PER_NODE", "1").split("(")[0]) + self.ppn = int( + os.getenv("SLURM_TASKS_PER_NODE", "1").split("(")[0]) elif "SLURM_JOB_NUM_NODES" in os.environ.keys(): - self.ppn = int(os.getenv("SLURM_JOB_CPUS_PER_NODE", "1").split("(")[0]) + self.ppn = int( + os.getenv("SLURM_JOB_CPUS_PER_NODE", "1").split("(")[0]) else: self.ppn = 0 @@ -43,9 +45,11 @@ def load(self, testlist): if t.groupSerialNumber == 1: testCase = getattr(t, "geos_atsTestCase", None) if testCase: - batchFilename = os.path.join(testCase.dirnamefull, "batch_%s.msub" % testCase.name) + batchFilename = os.path.join( + testCase.dirnamefull, "batch_%s.msub" % testCase.name) self.writeSubmitScript(batchFilename, testCase) - self.jobid = self.submitBatchScript(testCase.name, batchFilename) + self.jobid = self.submitBatchScript( + testCase.name, batchFilename) def writeSubmitScript(self, batchFilename, testCase): @@ -148,7 +152,9 @@ def submitBatchScript(self, testname, batchFilename): if config and config.batch_dryrun: return - p = subprocess.Popen(["msub", batchFilename], stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + p = subprocess.Popen(["msub", batchFilename], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT) out = p.communicate()[0] if p.returncode: diff --git a/geos_ats_package/geos_ats/machines/bgqos_0_ASQ.py b/geos_ats_package/geos_ats/machines/bgqos_0_ASQ.py index 6c470db..bbce9f8 100644 --- a/geos_ats_package/geos_ats/machines/bgqos_0_ASQ.py +++ b/geos_ats_package/geos_ats/machines/bgqos_0_ASQ.py @@ -1,10 +1,10 @@ #ATS:bgqos_0_ASQ machines.bgqos_0_ASQ bgqos_0_ASQMachine 16 -from ats import machines, debug, atsut # type: ignore[import] +from ats import machines, debug, atsut # type: ignore[import] from ats import log, terminal from ats import configuration -from ats.atsut import RUNNING, TIMEDOUT, SKIPPED, BATCHED, INVALID, PASSED, FAILED, CREATED, FILTERED, HALTED, EXPECTED # type: ignore[import] -import utils # type: ignore[import] +from ats.atsut import RUNNING, TIMEDOUT, SKIPPED, BATCHED, INVALID, PASSED, FAILED, CREATED, FILTERED, HALTED, EXPECTED # type: ignore[import] +import utils # type: ignore[import] import time import sys @@ -81,14 +81,16 @@ def examineOptions(self, options): if len(self.allNodeList) == 0: self.removeSrunStep = True else: - self.stepId, self.nodeStepNumDic = utils.setStepNumWithNode(len(self.allNodeList)) + self.stepId, self.nodeStepNumDic = utils.setStepNumWithNode( + len(self.allNodeList)) for oneNode in self.allNodeList: self.nodeProcAvailDic[oneNode] = self.npMax self.stepInUse = self.stepToUse # Let's check if there exists a srun process if len(self.allNodeList) > 0: - srunDefunct = utils.checkForSrunDefunct(self.allNodeList[0]) + srunDefunct = utils.checkForSrunDefunct( + self.allNodeList[0]) self.numberMaxProcessors -= srunDefunct self.nodeProcAvailDic[self.allNodeList[0]] -= srunDefunct @@ -107,7 +109,8 @@ def getResults(self): return results def label(self): - return "BG/Q %d nodes %d processors per node." % (self.numNodes, self.npMax) + return "BG/Q %d nodes %d processors per node." % (self.numNodes, + self.npMax) def calculateCommandList(self, test): """Prepare for run of executable using a suitable command. First we get the plain command @@ -168,9 +171,11 @@ def canRun(self, test): test.numberOfNodesNeeded += 1 if self.removeSrunStep: - test.requiredNP = max(test.np, self.npMax * test.numberOfNodesNeeded) + test.requiredNP = max(test.np, + self.npMax * test.numberOfNodesNeeded) if test.requiredNP > self.numberMaxProcessors: - return "Too many processors required, %d (limit is %d)" % (test.requiredNP, self.numberMaxProcessors) + return "Too many processors required, %d (limit is %d)" % ( + test.requiredNP, self.numberMaxProcessors) def canRunNow(self, test): "Is this machine able to run this test now? Return True/False" @@ -187,14 +192,15 @@ def noteLaunch(self, test): if not self.removeSrunStep: if test.srunRelativeNode < 0: - self.nodeProcAvailDic = utils.removeFromUsedTotalDicNoSrun(self.nodeProcAvailDic, self.nodeStepNumDic, - self.npMax, test.np, self.allNodeList) + self.nodeProcAvailDic = utils.removeFromUsedTotalDicNoSrun( + self.nodeProcAvailDic, self.nodeStepNumDic, self.npMax, + test.np, self.allNodeList) else: - self.nodeProcAvailDic = utils.removeFromUsedTotalDic(self.nodeProcAvailDic, self.nodeStepNumDic, - self.npMax, test.step, test.np, - test.numberOfNodesNeeded, test.numNodesToUse, - test.srunRelativeNode, self.stepId, - self.allNodeList) + self.nodeProcAvailDic = utils.removeFromUsedTotalDic( + self.nodeProcAvailDic, self.nodeStepNumDic, self.npMax, + test.step, test.np, test.numberOfNodesNeeded, + test.numNodesToUse, test.srunRelativeNode, self.stepId, + self.allNodeList) self.npBusy += max(test.np, 1) else: # this is necessary when srun exclusive is used. @@ -220,8 +226,11 @@ def noteEnd(self, test): self.npBusy -= max(test.np, test.numberOfNodesNeeded * self.npMax) if debug(): - log("Finished %s, #total proc in use = %d" % (test.name, self.npBusy), echo=True) - self.scheduler.schedule("Finished %s, #total proc in use = %d" % (test.name, self.npBusy)) + log("Finished %s, #total proc in use = %d" % + (test.name, self.npBusy), + echo=True) + self.scheduler.schedule("Finished %s, #total proc in use = %d" % + (test.name, self.npBusy)) self.numberTestsRunning = self.npBusy @@ -230,7 +239,10 @@ def periodicReport(self): # Let's also write out the tests that are waiting .... super(bgqos_0_ASQMachine, self).periodicReport() - currentEligible = [t.name for t in self.scheduler.testlist() if t.status is atsut.CREATED] + currentEligible = [ + t.name for t in self.scheduler.testlist() + if t.status is atsut.CREATED + ] if len(currentEligible) > 1: terminal("WAITING:", ", ".join(currentEligible[:5]), "... (more)") @@ -243,12 +255,13 @@ def kill(self, test): if test.status is RUNNING or test.status is TIMEDOUT: try: - retcode = subprocess.call("scancel" + " -n " + test.jobname, shell=True) + retcode = subprocess.call("scancel" + " -n " + test.jobname, + shell=True) if retcode < 0: - log("---- kill() in bgqos_0_ASQ.py, command= scancel -n %s failed with return code -%d ----" % - (test.jobname, retcode), + log("---- kill() in bgqos_0_ASQ.py, command= scancel -n %s failed with return code -%d ----" + % (test.jobname, retcode), echo=True) except OSError as e: - log("---- kill() in bgqos_0_ASQ.py, execution of command failed (scancel -n %s) failed: %s----" % - (test.jobname, e), + log("---- kill() in bgqos_0_ASQ.py, execution of command failed (scancel -n %s) failed: %s----" + % (test.jobname, e), echo=True) diff --git a/geos_ats_package/geos_ats/machines/darwin.py b/geos_ats_package/geos_ats/machines/darwin.py index af7b8b5..952c345 100644 --- a/geos_ats_package/geos_ats/machines/darwin.py +++ b/geos_ats_package/geos_ats/machines/darwin.py @@ -1,6 +1,6 @@ #ATS:darwin machines.darwin DarwinMachine 16 -from openmpi import OpenmpiMachine # type: ignore[import] +from openmpi import OpenmpiMachine # type: ignore[import] class DarwinMachine(OpenmpiMachine): diff --git a/geos_ats_package/geos_ats/machines/geosAtsSlurmProcessorScheduled.py b/geos_ats_package/geos_ats/machines/geosAtsSlurmProcessorScheduled.py index 1770402..481aeee 100644 --- a/geos_ats_package/geos_ats/machines/geosAtsSlurmProcessorScheduled.py +++ b/geos_ats_package/geos_ats/machines/geosAtsSlurmProcessorScheduled.py @@ -8,7 +8,7 @@ from geos_ats.scheduler import scheduler from geos_ats.machine_utilities import CheckForEarlyTimeOut -from slurmProcessorScheduled import SlurmProcessorScheduled # type: ignore[import] +from slurmProcessorScheduled import SlurmProcessorScheduled # type: ignore[import] import subprocess import logging @@ -25,7 +25,9 @@ def init(self): # CPUs (%c) is actually threads, so multiply sockets (%X) X cores (%Y) # to get actual number of processors (we ignore hyprethreading). sinfoCmd = 'sinfo -o"%X %Y"' - proc = subprocess.Popen(sinfoCmd, shell=True, stdout=subprocess.PIPE) + proc = subprocess.Popen(sinfoCmd, + shell=True, + stdout=subprocess.PIPE) stdout_value = proc.communicate()[0] (sockets, cores) = stdout_value.split('\n')[1].split() self.npMaxH = int(sockets) * int(cores) @@ -36,12 +38,14 @@ def init(self): self.scheduler = scheduler() def label(self): - return "GeosAtsSlurmProcessorScheduled: %d nodes, %d processors per node." % (self.numNodes, self.npMax) + return "GeosAtsSlurmProcessorScheduled: %d nodes, %d processors per node." % ( + self.numNodes, self.npMax) def checkForTimeOut(self, test): """ Check the time elapsed since test's start time. If greater then the timelimit, return true, else return false. test's end time is set if time elapsed exceeds time limit. Also return true if retry string if found.""" - retval, fraction = super(GeosAtsSlurmProcessorScheduled, self).checkForTimeOut(test) + retval, fraction = super(GeosAtsSlurmProcessorScheduled, + self).checkForTimeOut(test) return CheckForEarlyTimeOut(test, retval, fraction) diff --git a/geos_ats_package/geos_ats/machines/lassen.py b/geos_ats_package/geos_ats/machines/lassen.py index 1bdd93b..86fcbd9 100644 --- a/geos_ats_package/geos_ats/machines/lassen.py +++ b/geos_ats_package/geos_ats/machines/lassen.py @@ -1,11 +1,11 @@ #ATS:SequentialMachine SELF lassenMachine 1 #ATS:lassen SELF lassenMachine 1 -from ats import machines # type: ignore[import] +from ats import machines # type: ignore[import] from ats import machines, debug, atsut from ats import log, terminal from ats import configuration -from ats.atsut import RUNNING, TIMEDOUT # type: ignore[import] +from ats.atsut import RUNNING, TIMEDOUT # type: ignore[import] from ats import AtsTest import os import subprocess @@ -27,7 +27,8 @@ def examineOptions(self, options): super(lassenMachine, self).examineOptions(options) # Get total cpu cores available, and convert to number of gpus! - self.numberMaxProcessors = int(os.getenv("LSB_MAX_NUM_PROCESSORS", "0")) - 1 + self.numberMaxProcessors = int(os.getenv("LSB_MAX_NUM_PROCESSORS", + "0")) - 1 self.numberMaxGPUS = self.numberMaxProcessors / 10 self.numberTestsRunningMax = self.numberMaxProcessors @@ -61,7 +62,11 @@ def calculateCommandList(self, test): commandList = [] ngpu = test.ngpu - commandList += ["lrun", "-n", "%d" % test.np, "--pack", "-g", "%d" % ngpu] + commandList += [ + "lrun", "-n", + "%d" % test.np, "--pack", "-g", + "%d" % ngpu + ] commandList += basicCommands return commandList @@ -114,6 +119,8 @@ def kill(self, test): try: retcode = subprocess.call("jskill all", shell=True) if retcode < 0: - log("command= %s failed with return code %d" % ("jskill all", retcode), echo=True) + log("command= %s failed with return code %d" % + ("jskill all", retcode), + echo=True) except: logger.info("Killing job") diff --git a/geos_ats_package/geos_ats/machines/nersc.py b/geos_ats_package/geos_ats/machines/nersc.py index 1eb2443..7512b88 100644 --- a/geos_ats_package/geos_ats/machines/nersc.py +++ b/geos_ats_package/geos_ats/machines/nersc.py @@ -8,7 +8,7 @@ from geos_ats.scheduler import scheduler from geos_ats.machine_utilities import CheckForEarlyTimeOut -from slurmProcessorScheduled import SlurmProcessorScheduled # type: ignore[import] +from slurmProcessorScheduled import SlurmProcessorScheduled # type: ignore[import] import subprocess import logging @@ -25,7 +25,9 @@ def init(self): # CPUs (%c) is actually threads, so multiply sockets (%X) X cores (%Y) # to get actual number of processors (we ignore hyprethreading). sinfoCmd = 'sinfo -o"%X %Y"' - proc = subprocess.Popen(sinfoCmd, shell=True, stdout=subprocess.PIPE) + proc = subprocess.Popen(sinfoCmd, + shell=True, + stdout=subprocess.PIPE) stdout_value = proc.communicate()[0] (sockets, cores) = stdout_value.split('\n')[1].split() self.npMaxH = int(sockets) * int(cores) @@ -36,7 +38,8 @@ def init(self): self.scheduler = scheduler() def label(self): - return "Nersc: %d nodes, %d processors per node." % (self.numNodes, self.npMax) + return "Nersc: %d nodes, %d processors per node." % (self.numNodes, + self.npMax) def checkForTimeOut(self, test): """ Check the time elapsed since test's start time. If greater diff --git a/geos_ats_package/geos_ats/machines/openmpi.py b/geos_ats_package/geos_ats/machines/openmpi.py index 842564d..dae89a5 100644 --- a/geos_ats_package/geos_ats/machines/openmpi.py +++ b/geos_ats_package/geos_ats/machines/openmpi.py @@ -1,12 +1,12 @@ #ATS:openmpi machines.openmpi OpenmpiMachine 16 import os -import ats # type: ignore[import] +import ats # type: ignore[import] from ats import machines from ats import terminal from ats import log import shlex -from ats.atsut import RUNNING, TIMEDOUT # type: ignore[import] +from ats.atsut import RUNNING, TIMEDOUT # type: ignore[import] from ats import AtsTest import logging @@ -127,7 +127,8 @@ def calculateCommandList(self, test): if self.precommand: import time timeNow = time.strftime('%H%M%S', time.localtime()) - test.jobname = "t%d_%d%s%s" % (test.np, test.serialNumber, test.namebase, timeNow) + test.jobname = "t%d_%d%s%s" % (test.np, test.serialNumber, + test.namebase, timeNow) pre = self.precommand % {"np": test.np, "J": test.jobname} commandList = pre.split() else: @@ -149,7 +150,8 @@ def canRun(self, test): def canRunNow(self, test): "Is this machine able to run this test now? Return True/False" np = max(test.np, 1) - return ((self.numtests < self.maxtests) and (self.numProcsAvailable >= np)) + return ((self.numtests < self.maxtests) + and (self.numProcsAvailable >= np)) def noteLaunch(self, test): """A test has been launched.""" @@ -166,10 +168,12 @@ def noteEnd(self, test): def periodicReport(self): "Report on current status of tasks" terminal("-" * 80) - terminal("CURRENTLY RUNNING %d of %d tests." % (self.numtests, self.maxtests)) + terminal("CURRENTLY RUNNING %d of %d tests." % + (self.numtests, self.maxtests)) terminal("-" * 80) terminal("CURRENTLY UTILIZING %d processors (max %d)." % - (self.numberMaxProcessors - self.numProcsAvailable, self.numberMaxProcessors)) + (self.numberMaxProcessors - self.numProcsAvailable, + self.numberMaxProcessors)) terminal("-" * 80) def kill(self, test): diff --git a/geos_ats_package/geos_ats/machines/summit.py b/geos_ats_package/geos_ats/machines/summit.py index 07b2c58..092f44c 100644 --- a/geos_ats_package/geos_ats/machines/summit.py +++ b/geos_ats_package/geos_ats/machines/summit.py @@ -1,11 +1,11 @@ #ATS:SequentialMachine machines.summit summitMachine 1 #ATS:summit machines.summit summitMachine 1 -from ats import machines # type: ignore[import] +from ats import machines # type: ignore[import] from ats import machines, debug, atsut from ats import log, terminal from ats import configuration -from ats.atsut import RUNNING, TIMEDOUT # type: ignore[import] +from ats.atsut import RUNNING, TIMEDOUT # type: ignore[import] from ats import AtsTest import os import subprocess @@ -27,7 +27,8 @@ def examineOptions(self, options): super(summitMachine, self).examineOptions(options) # Get total cpu cores available, and convert to number of gpus! - self.numberMaxProcessors = int(os.getenv("LSB_MAX_NUM_PROCESSORS", "0")) - 1 + self.numberMaxProcessors = int(os.getenv("LSB_MAX_NUM_PROCESSORS", + "0")) - 1 self.numberMaxGPUS = (self.numberMaxProcessors / 42) * 6 self.numberTestsRunningMax = self.numberMaxProcessors @@ -61,7 +62,11 @@ def calculateCommandList(self, test): commandList = [] ngpu = test.ngpu - commandList += ["jsrun", "--np", "%d" % test.np, "-g", "%d" % ngpu, "-c", "1", "-b", "rs"] + commandList += [ + "jsrun", "--np", + "%d" % test.np, "-g", + "%d" % ngpu, "-c", "1", "-b", "rs" + ] commandList += basicCommands return commandList @@ -114,6 +119,8 @@ def kill(self, test): try: retcode = subprocess.call("jskill all", shell=True) if retcode < 0: - log("command= %s failed with return code %d" % ("jskill all", retcode), echo=True) + log("command= %s failed with return code %d" % + ("jskill all", retcode), + echo=True) except: self.logger.debug("Terminating job`") diff --git a/geos_ats_package/geos_ats/main.py b/geos_ats_package/geos_ats/main.py index 2f37158..736942d 100644 --- a/geos_ats_package/geos_ats/main.py +++ b/geos_ats_package/geos_ats/main.py @@ -12,7 +12,8 @@ report_actions = ("run", "rerun", "report", "continue") # Setup the logger -logging.basicConfig(level=logging.DEBUG, format='(%(asctime)s %(module)s:%(lineno)d) %(message)s') +logging.basicConfig(level=logging.DEBUG, + format='(%(asctime)s %(module)s:%(lineno)d) %(message)s') logger = logging.getLogger('geos_ats') # Job records @@ -108,17 +109,23 @@ def handleShutdown(signum, frame): def handle_salloc_relaunch(options, originalargv, configOverride): tests = [ options.action in test_actions, options.salloc, options.machine - in ("SlurmProcessorScheduled", "GeosAtsSlurmProcessorScheduled"), "SLURM_JOB_ID" not in os.environ + in ("SlurmProcessorScheduled", "GeosAtsSlurmProcessorScheduled"), + "SLURM_JOB_ID" not in os.environ ] if all(tests): if options.sallocOptions != "": sallocCommand = ["salloc"] + options.sallocOptions.split(" ") else: - sallocCommand = ["salloc", "-ppdebug", "--exclusive", "-N", "%d" % options.numNodes] + sallocCommand = [ + "salloc", "-ppdebug", "--exclusive", "-N", + "%d" % options.numNodes + ] if "testmodifier" in configOverride: if configOverride["testmodifier"] == "memcheck": - p = subprocess.Popen(['sinfo', '-o', '%l', '-h', '-ppdebug'], stdout=subprocess.PIPE) + p = subprocess.Popen( + ['sinfo', '-o', '%l', '-h', '-ppdebug'], + stdout=subprocess.PIPE) out, err = p.communicate() tarray = out.split(":") seconds = tarray.pop() @@ -131,7 +138,8 @@ def handle_salloc_relaunch(options, originalargv, configOverride): days, hours = hours.split('-') except ValueError as e: logger.debug(e) - limit = min(360, (24 * int(days) + int(hours)) * 60 + int(minutes)) + limit = min(360, (24 * int(days) + int(hours)) * 60 + + int(minutes)) sallocCommand.extend(["-t", "%d" % limit]) # generate a "unique" name for the salloc job so we can remove it later @@ -146,7 +154,9 @@ def handle_salloc_relaunch(options, originalargv, configOverride): command = sallocCommand # omit --workingDir on relaunch, as we have already changed directories - relaunchargv = [x for x in originalargv if not x.startswith("--workingDir")] + relaunchargv = [ + x for x in originalargv if not x.startswith("--workingDir") + ] command += relaunchargv command += ["--logs=%s" % options.logs] p = subprocess.Popen(command) @@ -185,7 +195,9 @@ def create_log_directory(options): if os.path.islink(basename): os.remove(basename) else: - logger.error(f"unable to replace {basename} with a symlink to {options.logs}") + logger.error( + f"unable to replace {basename} with a symlink to {options.logs}" + ) if not os.path.exists(basename): os.symlink(options.logs, basename) @@ -207,7 +219,8 @@ def check_timing_file(options, config): if options.action in ["run", "rerun", "continue"]: if config.timing_file: if not os.path.isfile(config.timing_file): - logger.warning(f'Timing file does not exist {config.timing_file}') + logger.warning( + f'Timing file does not exist {config.timing_file}') return from geos_ats import configuration_record @@ -215,7 +228,8 @@ def check_timing_file(options, config): for line in filep: if not line.startswith('#'): tokens = line.split() - configuration_record.globalTestTimings[tokens[0]] = int(tokens[1]) + configuration_record.globalTestTimings[ + tokens[0]] = int(tokens[1]) def append_test_end_step(machine): @@ -236,7 +250,9 @@ def check_working_dir(workingDir): if os.path.isdir(workingDir): os.chdir(workingDir) else: - logger.error(f"The requested working dir does not appear to exist: {workingDir}") + logger.error( + f"The requested working dir does not appear to exist: {workingDir}" + ) quit() @@ -265,22 +281,30 @@ def infoParagraph(title, paragraphs): def info(args): - from geos_ats import (common_utilities, configuration_record, test_steps, suite_settings, test_case, test_modifier) + from geos_ats import (common_utilities, configuration_record, test_steps, + suite_settings, test_case, test_modifier) infoLabels = lambda *x: suite_settings.infoLabels(suite_settings.__file__) infoOwners = lambda *x: suite_settings.infoOwners(suite_settings.__file__) menu = common_utilities.InfoTopic("geos_ats info menu") - menu.addTopic("teststep", "Reference on all the TestStep", test_steps.infoTestSteps) - menu.addTopic("testcase", "Reference on the TestCase", test_case.infoTestCase) + menu.addTopic("teststep", "Reference on all the TestStep", + test_steps.infoTestSteps) + menu.addTopic("testcase", "Reference on the TestCase", + test_case.infoTestCase) menu.addTopic("labels", "List of labels", infoLabels) menu.addTopic("owners", "List of owners", infoOwners) - menu.addTopic("config", "Reference on config options", configuration_record.infoConfig) - menu.addTopic("actions", "Description of the command line action options", - lambda *x: infoOptions("command line actions", command_line_parsers.action_ptions)) - menu.addTopic("checks", "Description of the command line check options", - lambda *x: infoOptions("command line checks", command_line_parsers.check_options)) - menu.addTopic("modifiers", "List of test modifiers", test_modifier.infoTestModifier) + menu.addTopic("config", "Reference on config options", + configuration_record.infoConfig) + menu.addTopic( + "actions", "Description of the command line action options", + lambda *x: infoOptions("command line actions", command_line_parsers. + action_ptions)) + menu.addTopic( + "checks", "Description of the command line check options", lambda *x: + infoOptions("command line checks", command_line_parsers.check_options)) + menu.addTopic("modifiers", "List of test modifiers", + test_modifier.infoTestModifier) # menu.addTopic("testconfig", "Information on the testconfig.py file", # lambda *x: infoParagraph("testconfig", command_line_parsers.test_config_info)) menu.process(args) @@ -301,7 +325,8 @@ def report(manager): with open(configuration_record.config.report_text_file, "w") as filep: reporter.report(filep) if configuration_record.config.report_text_echo: - with open(configuration_record.config.report_text_file, "r") as filep: + with open(configuration_record.config.report_text_file, + "r") as filep: sys.stdout.write(filep.read()) if configuration_record.config.report_html: @@ -317,7 +342,8 @@ def report(manager): reporter = reporting.ReportTiming(testcases) if not configuration_record.config.report_timing_overwrite: try: - with open(configuration_record.config.timing_file, "r") as filep: + with open(configuration_record.config.timing_file, + "r") as filep: reporter.getOldTiming(filep) except IOError as e: logger.debug(e) @@ -333,7 +359,8 @@ def summary(manager, alog, short=False): return if hasattr(manager.machine, "getNumberOfProcessors"): - totalNumberOfProcessors = getattr(manager.machine, "getNumberOfProcessors", None)() + totalNumberOfProcessors = getattr(manager.machine, + "getNumberOfProcessors", None)() else: totalNumberOfProcessors = 1 reporter = reporting.ReportTextPeriodic(manager.testlist) @@ -388,8 +415,10 @@ def main(): if os.path.isdir(options.machine_dir): search_path = options.machine_dir else: - logger.error(f'Target machine dir does not exist: {options.machine_dir}') - logger.error('geos_ats will continue searching in the default path') + logger.error( + f'Target machine dir does not exist: {options.machine_dir}') + logger.error( + 'geos_ats will continue searching in the default path') if not search_path: search_path = os.path.dirname(machines.__file__) @@ -421,11 +450,14 @@ def main(): # Check the report location if options.logs: - config.report_html_file = os.path.join(options.logs, 'test_results.html') - config.report_text_file = os.path.join(options.logs, 'test_results.txt') + config.report_html_file = os.path.join(options.logs, + 'test_results.html') + config.report_text_file = os.path.join(options.logs, + 'test_results.txt') config.report_ini_file = os.path.join(options.logs, 'test_results.ini') - ats_files = check_ats_targets(options, testcases, configOverride, originalargv) + ats_files = check_ats_targets(options, testcases, configOverride, + originalargv) build_ats_arguments(options, ats_files, originalargv, config) # Additional setup tasks @@ -444,7 +476,7 @@ def main(): geos_atsStartTime = time.time() # Note: the sys.argv is read here by default - import ats # type: ignore[import] + import ats # type: ignore[import] ats.manager.init() logger.debug('Copying options to the geos_ats config record file') config.copy_values(ats.manager.machine) @@ -464,7 +496,8 @@ def main(): else: ats.AtsTest.glue(testcases="all") - from geos_ats import (common_utilities, suite_settings, test_case, test_steps, user_utilities) + from geos_ats import (common_utilities, suite_settings, test_case, + test_steps, user_utilities) # Set ats options append_geos_ats_summary(ats.manager) @@ -498,7 +531,10 @@ def main(): # clean if options.action == "veryclean": common_utilities.removeLogDirectories(os.getcwd()) - files = [config.report_html_file, config.report_ini_file, config.report_text_file] + files = [ + config.report_html_file, config.report_ini_file, + config.report_text_file + ] for f in files: if os.path.exists(f): os.remove(f) @@ -512,11 +548,12 @@ def main(): # return 0 if all tests passed, 1 otherwise try: if options.failIfTestsFail: - with open(os.path.join(options.logs, "test_results.html"), 'r') as f: + with open(os.path.join(options.logs, "test_results.html"), + 'r') as f: contents = ''.join(f.readlines()).split("DETAILED RESULTS")[1] messages = [ - "class=\"red\">FAIL", "class=\"yellow\">SKIPPED", "class=\"reddish\">FAIL", - "class=\"yellow\">NOT RUN" + "class=\"red\">FAIL", "class=\"yellow\">SKIPPED", + "class=\"reddish\">FAIL", "class=\"yellow\">NOT RUN" ] result = any([m in contents for m in messages]) except IOError as e: diff --git a/geos_ats_package/geos_ats/reporting.py b/geos_ats_package/geos_ats/reporting.py index 60ff3e4..ec9f3dc 100644 --- a/geos_ats_package/geos_ats/reporting.py +++ b/geos_ats_package/geos_ats/reporting.py @@ -5,7 +5,7 @@ import re from geos_ats.configuration_record import config import sys -import ats # type: ignore[import] +import ats # type: ignore[import] from configparser import ConfigParser import logging @@ -31,8 +31,9 @@ UNEXPECTEDPASS = 14 # A tuple of test status values. -STATUS = (FAILRUN, UNEXPECTEDPASS, FAILRUNOPTIONAL, FAILCHECK, FAILCHECKMINOR, TIMEOUT, NOTRUN, FILTERED, RUNNING, - INPROGRESS, PASS, EXPECTEDFAIL, SKIP, BATCH, NOTBUILT) +STATUS = (FAILRUN, UNEXPECTEDPASS, FAILRUNOPTIONAL, FAILCHECK, FAILCHECKMINOR, + TIMEOUT, NOTRUN, FILTERED, RUNNING, INPROGRESS, PASS, EXPECTEDFAIL, + SKIP, BATCH, NOTBUILT) STATUS_NOTDONE = (NOTRUN, RUNNING, INPROGRESS, BATCH) @@ -62,7 +63,8 @@ def getOldTiming(self, fp): def report(self, fp): for testcase in self.reportcases: if testcase.status in [PASS, TIMEOUT]: - self.timings[testcase.testcase.name] = int(testcase.testcase.status.totalTime()) + self.timings[testcase.testcase.name] = int( + testcase.testcase.status.totalTime()) output = "" for key in sorted(self.timings): output += "%s %d\n" % (key, self.timings[key]) @@ -78,7 +80,9 @@ def __init__(self, testcases): # A dictionary where the key is a status, and the value is a sequence of ReportTestCases self.reportcaseResults = {} for status in STATUS: - self.reportcaseResults[status] = [t for t in self.reportcases if t.status == status] + self.reportcaseResults[status] = [ + t for t in self.reportcases if t.status == status + ] self.displayName = {} self.displayName[FAILRUN] = "FAILRUN" @@ -104,7 +108,8 @@ def report(self, fp): configParser = ConfigParser() configParser.add_section("Info") - configParser.set("Info", "Time", time.strftime("%a, %d %b %Y %H:%M:%S")) + configParser.set("Info", "Time", + time.strftime("%a, %d %b %Y %H:%M:%S")) try: platform = socket.gethostname() except: @@ -120,7 +125,8 @@ def report(self, fp): if len(line_split) != 2: extraNotations += "\"" + line.strip() + "\"" continue - configParser.set("Info", line_split[0].strip(), line_split[1].strip()) + configParser.set("Info", line_split[0].strip(), + line_split[1].strip()) if extraNotations != "": configParser.set("Info", "Extra Notations", extraNotations) @@ -139,17 +145,23 @@ def report(self, fp): configParser.set("Custodians", testName, owner) if config.report_doc_link: - linkToDocumentation = os.path.join(config.report_doc_dir, testName, testName + ".html") + linkToDocumentation = os.path.join(config.report_doc_dir, + testName, + testName + ".html") if os.path.exists(linkToDocumentation): - configParser.set("Documentation", testName, linkToDocumentation) + configParser.set("Documentation", testName, + linkToDocumentation) else: if not reportcaseResult.testcase.nodoc: undocumentedTests.append(testName) - linkToDocumentation = getowner(testName, reportcaseResult.testcase) + linkToDocumentation = getowner(testName, + reportcaseResult.testcase) testNames = sorted(testNames) - configParser.set("Results", self.displayName[status], ";".join(testNames)) + configParser.set("Results", self.displayName[status], + ";".join(testNames)) undocumentedTests = sorted(undocumentedTests) - configParser.set("Documentation", "undocumented", ";".join(undocumentedTests)) + configParser.set("Documentation", "undocumented", + ";".join(undocumentedTests)) configParser.write(fp) @@ -164,7 +176,9 @@ def __init__(self, testcases): # A dictionary where the key is a status, and the value is a sequence of ReportTestCases self.reportcaseResults = {} for status in STATUS: - self.reportcaseResults[status] = [t for t in self.reportcases if t.status == status] + self.reportcaseResults[status] = [ + t for t in self.reportcases if t.status == status + ] self.displayName = {} self.displayName[FAILRUN] = "FAIL RUN" @@ -185,10 +199,13 @@ def __init__(self, testcases): def report(self, fp): """Write out the text report to the give file pointer""" - self.writeSummary(fp, (FAILRUN, UNEXPECTEDPASS, FAILRUNOPTIONAL, FAILCHECK, FAILCHECKMINOR, TIMEOUT, NOTRUN, - INPROGRESS, FILTERED, PASS, EXPECTEDFAIL, SKIP, BATCH, NOTBUILT)) + self.writeSummary( + fp, (FAILRUN, UNEXPECTEDPASS, FAILRUNOPTIONAL, FAILCHECK, + FAILCHECKMINOR, TIMEOUT, NOTRUN, INPROGRESS, FILTERED, PASS, + EXPECTEDFAIL, SKIP, BATCH, NOTBUILT)) self.writeLongest(fp, 5) - self.writeDetails(fp, (FAILRUN, UNEXPECTEDPASS, FAILRUNOPTIONAL, FAILCHECK, FAILCHECKMINOR, TIMEOUT, FILTERED)) + self.writeDetails(fp, (FAILRUN, UNEXPECTEDPASS, FAILRUNOPTIONAL, + FAILCHECK, FAILCHECKMINOR, TIMEOUT, FILTERED)) def writeSummary(self, fp, statuses=STATUS): """The summary groups each TestCase by its status.""" @@ -212,8 +229,10 @@ def writeSummary(self, fp, statuses=STATUS): def writeDetails(self, fp, - statuses=(FAILRUN, UNEXPECTEDPASS, FAILRUNOPTIONAL, FAILCHECK, FAILCHECKMINOR, INPROGRESS), - columns=("Status", "TestCase", "Elapsed", "Resources", "TestStep", "OutFile")): + statuses=(FAILRUN, UNEXPECTEDPASS, FAILRUNOPTIONAL, + FAILCHECK, FAILCHECKMINOR, INPROGRESS), + columns=("Status", "TestCase", "Elapsed", "Resources", + "TestStep", "OutFile")): """This function provides more information about each of the test cases""" from geos_ats import common_utilities @@ -303,13 +322,17 @@ def __init__(self, atstests): ReportText.__init__(self, testcases) def report(self, startTime, totalProcessors=None): - self.writeSummary(sys.stdout, - (FAILRUN, UNEXPECTEDPASS, FAILRUNOPTIONAL, FAILCHECK, FAILCHECKMINOR, TIMEOUT, NOTRUN, - INPROGRESS, FILTERED, RUNNING, PASS, EXPECTEDFAIL, SKIP, BATCH, NOTBUILT)) + self.writeSummary( + sys.stdout, (FAILRUN, UNEXPECTEDPASS, FAILRUNOPTIONAL, FAILCHECK, + FAILCHECKMINOR, TIMEOUT, NOTRUN, INPROGRESS, FILTERED, + RUNNING, PASS, EXPECTEDFAIL, SKIP, BATCH, NOTBUILT)) self.writeUtilization(sys.stdout, startTime, totalProcessors) self.writeLongest(sys.stdout) - self.writeDetails(sys.stdout, (FAILRUN, UNEXPECTEDPASS, FAILRUNOPTIONAL, FAILCHECK, FAILCHECKMINOR, RUNNING), - ("Status", "TestCase", "Directory", "Elapsed", "Resources", "TestStep")) + self.writeDetails(sys.stdout, + (FAILRUN, UNEXPECTEDPASS, FAILRUNOPTIONAL, FAILCHECK, + FAILCHECKMINOR, RUNNING), + ("Status", "TestCase", "Directory", "Elapsed", + "Resources", "TestStep")) def writeUtilization(self, fp, startTime, totalProcessors=None): """Machine utilization is reported""" @@ -322,13 +345,18 @@ def writeUtilization(self, fp, startTime, totalProcessors=None): if totalResourcesUsed > 0: fp.write('\n') - fp.write(f"\n TOTAL TIME : {ats.times.hms( totaltime )}") - fp.write(f"\n TOTAL PROCESSOR-TIME : {ats.times.hms(totalResourcesUsed )}") + fp.write( + f"\n TOTAL TIME : {ats.times.hms( totaltime )}") + fp.write( + f"\n TOTAL PROCESSOR-TIME : {ats.times.hms(totalResourcesUsed )}" + ) if totalProcessors: availableResources = totalProcessors * totaltime utilization = totalResourcesUsed / availableResources * 100.0 - fp.write(f" AVAIL PROCESSOR-TIME : {ats.times.hms(availableResources )}") + fp.write( + f" AVAIL PROCESSOR-TIME : {ats.times.hms(availableResources )}" + ) fp.write(f" RESOURCE UTILIZATION : {utilization:5.3g}%") @@ -398,12 +426,16 @@ def initializeReportGroups(self): testdir[dirname] = [] testdir[dirname].append(reportcase) - self.groups = [ReportGroup(key, value) for key, value in testdir.items()] + self.groups = [ + ReportGroup(key, value) for key, value in testdir.items() + ] # place groups into a dictionary keyed on the group status self.groupResults = {} for status in STATUS: - self.groupResults[status] = [g for g in self.groups if g.status == status] + self.groupResults[status] = [ + g for g in self.groups if g.status == status + ] def report(self, refresh=0): # potentially regenerate the html documentation for the test suite. @@ -413,7 +445,8 @@ def report(self, refresh=0): sp = open(self.html_filename, 'w') if refresh: - if not any(g.status in (RUNNING, NOTRUN, INPROGRESS) for g in self.groups): + if not any(g.status in (RUNNING, NOTRUN, INPROGRESS) + for g in self.groups): refresh = 0 self.writeHeader(sp, refresh) @@ -427,7 +460,8 @@ def report(self, refresh=0): else: groupColumns = ("Name", "Status") - testcaseColumns = ("Status", "Name", "TestStep", "Age", "Elapsed", "Resources", "Output") + testcaseColumns = ("Status", "Name", "TestStep", "Age", "Elapsed", + "Resources", "Output") # write the details self.writeTable(sp, groupColumns, testcaseColumns) @@ -454,10 +488,13 @@ def generateDocumentation(self): if filetime > newest: newest = filetime if os.path.getmtime(testdocfile) > newest: - logger.info(f"HTML documentation found in {os.path.relpath(testdocfile)}. Not regenerating.") + logger.info( + f"HTML documentation found in {os.path.relpath(testdocfile)}. Not regenerating." + ) return - logger.info("Generating HTML documentation files (running 'atddoc')...") + logger.info( + "Generating HTML documentation files (running 'atddoc')...") retcode = True try: geos_atsdir = os.path.realpath(os.path.dirname(__file__)) @@ -467,9 +504,12 @@ def generateDocumentation(self): except OSError as e: logger.debug(e) if retcode: - logger.info(f" Failed to create HTML documentation in {config.report_doc_dir}") + logger.info( + f" Failed to create HTML documentation in {config.report_doc_dir}" + ) else: - logger.info(f" HTML documentation created in {config.report_doc_dir}") + logger.info( + f" HTML documentation created in {config.report_doc_dir}") def writeRowHeader(self, sp, groupColumns, testcaseColumns): header = f""" @@ -574,7 +614,8 @@ def writeTable(self, sp, groupColumns, testcaseColumns): if col == "Status": statusDisplay = self.displayName[testcase.status] - retries = getattr(testcase.testcase.atsGroup, "retries", 0) + retries = getattr(testcase.testcase.atsGroup, + "retries", 0) if retries > 0: statusDisplay += "
retry: %d" % retries header += f'\n{statusDisplay}' @@ -586,22 +627,27 @@ def writeTable(self, sp, groupColumns, testcaseColumns): if config.report_doc_link: docfound = False # first check for the full problem name, with the domain extension - testhtml = os.path.join(config.report_doc_dir, test.name, testcase.testcase.name + ".html") + testhtml = os.path.join( + config.report_doc_dir, test.name, + testcase.testcase.name + ".html") if os.path.exists(testhtml): docfound = True else: # next check for the full problem name without the domain extension - testhtml = os.path.join(config.report_doc_dir, test.name, - testcase.testcase.name + ".html") + testhtml = os.path.join( + config.report_doc_dir, test.name, + testcase.testcase.name + ".html") if os.path.exists(testhtml): docfound = True else: # final check for any of the input file names for step in testcase.testcase.steps: if getattr(step.p, "deck", None): - [inputname, suffix] = getattr(step.p, "deck").rsplit('.', 1) - testhtml = os.path.join(config.report_doc_dir, test.name, - inputname + ".html") + [inputname, suffix] = getattr( + step.p, "deck").rsplit('.', 1) + testhtml = os.path.join( + config.report_doc_dir, + test.name, inputname + ".html") if os.path.exists(testhtml): # match with the first input file docfound = True @@ -661,11 +707,13 @@ def writeTable(self, sp, groupColumns, testcaseColumns): header += "\n" seen = {} - for stepnum, step in enumerate(testcase.testcase.steps): + for stepnum, step in enumerate( + testcase.testcase.steps): paths = testcase.testcase.resultPaths(step) for p in paths: # if p has already been accounted for, doesn't exist, or is an empty file, don't print it. - if (((p in seen) or not os.path.exists(p)) or (os.stat(p)[6] == 0)): + if (((p in seen) or not os.path.exists(p)) + or (os.stat(p)[6] == 0)): continue header += f"\n{os.path.basename(p)}
" seen[p] = 1 @@ -820,7 +868,8 @@ def writeSummary(self, sp): caseref = case.name retries = 0 for test in case.testcases: - retries += getattr(test.testcase.atsGroup, "retries", 0) + retries += getattr(test.testcase.atsGroup, "retries", + 0) if retries > 0: haveRetry = True casename += '*' @@ -892,7 +941,7 @@ def report(self, fp): import time start = time.time() - sleeptime = 60 # interval to check (seconds) + sleeptime = 60 # interval to check (seconds) while True: notdone = [] @@ -904,9 +953,11 @@ def report(self, fp): if notdone: rr = ReportText(self.testcases) - rr.writeSummary(sys.stdout, - (FAILRUN, UNEXPECTEDPASS, FAILRUNOPTIONAL, FAILCHECK, FAILCHECKMINOR, TIMEOUT, NOTRUN, - INPROGRESS, FILTERED, PASS, EXPECTEDFAIL, SKIP, BATCH, NOTBUILT)) + rr.writeSummary( + sys.stdout, + (FAILRUN, UNEXPECTEDPASS, FAILRUNOPTIONAL, FAILCHECK, + FAILCHECKMINOR, TIMEOUT, NOTRUN, INPROGRESS, FILTERED, + PASS, EXPECTEDFAIL, SKIP, BATCH, NOTBUILT)) time.sleep(sleeptime) else: break @@ -922,8 +973,8 @@ class ReportTestCase(object): def __init__(self, testcase): - self.testcase = testcase # test_case - self.status = None # One of the STATUS values (e.g. FAILRUN, PASS, etc.) + self.testcase = testcase # test_case + self.status = None # One of the STATUS values (e.g. FAILRUN, PASS, etc.) self.laststep = None self.diffage = None self.elapsed = 0.0 @@ -958,14 +1009,14 @@ def __init__(self, testcase): self.resources += np * dt outcome = "EXPECTEDFAIL" self.status = EXPECTEDFAIL - break # don't continue past an expected failure + break # don't continue past an expected failure if outcome == "UNEX": dt = endTime - startTime self.elapsed += dt self.resources += np * dt outcome = "UNEXPECTEDPASS" self.status = UNEXPECTEDPASS - break # don't continue past an unexpected pass + break # don't continue past an unexpected pass elif outcome == "SKIP": self.status = SKIP break @@ -1023,7 +1074,9 @@ def __init__(self, testcase): try: with open(step.p.stdout, 'r') as fp: for line in fp: - if re.search(config.report_notbuilt_regexp, line): + if re.search( + config.report_notbuilt_regexp, + line): self.status = NOTBUILT break except: @@ -1071,7 +1124,8 @@ def _getStepInfo(self, teststep): np = getattr(teststep.p, "np", 1) - if status in ("SKIP", "FILT", "INIT", "PASS", "FAIL", "TIME", "EXEC", "BACH", "EXPT", "UNEX"): + if status in ("SKIP", "FILT", "INIT", "PASS", "FAIL", "TIME", "EXEC", + "BACH", "EXPT", "UNEX"): return (status, np, startTime, endTime) else: return ("SKIP", np, startTime, endTime) @@ -1098,7 +1152,8 @@ def getowner(dirname, testcase=None): owner = "" if not config.report_doc_link: try: - atdfile = os.path.join(config.report_doc_dir, dirname, dirname + ".atd") + atdfile = os.path.join(config.report_doc_dir, dirname, + dirname + ".atd") with open(atdfile, "r") as fp: for line in fp: match = re.search("CUSTODIAN:: +(.*)$", line) diff --git a/geos_ats_package/geos_ats/rules.py b/geos_ats_package/geos_ats/rules.py index 542e547..a6c5c13 100644 --- a/geos_ats_package/geos_ats/rules.py +++ b/geos_ats_package/geos_ats/rules.py @@ -103,7 +103,8 @@ class SetupRules(Rule): def __init__(self, toggles, minToggle=0, maxToggle=None): self.setupMin = minToggle self.setupMax = maxToggle - Rule.__init__(self, SetupRules.numToggles, SetupRules.numCombinations, toggles) + Rule.__init__(self, SetupRules.numToggles, SetupRules.numCombinations, + toggles) def refresh(self): mtoggles = self.toggles[self.setupMin:self.setupMax] @@ -112,14 +113,16 @@ def refresh(self): self.isTenthCycle = mtoggles[1] self.baseName = "foo%i" % self.id - self.baseName = "%s%s" % (self.baseName, "_001" if underscoredName else "") + self.baseName = "%s%s" % (self.baseName, + "_001" if underscoredName else "") self.repStrings["@@BASE@@"] = self.baseName self.inputDeck = "%s.in" % self.baseName self.repStrings["@@DECK@@"] = self.inputDeck self.restartBaseName = "%s_001" % self.baseName - self.restartName = "%s_%s" % (self.restartBaseName, "00010" if self.isTenthCycle else "00000") + self.restartName = "%s_%s" % (self.restartBaseName, "00010" + if self.isTenthCycle else "00000") self.repStrings["@@RF@@"] = self.restartName super(SetupRules, self).refresh() @@ -141,12 +144,13 @@ class CommandLineRules(Rule): def __init__(self, toggles, minToggle=0, maxToggle=None): self.clMin = minToggle self.clMax = maxToggle - Rule.__init__(self, CommandLineRules.numToggles, CommandLineRules.numCombinations, toggles) + Rule.__init__(self, CommandLineRules.numToggles, + CommandLineRules.numCombinations, toggles) def refresh(self): mtoggles = self.toggles[self.clMin:self.clMax] - self.probDefined = mtoggles[0] # use the -prob flag - self.restartDefined = mtoggles[1] # use the -rf flag + self.probDefined = mtoggles[0] # use the -prob flag + self.restartDefined = mtoggles[1] # use the -rf flag # self.prob = "-prob %s" % "@@BASE@@" if self.probDefined else "" # self.rf = "-rf %s" % "@@RF@@" if self.restartDefined else "" @@ -163,16 +167,22 @@ def main(): generator = GenRules(SetupRules) for rule in generator: - vals = (rule.GetInputDeckName(), rule.GetInitialRestartName(), rule.GetPosition()) - logger.debug(rule.replaceString("InputDeck: %s\tRestartFile: %s\tPos: %f" % vals)) + vals = (rule.GetInputDeckName(), rule.GetInitialRestartName(), + rule.GetPosition()) + logger.debug( + rule.replaceString("InputDeck: %s\tRestartFile: %s\tPos: %f" % + vals)) DeclareCompoundRuleClass("SetupCommand", SetupRules, CommandLineRules) logger.debug(SetupCommand.numCombinations) generator = GenRules(SetupCommand) logger.debug("compound:") for rule in generator: - vals = (rule.GetInputDeckName(), rule.GetInitialRestartName(), rule.GetPosition(), rule.prob, rule.rf) - logger.debug(rule.replaceString("InputDeck: %s\tRestartFile: %s\tPos: %f\t%s\t%s" % vals)) + vals = (rule.GetInputDeckName(), rule.GetInitialRestartName(), + rule.GetPosition(), rule.prob, rule.rf) + logger.debug( + rule.replaceString( + "InputDeck: %s\tRestartFile: %s\tPos: %f\t%s\t%s" % vals)) return diff --git a/geos_ats_package/geos_ats/scheduler.py b/geos_ats_package/geos_ats/scheduler.py index 108deab..5b47487 100644 --- a/geos_ats_package/geos_ats/scheduler.py +++ b/geos_ats_package/geos_ats/scheduler.py @@ -4,9 +4,9 @@ import time from geos_ats.configuration_record import config from geos_ats.common_utilities import Log -from ats.log import log # type: ignore[import] -from ats.atsut import PASSED, FAILED, CREATED, EXPECTED, TIMEDOUT # type: ignore[import] -from ats.schedulers import StandardScheduler # type: ignore[import] +from ats.log import log # type: ignore[import] +from ats.atsut import PASSED, FAILED, CREATED, EXPECTED, TIMEDOUT # type: ignore[import] +from ats.schedulers import StandardScheduler # type: ignore[import] class GeosAtsScheduler(StandardScheduler): @@ -33,16 +33,22 @@ def testEnded(self, test): g.recordOutput() if not hasattr(g, "retries"): g.retries = 0 - if test.status in [FAILED, TIMEDOUT] and g.retries < config.max_retry: + if test.status in [FAILED, TIMEDOUT + ] and g.retries < config.max_retry: with open(test.geos_atsTestCase.errname) as f: erroutput = f.read() if re.search(config.retry_err_regexp, erroutput): f.close() - os.rename(test.geos_atsTestCase.errname, "%s.%d" % (test.geos_atsTestCase.errname, g.retries)) - os.rename(test.geos_atsTestCase.outname, "%s.%d" % (test.geos_atsTestCase.outname, g.retries)) + os.rename( + test.geos_atsTestCase.errname, "%s.%d" % + (test.geos_atsTestCase.errname, g.retries)) + os.rename( + test.geos_atsTestCase.outname, "%s.%d" % + (test.geos_atsTestCase.outname, g.retries)) g.retries += 1 for t in g: t.status = CREATED - Log(f"# retry test={test.geos_atsTestCase.name} ({g.retries}/{config.max_retry})") + Log(f"# retry test={test.geos_atsTestCase.name} ({g.retries}/{config.max_retry})" + ) return self.groups.remove(g) diff --git a/geos_ats_package/geos_ats/suite_settings.py b/geos_ats_package/geos_ats/suite_settings.py index 824079f..594bc51 100644 --- a/geos_ats_package/geos_ats/suite_settings.py +++ b/geos_ats_package/geos_ats/suite_settings.py @@ -3,7 +3,7 @@ testLabels = [ "geos", - "auto", # label used when the tests were automatically converted. Will be deprecated. + "auto", # label used when the tests were automatically converted. Will be deprecated. ] testOwners = [("corbett5", "Ben Corbett")] diff --git a/geos_ats_package/geos_ats/test_builder.py b/geos_ats_package/geos_ats/test_builder.py index fe2fb45..12c7aaa 100644 --- a/geos_ats_package/geos_ats/test_builder.py +++ b/geos_ats_package/geos_ats/test_builder.py @@ -63,7 +63,9 @@ def collect_block_names(fname): tree = etree.parse(actual_fname, parser=parser) root = tree.getroot() for child in root.getchildren(): - results[child.tag] = [grandchild.tag for grandchild in child.getchildren()] + results[child.tag] = [ + grandchild.tag for grandchild in child.getchildren() + ] # Collect block names in included files for included_root in root.findall('Included'): @@ -122,14 +124,16 @@ def generate_geos_tests(decks: Iterable[TestDeck]): checks.append('restart') steps.append( geos(deck=xml_file, - name="{:d}to{:d}".format(deck.restart_step, deck.check_step), + name="{:d}to{:d}".format(deck.restart_step, + deck.check_step), np=N, ngpu=N, x_partitions=nx, y_partitions=ny, z_partitions=nz, - restart_file=os.path.join(testcase_name, - "{}_restart_{:09d}".format(base_name, deck.restart_step)), + restart_file=os.path.join( + testcase_name, "{}_restart_{:09d}".format( + base_name, deck.restart_step)), baseline_pattern=f"{base_name}_restart_[0-9]+\.root", allow_rebaseline=False, restartcheck_params=restartcheck_params)) @@ -138,7 +142,8 @@ def generate_geos_tests(decks: Iterable[TestDeck]): AtsTest.stick(checks=','.join(checks)) AtsTest.stick(solvers=','.join(xml_blocks.get('Solvers', []))) AtsTest.stick(outputs=','.join(xml_blocks.get('Outputs', []))) - AtsTest.stick(constitutive_models=','.join(xml_blocks.get('Constitutive', []))) + AtsTest.stick(constitutive_models=','.join( + xml_blocks.get('Constitutive', []))) TestCase(name=testcase_name, desc=deck.description, label="auto", diff --git a/geos_ats_package/geos_ats/test_case.py b/geos_ats_package/geos_ats/test_case.py index fa317d3..89b055c 100644 --- a/geos_ats_package/geos_ats/test_case.py +++ b/geos_ats_package/geos_ats/test_case.py @@ -1,4 +1,4 @@ -import ats # type: ignore[import] +import ats # type: ignore[import] import os import sys import shutil @@ -38,8 +38,8 @@ def __init__(self, enabled=True, duration="1h", ppn=0, altname=None): logger.error(e) Error("bad time specification: %s" % duration) - self.ppn = ppn # processor per node - self.altname = altname # alternate name to use when launcing the batch job + self.ppn = ppn # processor per node + self.altname = altname # alternate name to use when launcing the batch job class TestCase(object): @@ -55,7 +55,14 @@ def __init__(self, name, desc, label=None, labels=None, steps=[], **kw): Log(str(e)) raise Exception(e) - def initialize(self, name, desc, label=None, labels=None, steps=[], batch=Batch(enabled=False), **kw): + def initialize(self, + name, + desc, + label=None, + labels=None, + steps=[], + batch=Batch(enabled=False), + **kw): self.name = name self.desc = desc @@ -89,7 +96,9 @@ def initialize(self, name, desc, label=None, labels=None, steps=[], batch=Batch( self.errname = os.path.join(self.path, "%s.err" % self.name) self.dictionary["name"] = self.name self.dictionary["output_directory"] = self.path - self.dictionary["baseline_dir"] = os.path.join(os.getcwd(), BASELINE_PATH, self.dirname) + self.dictionary["baseline_dir"] = os.path.join(os.getcwd(), + BASELINE_PATH, + self.dirname) self.dictionary["testcase_out"] = self.outname self.dictionary["testcase_err"] = self.errname self.dictionary["testcase_name"] = self.name @@ -119,7 +128,8 @@ def initialize(self, name, desc, label=None, labels=None, steps=[], batch=Batch( else: self.independent = self.dictionary.get("independent", False) if self.independent not in (True, False): - Error("independent must be either True or False: %s" % str(self.independent)) + Error("independent must be either True or False: %s" % + str(self.independent)) # check for depends self.depends = self.dictionary.get("depends", None) @@ -152,7 +162,8 @@ def initialize(self, name, desc, label=None, labels=None, steps=[], batch=Batch( extraStep.insertStep(newSteps) self.steps = newSteps else: - Log("# SKIP test=%s : testmodifier=%s" % (self.name, config.testmodifier)) + Log("# SKIP test=%s : testmodifier=%s" % + (self.name, config.testmodifier)) self.status = reporting.SKIP return @@ -166,7 +177,8 @@ def initialize(self, name, desc, label=None, labels=None, steps=[], batch=Batch( npMax = self.findMaxNumberOfProcessors() if config.filter_maxprocessors != -1: if npMax > config.filter_maxprocessors: - Log("# FILTER test=%s : max processors(%d > %d)" % (self.name, npMax, config.filter_maxprocessors)) + Log("# FILTER test=%s : max processors(%d > %d)" % + (self.name, npMax, config.filter_maxprocessors)) self.status = reporting.FILTERED return @@ -176,22 +188,26 @@ def initialize(self, name, desc, label=None, labels=None, steps=[], batch=Batch( # filter based on not enough resources if action in ("run", "rerun", "continue"): tests = [ - not ats.tests.AtsTest.getOptions().get("testmode"), not self.batch.enabled, + not ats.tests.AtsTest.getOptions().get("testmode"), + not self.batch.enabled, hasattr(ats.manager.machine, "getNumberOfProcessors") ] if all(tests): - totalNumberOfProcessors = getattr(ats.manager.machine, "getNumberOfProcessors")() + totalNumberOfProcessors = getattr(ats.manager.machine, + "getNumberOfProcessors")() if npMax > totalNumberOfProcessors: - Log("# SKIP test=%s : not enough processors to run (%d > %d)" % - (self.name, npMax, totalNumberOfProcessors)) + Log("# SKIP test=%s : not enough processors to run (%d > %d)" + % (self.name, npMax, totalNumberOfProcessors)) self.status = reporting.SKIP return # If the machine doesn't specify a number of GPUs then it has none. - totalNumberOfGPUs = getattr(ats.manager.machine, "getNumberOfGPUS", lambda: 1e90)() + totalNumberOfGPUs = getattr(ats.manager.machine, + "getNumberOfGPUS", lambda: 1e90)() if ngpuMax > totalNumberOfGPUs: - Log("# SKIP test=%s : not enough gpus to run (%d > %d)" % (self.name, ngpuMax, totalNumberOfGPUs)) + Log("# SKIP test=%s : not enough gpus to run (%d > %d)" % + (self.name, ngpuMax, totalNumberOfGPUs)) self.status = reporting.SKIP return @@ -199,7 +215,9 @@ def initialize(self, name, desc, label=None, labels=None, steps=[], batch=Batch( if action in ("run", "rerun", "continue"): checkoption = ats.tests.AtsTest.getOptions().get("checkoption") if checkoption == "none": - self.steps = [step for step in self.steps if not step.isCheck()] + self.steps = [ + step for step in self.steps if not step.isCheck() + ] elif action == "check": self.steps = [step for step in self.steps if step.isCheck()] @@ -312,7 +330,7 @@ def _remove(path): else: os.remove(p) except OSError: - pass # so that two simultaneous clean operations don't fail + pass # so that two simultaneous clean operations don't fail # clean self.testClean() @@ -381,8 +399,8 @@ def testCreate(self): if self.depends: priorTestCase = TESTS.get(self.depends, None) if priorTestCase is None: - Log("Warning: Test %s depends on testcase %s, which is not scheduled to run" % - (self.name, self.depends)) + Log("Warning: Test %s depends on testcase %s, which is not scheduled to run" + % (self.name, self.depends)) else: if priorTestCase.steps: atsTest = getattr(priorTestCase.steps[-1], "atsTest", None) @@ -395,7 +413,8 @@ def testCreate(self): args = step.makeArgs() # set the label - label = "%s/%s_%d_%s" % (self.dirname, self.name, stepnum + 1, step.label()) + label = "%s/%s_%d_%s" % (self.dirname, self.name, stepnum + 1, + step.label()) # call either 'test' or 'testif' if atsTest is None: @@ -419,7 +438,8 @@ def testCreate(self): np=np, ngpu=ngpu, label=label, - serial=(not step.useMPI() and not config.script_launch), + serial=(not step.useMPI() + and not config.script_launch), independent=self.independent, batch=self.batch.enabled, **kw) @@ -435,7 +455,8 @@ def testCreate(self): self.status.addStep(atsTest) # set the expected result - if step.expectedResult() == "FAIL" or step.expectedResult() is False: + if step.expectedResult() == "FAIL" or step.expectedResult( + ) is False: atsTest.expectedResult = ats.FAILED # The ATS does not permit tests to depend on failed tests. # therefore we need to break here @@ -479,9 +500,13 @@ def testRebaseline(self): if config.rebaseline_ask: while 1: if config.rebaseline_undo: - logger.info(f"Are you sure you want to undo the rebaseline for TestCase '{self.name}'?", flush=True) + logger.info( + f"Are you sure you want to undo the rebaseline for TestCase '{self.name}'?", + flush=True) else: - logger.info(f"Are you sure you want to rebaseline TestCase '{self.name}'?", flush=True) + logger.info( + f"Are you sure you want to rebaseline TestCase '{self.name}'?", + flush=True) x = input('[y/n] ') x = x.strip() @@ -504,7 +529,8 @@ def testRebaselineFailed(self): self.testRebaseline() def testList(self): - Log("# test=%s : labels=%s" % (self.name.ljust(32), " ".join(self.labels))) + Log("# test=%s : labels=%s" % + (self.name.ljust(32), " ".join(self.labels))) def testReport(self): self.status = test_caseStatus(self) @@ -523,7 +549,8 @@ def handleLabels(self, label, labels): for x in self.labels: if x not in testLabels: - Error(f"unknown label {x}. run 'geos_ats -i labels' for a list") + Error( + f"unknown label {x}. run 'geos_ats -i labels' for a list") class test_caseStatus(object): @@ -549,7 +576,8 @@ def writeStatusFile(self): def testKey(self, step): np = getattr(step.p, "np", 1) - key = str((np, step.label(), step.executable(), step.makeArgsForStatusKey())) + key = str( + (np, step.label(), step.executable(), step.makeArgsForStatusKey())) return key def testData(self, test): @@ -708,16 +736,25 @@ def infoTestCase(*args): table = TextTable(3) table.addRow("name", "required", "The name of the test problem") table.addRow("desc", "required", "A brief description") - table.addRow("label", "required", "A string or sequence of strings to tag the TestCase. See info topic 'labels'") - table.addRow("owner", "optional", - "A string or sequence of strings of test owners for this TestCase. See info topic 'owners'") table.addRow( - "batch", "optional", "A Batch object. Batch(enabled=True, duration='1h', ppn=0, altname=None)." + "label", "required", + "A string or sequence of strings to tag the TestCase. See info topic 'labels'" + ) + table.addRow( + "owner", "optional", + "A string or sequence of strings of test owners for this TestCase. See info topic 'owners'" + ) + table.addRow( + "batch", "optional", + "A Batch object. Batch(enabled=True, duration='1h', ppn=0, altname=None)." " ppn is short for processors per node (0 means to use the global default)." " altname will be used for the batch job's name if supplied, otherwise the full name of the test case is used." ), - table.addRow("depends", "optional", "The name of a testcase that this testcase depends") - table.addRow("steps", "required", "A sequence of TestSteps objects. See info topic 'teststeps'") + table.addRow("depends", "optional", + "The name of a testcase that this testcase depends") + table.addRow( + "steps", "required", + "A sequence of TestSteps objects. See info topic 'teststeps'") table.printTable() diff --git a/geos_ats_package/geos_ats/test_modifier.py b/geos_ats_package/geos_ats/test_modifier.py index c6d7cb2..b625be4 100644 --- a/geos_ats_package/geos_ats/test_modifier.py +++ b/geos_ats_package/geos_ats/test_modifier.py @@ -1,4 +1,4 @@ -import ats # type: ignore[import] +import ats # type: ignore[import] from geos_ats import common_utilities from geos_ats.configuration_record import config from geos_ats import test_steps diff --git a/geos_ats_package/geos_ats/test_steps.py b/geos_ats_package/geos_ats/test_steps.py index 129e162..e62a08a 100644 --- a/geos_ats_package/geos_ats/test_steps.py +++ b/geos_ats_package/geos_ats/test_steps.py @@ -1,5 +1,5 @@ import os -import ats # type: ignore[import] +import ats # type: ignore[import] import glob import shutil import sys @@ -89,37 +89,56 @@ class TestStepBase(object): " clean may be a string or a list of strings. The strings may contain" " wildcard characters."), TestParam( - "timelimit", "maximum time the step is allowed to run before it is considerend a TIMEOUT." + "timelimit", + "maximum time the step is allowed to run before it is considerend a TIMEOUT." " Specified as a string such as: 1h30m, 60m, etc.", "None"), - TestParam("stdout", "If set, the stdout will be placed in the named file, in the TestCase directory", None), - TestParam("stderr", "If set, the stderr will be placed in the named file, in the TestCase directory", None), + TestParam( + "stdout", + "If set, the stdout will be placed in the named file, in the TestCase directory", + None), + TestParam( + "stderr", + "If set, the stderr will be placed in the named file, in the TestCase directory", + None), TestParam("expectedResult", "'PASS' or 'FAIL'", "'PASS'"), - TestParam("delayed", "Whether execution of the step will be delayed", "False"), - TestParam("minor", "Whether failure of this step is minor issue", "False"), + TestParam("delayed", "Whether execution of the step will be delayed", + "False"), + TestParam("minor", "Whether failure of this step is minor issue", + "False"), ) commonParams = { "name": - TestParam("name", "Used to give other params default values.", "The name of the TestCase"), + TestParam("name", "Used to give other params default values.", + "The name of the TestCase"), "deck": - TestParam("deck", "Name of the input file. Setting deck to False means no deck is used.", ".in"), + TestParam( + "deck", + "Name of the input file. Setting deck to False means no deck is used.", + ".in"), "np": TestParam("np", "The number of processors to run on.", 1), "ngpu": TestParam("ngpu", "The number of gpus to run on when available.", 0), "check": TestParam( - "check", "True or False. determines whether the default checksteps will " + "check", + "True or False. determines whether the default checksteps will " "be automatically be added after this step.", "True"), "baseline_dir": - TestParam("baseline_dir", "subdirectory of config.testbaseline_dir where the test " - "baselines are located.", ""), + TestParam( + "baseline_dir", + "subdirectory of config.testbaseline_dir where the test " + "baselines are located.", ""), "output_directory": - TestParam("output_directory", "subdirectory where the test log, params, rin, and " - "timehistory files are located.", ""), + TestParam( + "output_directory", + "subdirectory where the test log, params, rin, and " + "timehistory files are located.", ""), "rebaseline": TestParam( - "rebaseline", "additional files to rebaseline during the rebaseline action." + "rebaseline", + "additional files to rebaseline during the rebaseline action." " rebaseline may be a string or a list of strings."), "timehistfile": TestParam("timehistfile", "name of the file containing all the" @@ -129,12 +148,14 @@ class TestStepBase(object): "//"), "allow_rebaseline": TestParam( - "allow_rebaseline", "True if the second file should be re-baselined during a rebaseline action." + "allow_rebaseline", + "True if the second file should be re-baselined during a rebaseline action." " False if the second file should not be rebaselined.", "True"), "testcase_name": TestParam("testcase_name", "The name of the testcase"), "testcase_out": - TestParam("testcase_out", "The file where stdout for the testcase is accumulated"), + TestParam("testcase_out", + "The file where stdout for the testcase is accumulated"), } # namespace to place the params. @@ -263,7 +284,9 @@ def _remove(self, paths, noclean): else: os.remove(p) except OSError as e: - logger.debug(e) # so that two simultaneous clean operations don't fail + logger.debug( + e + ) # so that two simultaneous clean operations don't fail def getCheckOption(self): return ats.tests.AtsTest.getOptions().get("checkoption") @@ -384,15 +407,21 @@ class geos(TestStepBase): command = "geosx [-i ] [-r ] [-x ] [-y ] [-z ] [-s ] [-n ] [-o ] [ --suppress-pinned ] " params = TestStepBase.defaultParams + ( - TestStepBase.commonParams["name"], TestStepBase.commonParams["deck"], TestStepBase.commonParams["np"], - TestStepBase.commonParams["ngpu"], TestStepBase.commonParams["check"], - TestStepBase.commonParams["baseline_dir"], TestStepBase.commonParams["output_directory"], + TestStepBase.commonParams["name"], TestStepBase.commonParams["deck"], + TestStepBase.commonParams["np"], TestStepBase.commonParams["ngpu"], + TestStepBase.commonParams["check"], + TestStepBase.commonParams["baseline_dir"], + TestStepBase.commonParams["output_directory"], TestParam("restart_file", "The name of the restart file."), - TestParam("x_partitions", "The number of partitions in the x direction."), - TestParam("y_partitions", "The number of partitions in the y direction."), + TestParam("x_partitions", + "The number of partitions in the x direction."), + TestParam("y_partitions", + "The number of partitions in the y direction."), TestParam("z_partitions", - "The number of partitions in the z direction."), TestParam("schema_level", "The schema level."), - TestParam("suppress-pinned", "Option to suppress use of pinned memory for MPI buffers."), + "The number of partitions in the z direction."), + TestParam("schema_level", "The schema level."), + TestParam("suppress-pinned", + "Option to suppress use of pinned memory for MPI buffers."), TestParam("trace_data_migration", "Trace host-device data migration.")) checkstepnames = ["restartcheck"] @@ -521,17 +550,29 @@ class restartcheck(CheckTestStepBase): command = """restartcheck [-r RELATIVE] [-a ABSOLUTE] [-o OUTPUT] [-e EXCLUDE [EXCLUDE ...]] [-w] file_pattern baseline_pattern""" params = TestStepBase.defaultParams + CheckTestStepBase.checkParams + ( - TestStepBase.commonParams["deck"], TestStepBase.commonParams["name"], TestStepBase.commonParams["np"], - TestStepBase.commonParams["allow_rebaseline"], TestStepBase.commonParams["baseline_dir"], + TestStepBase.commonParams["deck"], TestStepBase.commonParams["name"], + TestStepBase.commonParams["np"], + TestStepBase.commonParams["allow_rebaseline"], + TestStepBase.commonParams["baseline_dir"], TestStepBase.commonParams["output_directory"], - TestParam("file_pattern", "Regex pattern to match file written out by geos."), - TestParam("baseline_pattern", "Regex pattern to match file to compare against."), - TestParam("rtol", - "Relative tolerance, default is 0.0."), TestParam("atol", "Absolute tolerance, default is 0.0."), - TestParam("exclude", "Regular expressions matching groups to exclude from the check, default is None."), - TestParam("warnings_are_errors", "Treat warnings as errors, default is True."), - TestParam("suppress_output", "Whether to write output to stdout, default is True."), - TestParam("skip_missing", "Whether to skip missing values in target or baseline files, default is False.")) + TestParam("file_pattern", + "Regex pattern to match file written out by geos."), + TestParam("baseline_pattern", + "Regex pattern to match file to compare against."), + TestParam("rtol", "Relative tolerance, default is 0.0."), + TestParam("atol", "Absolute tolerance, default is 0.0."), + TestParam( + "exclude", + "Regular expressions matching groups to exclude from the check, default is None." + ), + TestParam("warnings_are_errors", + "Treat warnings as errors, default is True."), + TestParam("suppress_output", + "Whether to write output to stdout, default is True."), + TestParam( + "skip_missing", + "Whether to skip missing values in target or baseline files, default is False." + )) def __init__(self, restartcheck_params, **kw): """ @@ -564,12 +605,15 @@ def update(self, dictionary): self.requireParam("output_directory") if self.p.file_pattern is None: - self.p.file_pattern = getGeosProblemName(self.p.deck, self.p.name) + r"_restart_[0-9]+\.root" + self.p.file_pattern = getGeosProblemName( + self.p.deck, self.p.name) + r"_restart_[0-9]+\.root" if self.p.baseline_pattern is None: self.p.baseline_pattern = self.p.file_pattern - self.restart_file_regex = os.path.join(self.p.output_directory, self.p.file_pattern) - self.restart_baseline_regex = os.path.join(self.p.baseline_dir, self.p.baseline_pattern) + self.restart_file_regex = os.path.join(self.p.output_directory, + self.p.file_pattern) + self.restart_baseline_regex = os.path.join(self.p.baseline_dir, + self.p.baseline_pattern) if self.p.allow_rebaseline is None: self.p.allow_rebaseline = True @@ -610,8 +654,9 @@ def rebaseline(self): root_file_path = findMaxMatchingFile(self.restart_file_regex) if root_file_path is None: - raise IOError("File not found matching the pattern %s in directory %s." % - (self.restart_file_regex, os.getcwd())) + raise IOError( + "File not found matching the pattern %s in directory %s." % + (self.restart_file_regex, os.getcwd())) baseline_dir = os.path.dirname(self.restart_baseline_regex) root_baseline_path = findMaxMatchingFile(self.restart_baseline_regex) @@ -626,13 +671,21 @@ def rebaseline(self): os.makedirs(baseline_dir, exist_ok=True) # Copy the root file into the baseline directory. - shutil.copy2(root_file_path, os.path.join(baseline_dir, os.path.basename(root_file_path))) + shutil.copy2( + root_file_path, + os.path.join(baseline_dir, os.path.basename(root_file_path))) # Copy the directory holding the data files into the baseline directory. data_dir_path = os.path.splitext(root_file_path)[0] - shutil.copytree(data_dir_path, os.path.join(baseline_dir, os.path.basename(data_dir_path))) + shutil.copytree( + data_dir_path, + os.path.join(baseline_dir, os.path.basename(data_dir_path))) def resultPaths(self): - return [os.path.join(self.p.output_directory, "%s.restartcheck" % os.path.splitext(self.p.file_pattern)[0])] + return [ + os.path.join( + self.p.output_directory, + "%s.restartcheck" % os.path.splitext(self.p.file_pattern)[0]) + ] def clean(self): self._clean(self.resultPaths()) @@ -651,16 +704,22 @@ class curvecheck(CheckTestStepBase): command = """curve_check.py [-h] [-c CURVE [CURVE ...]] [-t TOLERANCE] [-w] [-o OUTPUT] [-n N_COLUMN] [-u {milliseconds,seconds,minutes,hours,days,years}] filename baseline""" params = TestStepBase.defaultParams + CheckTestStepBase.checkParams + ( - TestStepBase.commonParams["deck"], TestStepBase.commonParams["name"], TestStepBase.commonParams["np"], - TestStepBase.commonParams["allow_rebaseline"], TestStepBase.commonParams["baseline_dir"], + TestStepBase.commonParams["deck"], TestStepBase.commonParams["name"], + TestStepBase.commonParams["np"], + TestStepBase.commonParams["allow_rebaseline"], + TestStepBase.commonParams["baseline_dir"], TestStepBase.commonParams["output_directory"], - TestParam("filename", "Name of the target curve file written by GEOS."), + TestParam("filename", + "Name of the target curve file written by GEOS."), TestParam("curves", "A list of parameter, setname value pairs."), TestParam( "tolerance", "Curve check tolerance (||x-y||/N), can be specified as a single value or a list of floats corresponding to the curves." - ), TestParam("warnings_are_errors", "Treat warnings as errors, default is True."), - TestParam("script_instructions", "A list of (path, function, value, setname) entries"), + ), + TestParam("warnings_are_errors", + "Treat warnings as errors, default is True."), + TestParam("script_instructions", + "A list of (path, function, value, setname) entries"), TestParam("time_units", "Time units to use for plots.")) def __init__(self, curvecheck_params, **kw): @@ -711,7 +770,8 @@ def update(self, dictionary): self.requireParam("output_directory") self.baseline_file = os.path.join(self.p.baseline_dir, self.p.filename) - self.target_file = os.path.join(self.p.output_directory, self.p.filename) + self.target_file = os.path.join(self.p.output_directory, + self.p.filename) self.figure_root = os.path.join(self.p.output_directory, 'curve_check') if self.p.allow_rebaseline is None: @@ -865,7 +925,9 @@ def all(): for s in steps: stepclass = globals()[s] doc = getattr(stepclass, "doc", None) - topic.addTopic(s, textwrap.dedent(doc).strip(), lambda ss=s: infoTestStep(ss)) + topic.addTopic(s, + textwrap.dedent(doc).strip(), + lambda ss=s: infoTestStep(ss)) topic.process(args) diff --git a/geos_ats_package/geos_ats/user_utilities.py b/geos_ats_package/geos_ats/user_utilities.py index 775e3a1..24ebd89 100644 --- a/geos_ats_package/geos_ats/user_utilities.py +++ b/geos_ats_package/geos_ats/user_utilities.py @@ -1,4 +1,4 @@ -import ats # type: ignore[import] +import ats # type: ignore[import] import os diff --git a/geosx_mesh_doctor/checks/check_fractures.py b/geosx_mesh_doctor/checks/check_fractures.py index b2c241b..30d83d3 100644 --- a/geosx_mesh_doctor/checks/check_fractures.py +++ b/geosx_mesh_doctor/checks/check_fractures.py @@ -18,17 +18,13 @@ vtkCell, ) from vtkmodules.vtkCommonCore import ( - vtkPoints, -) + vtkPoints, ) from vtkmodules.vtkIOXML import ( - vtkXMLMultiBlockDataReader, -) + vtkXMLMultiBlockDataReader, ) from vtkmodules.util.numpy_support import ( - vtk_to_numpy, -) + vtk_to_numpy, ) from vtk_utils import ( - vtk_iter, -) + vtk_iter, ) @dataclass(frozen=True) @@ -47,7 +43,9 @@ class Result: errors: Sequence[tuple[int, int, int]] -def __read_multiblock(vtk_input_file: str, matrix_name: str, fracture_name: str) -> Tuple[vtkUnstructuredGrid, vtkUnstructuredGrid]: +def __read_multiblock( + vtk_input_file: str, matrix_name: str, + fracture_name: str) -> Tuple[vtkUnstructuredGrid, vtkUnstructuredGrid]: reader = vtkXMLMultiBlockDataReader() reader.SetFileName(vtk_input_file) reader.Update() @@ -62,30 +60,39 @@ def __read_multiblock(vtk_input_file: str, matrix_name: str, fracture_name: str) return matrix, fracture -def format_collocated_nodes(fracture_mesh: vtkUnstructuredGrid) -> Sequence[Iterable[int]]: +def format_collocated_nodes( + fracture_mesh: vtkUnstructuredGrid) -> Sequence[Iterable[int]]: """ Extract the collocated nodes information from the mesh and formats it in a python way. :param fracture_mesh: The mesh of the fracture (with 2d cells). :return: An iterable over all the buckets of collocated nodes. """ - collocated_nodes: numpy.ndarray = vtk_to_numpy(fracture_mesh.GetPointData().GetArray("collocated_nodes")) + collocated_nodes: numpy.ndarray = vtk_to_numpy( + fracture_mesh.GetPointData().GetArray("collocated_nodes")) if len(collocated_nodes.shape) == 1: - collocated_nodes: numpy.ndarray = collocated_nodes.reshape((collocated_nodes.shape[0], 1)) - generator = (tuple(sorted(bucket[bucket > -1])) for bucket in collocated_nodes) + collocated_nodes: numpy.ndarray = collocated_nodes.reshape( + (collocated_nodes.shape[0], 1)) + generator = (tuple(sorted(bucket[bucket > -1])) + for bucket in collocated_nodes) return tuple(generator) -def __check_collocated_nodes_positions(matrix_points: Sequence[Tuple[float, float, float]], - fracture_points: Sequence[Tuple[float, float, float]], - g2l: Sequence[int], - collocated_nodes: Iterable[Iterable[int]]) -> Collection[Tuple[int, Iterable[int], Iterable[Tuple[float, float, float]]]]: +def __check_collocated_nodes_positions( + matrix_points: Sequence[Tuple[float, float, float]], + fracture_points: Sequence[Tuple[float, float, float]], g2l: Sequence[int], + collocated_nodes: Iterable[Iterable[int]] +) -> Collection[Tuple[int, Iterable[int], Iterable[Tuple[float, float, + float]]]]: issues = [] for li, bucket in enumerate(collocated_nodes): - matrix_nodes = (fracture_points[li], ) + tuple(map(lambda gi: matrix_points[g2l[gi]], bucket)) + matrix_nodes = (fracture_points[li], ) + tuple( + map(lambda gi: matrix_points[g2l[gi]], bucket)) m = numpy.array(matrix_nodes) rank: int = numpy.linalg.matrix_rank(m) if rank > 1: - issues.append((li, bucket, tuple(map(lambda gi: matrix_points[g2l[gi]], bucket)))) + issues.append((li, bucket, + tuple(map(lambda gi: matrix_points[g2l[gi]], + bucket)))) return issues @@ -100,8 +107,7 @@ def my_iter(ccc): def __check_neighbors(matrix: vtkUnstructuredGrid, - fracture: vtkUnstructuredGrid, - g2l: Sequence[int], + fracture: vtkUnstructuredGrid, g2l: Sequence[int], collocated_nodes: Sequence[Iterable[int]]): fracture_nodes: Set[int] = set() for bucket in collocated_nodes: @@ -119,7 +125,8 @@ def __check_neighbors(matrix: vtkUnstructuredGrid, if point_ids <= fracture_nodes: fracture_faces.add(point_ids) # Finding the cells - for c in tqdm(range(fracture.GetNumberOfCells()), desc="Finding neighbor cell pairs"): + for c in tqdm(range(fracture.GetNumberOfCells()), + desc="Finding neighbor cell pairs"): cell: vtkCell = fracture.GetCell(c) cns: Set[FrozenSet[int]] = set() # subset of collocated_nodes point_ids = frozenset(vtk_iter(cell.GetPointIds())) @@ -134,15 +141,19 @@ def __check_neighbors(matrix: vtkUnstructuredGrid, if f in fracture_faces: found += 1 if found != 2: - logging.warning(f"Something went wrong since we should have found 2 fractures faces (we found {found}) for collocated nodes {cns}.") + logging.warning( + f"Something went wrong since we should have found 2 fractures faces (we found {found}) for collocated nodes {cns}." + ) def __check(vtk_input_file: str, options: Options) -> Result: - matrix, fracture = __read_multiblock(vtk_input_file, options.matrix_name, options.fracture_name) + matrix, fracture = __read_multiblock(vtk_input_file, options.matrix_name, + options.fracture_name) matrix_points: vtkPoints = matrix.GetPoints() fracture_points: vtkPoints = fracture.GetPoints() - collocated_nodes: Sequence[Iterable[int]] = format_collocated_nodes(fracture) + collocated_nodes: Sequence[Iterable[int]] = format_collocated_nodes( + fracture) assert matrix.GetPointData().GetGlobalIds() and matrix.GetCellData().GetGlobalIds() and \ fracture.GetPointData().GetGlobalIds() and fracture.GetCellData().GetGlobalIds() @@ -152,10 +163,9 @@ def __check(vtk_input_file: str, options: Options) -> Result: g2l[glo] = loc g2l.flags.writeable = False - issues = __check_collocated_nodes_positions(vtk_to_numpy(matrix.GetPoints().GetData()), - vtk_to_numpy(fracture.GetPoints().GetData()), - g2l, - collocated_nodes) + issues = __check_collocated_nodes_positions( + vtk_to_numpy(matrix.GetPoints().GetData()), + vtk_to_numpy(fracture.GetPoints().GetData()), g2l, collocated_nodes) assert len(issues) == 0 __check_neighbors(matrix, fracture, g2l, collocated_nodes) @@ -165,7 +175,8 @@ def __check(vtk_input_file: str, options: Options) -> Result: for duplicate in filter(lambda i: i > -1, duplicates): p0 = matrix_points.GetPoint(g2l[duplicate]) p1 = fracture_points.GetPoint(i) - if numpy.linalg.norm(numpy.array(p1) - numpy.array(p0)) > options.tolerance: + if numpy.linalg.norm(numpy.array(p1) - + numpy.array(p0)) > options.tolerance: errors.append((i, g2l[duplicate], duplicate)) return Result(errors=errors) diff --git a/geosx_mesh_doctor/checks/collocated_nodes.py b/geosx_mesh_doctor/checks/collocated_nodes.py index 7a5273e..6bcbb62 100644 --- a/geosx_mesh_doctor/checks/collocated_nodes.py +++ b/geosx_mesh_doctor/checks/collocated_nodes.py @@ -24,8 +24,10 @@ class Options: @dataclass(frozen=True) class Result: - nodes_buckets: Iterable[Iterable[int]] # Each bucket contains the duplicated node indices. - wrong_support_elements: Collection[int] # Element indices with support node indices appearing more than once. + nodes_buckets: Iterable[ + Iterable[int]] # Each bucket contains the duplicated node indices. + wrong_support_elements: Collection[ + int] # Element indices with support node indices appearing more than once. def __check(mesh, options: Options) -> Result: @@ -37,7 +39,8 @@ def __check(mesh, options: Options) -> Result: locator.InitPointInsertion(output, points.GetBounds()) # original ids to/from filtered ids. - filtered_to_original = numpy.ones(points.GetNumberOfPoints(), dtype=int) * -1 + filtered_to_original = numpy.ones(points.GetNumberOfPoints(), + dtype=int) * -1 rejected_points = defaultdict(list) point_id = reference(0) @@ -66,7 +69,8 @@ def __check(mesh, options: Options) -> Result: for c in range(mesh.GetNumberOfCells()): cell = mesh.GetCell(c) num_points_per_cell = cell.GetNumberOfPoints() - if len({cell.GetPointId(i) for i in range(num_points_per_cell)}) != num_points_per_cell: + if len({cell.GetPointId(i) + for i in range(num_points_per_cell)}) != num_points_per_cell: wrong_support_elements.append(c) return Result(nodes_buckets=tmp, diff --git a/geosx_mesh_doctor/checks/element_volumes.py b/geosx_mesh_doctor/checks/element_volumes.py index 4dfd917..c25c38a 100644 --- a/geosx_mesh_doctor/checks/element_volumes.py +++ b/geosx_mesh_doctor/checks/element_volumes.py @@ -14,9 +14,7 @@ vtkMeshQuality, ) from vtkmodules.util.numpy_support import ( - vtk_to_numpy, -) - + vtk_to_numpy, ) from . import vtk_utils @@ -39,7 +37,8 @@ def __check(mesh, options: Options) -> Result: cs.ComputeSumOff() cs.ComputeVertexCountOff() cs.ComputeVolumeOn() - volume_array_name = "__MESH_DOCTOR_VOLUME-" + str(uuid.uuid4()) # Making the name unique + volume_array_name = "__MESH_DOCTOR_VOLUME-" + str( + uuid.uuid4()) # Making the name unique cs.SetVolumeArrayName(volume_array_name) cs.SetInputData(mesh) @@ -50,19 +49,23 @@ def __check(mesh, options: Options) -> Result: mq.SetTetQualityMeasureToVolume() mq.SetHexQualityMeasureToVolume() - if hasattr(mq, "SetPyramidQualityMeasureToVolume"): # This feature is quite recent + if hasattr(mq, "SetPyramidQualityMeasureToVolume" + ): # This feature is quite recent mq.SetPyramidQualityMeasureToVolume() SUPPORTED_TYPES.append(VTK_PYRAMID) mq.SetWedgeQualityMeasureToVolume() SUPPORTED_TYPES.append(VTK_WEDGE) else: - logging.warning("Your \"pyvtk\" version does not bring pyramid nor wedge support with vtkMeshQuality. Using the fallback solution.") + logging.warning( + "Your \"pyvtk\" version does not bring pyramid nor wedge support with vtkMeshQuality. Using the fallback solution." + ) mq.SetInputData(mesh) mq.Update() volume = cs.GetOutput().GetCellData().GetArray(volume_array_name) - quality = mq.GetOutput().GetCellData().GetArray("Quality") # Name is imposed by vtk. + quality = mq.GetOutput().GetCellData().GetArray( + "Quality") # Name is imposed by vtk. assert volume is not None assert quality is not None diff --git a/geosx_mesh_doctor/checks/fix_elements_orderings.py b/geosx_mesh_doctor/checks/fix_elements_orderings.py index 61dd034..2ed9e69 100644 --- a/geosx_mesh_doctor/checks/fix_elements_orderings.py +++ b/geosx_mesh_doctor/checks/fix_elements_orderings.py @@ -8,8 +8,7 @@ ) from vtkmodules.vtkCommonCore import ( - vtkIdList, -) + vtkIdList, ) from . import vtk_utils from .vtk_utils import ( @@ -52,13 +51,16 @@ def __check(mesh, options: Options) -> Result: cells.GetCellAtId(cell_idx, support_point_ids) new_support_point_ids = [] for i, v in enumerate(new_ordering): - new_support_point_ids.append(support_point_ids.GetId(new_ordering[i])) - cells.ReplaceCellAtId(cell_idx, to_vtk_id_list(new_support_point_ids)) + new_support_point_ids.append( + support_point_ids.GetId(new_ordering[i])) + cells.ReplaceCellAtId(cell_idx, + to_vtk_id_list(new_support_point_ids)) else: unchanged_cell_types.add(cell_type) is_written_error = vtk_utils.write_mesh(output_mesh, options.vtk_output) - return Result(output=options.vtk_output.output if not is_written_error else "", - unchanged_cell_types=frozenset(unchanged_cell_types)) + return Result( + output=options.vtk_output.output if not is_written_error else "", + unchanged_cell_types=frozenset(unchanged_cell_types)) def check(vtk_input_file: str, options: Options) -> Result: diff --git a/geosx_mesh_doctor/checks/generate_cube.py b/geosx_mesh_doctor/checks/generate_cube.py index f8625f5..765b371 100644 --- a/geosx_mesh_doctor/checks/generate_cube.py +++ b/geosx_mesh_doctor/checks/generate_cube.py @@ -5,8 +5,7 @@ import numpy from vtkmodules.vtkCommonCore import ( - vtkPoints, -) + vtkPoints, ) from vtkmodules.vtkCommonDataModel import ( VTK_HEXAHEDRON, vtkCellArray, @@ -15,13 +14,11 @@ vtkUnstructuredGrid, ) from vtkmodules.util.numpy_support import ( - numpy_to_vtk, -) + numpy_to_vtk, ) from . import vtk_utils from .vtk_utils import ( - VtkOutput, -) + VtkOutput, ) from .generate_global_ids import __build_global_ids @@ -87,7 +84,8 @@ def build_rectilinear_blocks_mesh(xyzs: Iterable[XYZ]) -> vtkUnstructuredGrid: cells = vtkCellArray() cells.AllocateExact(num_cells, num_cells * 8) - m = (0, 1, 3, 2, 4, 5, 7, 6) # VTK_VOXEL and VTK_HEXAHEDRON do not share the same ordering. + m = (0, 1, 3, 2, 4, 5, 7, 6 + ) # VTK_VOXEL and VTK_HEXAHEDRON do not share the same ordering. offset = 0 for rg in rgs: for i in range(rg.GetNumberOfCells()): @@ -105,7 +103,8 @@ def build_rectilinear_blocks_mesh(xyzs: Iterable[XYZ]) -> vtkUnstructuredGrid: return mesh -def __add_fields(mesh: vtkUnstructuredGrid, fields: Iterable[FieldInfo]) -> vtkUnstructuredGrid: +def __add_fields(mesh: vtkUnstructuredGrid, + fields: Iterable[FieldInfo]) -> vtkUnstructuredGrid: for field_info in fields: if field_info.support == "CELLS": data = mesh.GetCellData() @@ -121,6 +120,7 @@ def __add_fields(mesh: vtkUnstructuredGrid, fields: Iterable[FieldInfo]) -> vtkU def __build(options: Options): + def build_coordinates(positions, num_elements): result = [] it = zip(zip(positions, positions[1:]), num_elements) @@ -129,20 +129,28 @@ def build_coordinates(positions, num_elements): while True: start, stop = coords end_point = False - tmp = numpy.linspace(start=start, stop=stop, num=n+end_point, endpoint=end_point) + tmp = numpy.linspace(start=start, + stop=stop, + num=n + end_point, + endpoint=end_point) coords, n = next(it) result.append(tmp) except StopIteration: end_point = True - tmp = numpy.linspace(start=start, stop=stop, num=n+end_point, endpoint=end_point) + tmp = numpy.linspace(start=start, + stop=stop, + num=n + end_point, + endpoint=end_point) result.append(tmp) return numpy.concatenate(result) + x = build_coordinates(options.xs, options.nxs) y = build_coordinates(options.ys, options.nys) z = build_coordinates(options.zs, options.nzs) - cube = build_rectilinear_blocks_mesh((XYZ(x, y, z),)) + cube = build_rectilinear_blocks_mesh((XYZ(x, y, z), )) cube = __add_fields(cube, options.fields) - __build_global_ids(cube, options.generate_cells_global_ids, options.generate_points_global_ids) + __build_global_ids(cube, options.generate_cells_global_ids, + options.generate_points_global_ids) return cube diff --git a/geosx_mesh_doctor/checks/generate_fractures.py b/geosx_mesh_doctor/checks/generate_fractures.py index 22fbadc..44a0640 100644 --- a/geosx_mesh_doctor/checks/generate_fractures.py +++ b/geosx_mesh_doctor/checks/generate_fractures.py @@ -44,8 +44,7 @@ to_vtk_id_list, ) from .vtk_polyhedron import ( - FaceStream, -) + FaceStream, ) class FracturePolicy(Enum): @@ -69,21 +68,27 @@ class Result: @dataclass(frozen=True) class FractureInfo: - node_to_cells: Mapping[int, Iterable[int]] # For each _fracture_ node, gives all the cells that use this node. - face_nodes: Iterable[Collection[int]] # For each fracture face, returns the nodes of this face + node_to_cells: Mapping[int, Iterable[ + int]] # For each _fracture_ node, gives all the cells that use this node. + face_nodes: Iterable[Collection[ + int]] # For each fracture face, returns the nodes of this face -def build_node_to_cells(mesh: vtkUnstructuredGrid, - face_nodes: Iterable[Iterable[int]]) -> Mapping[int, Iterable[int]]: - node_to_cells: Dict[int, Set[int]] = defaultdict(set) # TODO normally, just a list and not a set should be enough. +def build_node_to_cells( + mesh: vtkUnstructuredGrid, + face_nodes: Iterable[Iterable[int]]) -> Mapping[int, Iterable[int]]: + node_to_cells: Dict[int, Set[int]] = defaultdict( + set) # TODO normally, just a list and not a set should be enough. fracture_nodes: Set[int] = set() for fns in face_nodes: for n in fns: fracture_nodes.add(n) - for cell_id in tqdm(range(mesh.GetNumberOfCells()), desc="Computing the node to cells mapping"): - cell_points: FrozenSet[int] = frozenset(vtk_iter(mesh.GetCell(cell_id).GetPointIds())) + for cell_id in tqdm(range(mesh.GetNumberOfCells()), + desc="Computing the node to cells mapping"): + cell_points: FrozenSet[int] = frozenset( + vtk_iter(mesh.GetCell(cell_id).GetPointIds())) intersection: Iterable[int] = cell_points & fracture_nodes for node in intersection: node_to_cells[node].add(cell_id) @@ -91,47 +96,60 @@ def build_node_to_cells(mesh: vtkUnstructuredGrid, return node_to_cells -def __build_fracture_info_from_fields(mesh: vtkUnstructuredGrid, - f: Sequence[int], - field_values: FrozenSet[int]) -> FractureInfo: +def __build_fracture_info_from_fields( + mesh: vtkUnstructuredGrid, f: Sequence[int], + field_values: FrozenSet[int]) -> FractureInfo: cells_to_faces: Dict[int, List[int]] = defaultdict(list) # For each face of each cell, we search for the unique neighbor cell (if it exists). # Then, if the 2 values of the two cells match the field requirements, # we store the cell and its local face index: this is indeed part of the surface that we'll need to be split. cell: vtkCell - for cell_id in tqdm(range(mesh.GetNumberOfCells()), desc="Computing the cell to faces mapping"): + for cell_id in tqdm(range(mesh.GetNumberOfCells()), + desc="Computing the cell to faces mapping"): if f[cell_id] not in field_values: # No need to consider a cell if its field value is not in the target range. continue cell = mesh.GetCell(cell_id) for i in range(cell.GetNumberOfFaces()): neighbor_cell_ids = vtkIdList() - mesh.GetCellNeighbors(cell_id, cell.GetFace(i).GetPointIds(), neighbor_cell_ids) + mesh.GetCellNeighbors(cell_id, + cell.GetFace(i).GetPointIds(), + neighbor_cell_ids) assert neighbor_cell_ids.GetNumberOfIds() < 2 - for j in range(neighbor_cell_ids.GetNumberOfIds()): # It's 0 or 1... + for j in range( + neighbor_cell_ids.GetNumberOfIds()): # It's 0 or 1... neighbor_cell_id = neighbor_cell_ids.GetId(j) - if f[neighbor_cell_id] != f[cell_id] and f[neighbor_cell_id] in field_values: - cells_to_faces[cell_id].append(i) # TODO add this (cell_is, face_id) information to the fracture_info? + if f[neighbor_cell_id] != f[cell_id] and f[ + neighbor_cell_id] in field_values: + cells_to_faces[cell_id].append( + i + ) # TODO add this (cell_is, face_id) information to the fracture_info? face_nodes: List[Collection[int]] = list() - face_nodes_hashes: Set[FrozenSet[int]] = set() # A temporary not to add multiple times the same face. - for cell_id, faces_ids in tqdm(cells_to_faces.items(), desc="Extracting the faces of the fractures"): + face_nodes_hashes: Set[FrozenSet[int]] = set( + ) # A temporary not to add multiple times the same face. + for cell_id, faces_ids in tqdm( + cells_to_faces.items(), + desc="Extracting the faces of the fractures"): cell = mesh.GetCell(cell_id) for face_id in faces_ids: - fn: Collection[int] = tuple(vtk_iter(cell.GetFace(face_id).GetPointIds())) + fn: Collection[int] = tuple( + vtk_iter(cell.GetFace(face_id).GetPointIds())) fnh = frozenset(fn) if fnh not in face_nodes_hashes: face_nodes_hashes.add(fnh) face_nodes.append(fn) - node_to_cells: Mapping[int, Iterable[int]] = build_node_to_cells(mesh, face_nodes) + node_to_cells: Mapping[int, Iterable[int]] = build_node_to_cells( + mesh, face_nodes) return FractureInfo(node_to_cells=node_to_cells, face_nodes=face_nodes) -def __build_fracture_info_from_internal_surfaces(mesh: vtkUnstructuredGrid, - f: Sequence[int], - field_values: FrozenSet[int]) -> FractureInfo: +def __build_fracture_info_from_internal_surfaces( + mesh: vtkUnstructuredGrid, f: Sequence[int], + field_values: FrozenSet[int]) -> FractureInfo: node_to_cells: Dict[int, List[int]] = {} face_nodes: List[Collection[int]] = [] - for cell_id in tqdm(range(mesh.GetNumberOfCells()), desc="Computing the face to nodes mapping"): + for cell_id in tqdm(range(mesh.GetNumberOfCells()), + desc="Computing the face to nodes mapping"): cell = mesh.GetCell(cell_id) if cell.GetCellDimension() == 2: if f[cell_id] in field_values: @@ -142,7 +160,8 @@ def __build_fracture_info_from_internal_surfaces(mesh: vtkUnstructuredGrid, nodes.append(point_id) face_nodes.append(tuple(nodes)) - for cell_id in tqdm(range(mesh.GetNumberOfCells()), desc="Computing the node to cells mapping"): + for cell_id in tqdm(range(mesh.GetNumberOfCells()), + desc="Computing the node to cells mapping"): cell = mesh.GetCell(cell_id) if cell.GetCellDimension() == 3: for v in range(cell.GetNumberOfPoints()): @@ -160,12 +179,14 @@ def build_fracture_info(mesh: vtkUnstructuredGrid, if cell_data.HasArray(field): f = vtk_to_numpy(cell_data.GetArray(field)) else: - raise ValueError(f"Cell field {field} does not exist in mesh, nothing done") + raise ValueError( + f"Cell field {field} does not exist in mesh, nothing done") if options.policy == FracturePolicy.FIELD: return __build_fracture_info_from_fields(mesh, f, field_values) elif options.policy == FracturePolicy.INTERNAL_SURFACES: - return __build_fracture_info_from_internal_surfaces(mesh, f, field_values) + return __build_fracture_info_from_internal_surfaces( + mesh, f, field_values) def build_cell_to_cell_graph(mesh: vtkUnstructuredGrid, @@ -197,7 +218,8 @@ def build_cell_to_cell_graph(mesh: vtkUnstructuredGrid, for cell_id in tqdm(cells, desc="Computing the cell to cell graph"): cell: vtkCell = mesh.GetCell(cell_id) for face_id in range(cell.GetNumberOfFaces()): - face_hash: FrozenSet[int] = frozenset(vtk_iter(cell.GetFace(face_id).GetPointIds())) + face_hash: FrozenSet[int] = frozenset( + vtk_iter(cell.GetFace(face_id).GetPointIds())) if face_hash not in face_hashes: face_to_cells[face_hash].append(cell_id) @@ -205,14 +227,16 @@ def build_cell_to_cell_graph(mesh: vtkUnstructuredGrid, # and should be connected in the final cell to cell graph. cell_to_cell = networkx.Graph() cell_to_cell.add_nodes_from(cells) - cell_to_cell.add_edges_from(filter(lambda cs: len(cs) == 2, face_to_cells.values())) + cell_to_cell.add_edges_from( + filter(lambda cs: len(cs) == 2, face_to_cells.values())) return cell_to_cell -def __identify_split(num_points: int, - cell_to_cell: networkx.Graph, - node_to_cells: Mapping[int, Iterable[int]]) -> Mapping[int, Mapping[int, int]]: +def __identify_split( + num_points: int, cell_to_cell: networkx.Graph, + node_to_cells: Mapping[int, + Iterable[int]]) -> Mapping[int, Mapping[int, int]]: """ For each cell, compute the node indices replacements. :param num_points: Number of points in the whole mesh (not the fracture). @@ -229,6 +253,7 @@ class NewIndex: Note that the first time an index is met, the index itself is returned: we do not want to change an index if we do not have to. """ + def __init__(self, num_nodes: int): self.__current_last_index = num_nodes - 1 self.__seen: Set[int] = set() @@ -243,9 +268,13 @@ def __call__(self, index: int) -> int: build_new_index = NewIndex(num_points) result: Dict[int, Dict[int, int]] = defaultdict(dict) - for node, cells in tqdm(sorted(node_to_cells.items()), # Iteration over `sorted` nodes to have a predictable result for tests. - desc="Identifying the node splits"): - for connected_cells in networkx.connected_components(cell_to_cell.subgraph(cells)): + for node, cells in tqdm( + sorted( + node_to_cells.items() + ), # Iteration over `sorted` nodes to have a predictable result for tests. + desc="Identifying the node splits"): + for connected_cells in networkx.connected_components( + cell_to_cell.subgraph(cells)): # Each group of connect cells need around `node` must consider the same `node`. # Separate groups must have different (duplicated) nodes. new_index: int = build_new_index(node) @@ -254,8 +283,7 @@ def __call__(self, index: int) -> int: return result -def __copy_fields(old_mesh: vtkUnstructuredGrid, - new_mesh: vtkUnstructuredGrid, +def __copy_fields(old_mesh: vtkUnstructuredGrid, new_mesh: vtkUnstructuredGrid, collocated_nodes: Sequence[int]) -> None: """ Copies the fields from the old mesh to the new one. @@ -294,8 +322,10 @@ def __copy_fields(old_mesh: vtkUnstructuredGrid, new_mesh.GetPointData().AddArray(tmp) -def __perform_split(old_mesh: vtkUnstructuredGrid, - cell_to_node_mapping: Mapping[int, Mapping[int, int]]) -> vtkUnstructuredGrid: +def __perform_split( + old_mesh: vtkUnstructuredGrid, + cell_to_node_mapping: Mapping[int, Mapping[int, + int]]) -> vtkUnstructuredGrid: """ Split the main 3d mesh based on the node duplication information contained in @p cell_to_node_mapping :param old_mesh: The main 3d mesh. @@ -338,7 +368,8 @@ def __perform_split(old_mesh: vtkUnstructuredGrid, new_mesh.SetPoints(new_points) new_mesh.Allocate(old_mesh.GetNumberOfCells()) - for c in tqdm(range(old_mesh.GetNumberOfCells()), desc="Performing the mesh split"): + for c in tqdm(range(old_mesh.GetNumberOfCells()), + desc="Performing the mesh split"): node_mapping: Mapping[int, int] = cell_to_node_mapping.get(c, {}) cell: vtkCell = old_mesh.GetCell(c) cell_type: int = cell.GetCellType() @@ -347,20 +378,24 @@ def __perform_split(old_mesh: vtkUnstructuredGrid, face_stream = vtkIdList() old_mesh.GetFaceStream(c, face_stream) new_face_nodes: List[List[int]] = [] - for face_nodes in FaceStream.build_from_vtk_id_list(face_stream).face_nodes: + for face_nodes in FaceStream.build_from_vtk_id_list( + face_stream).face_nodes: new_point_ids = [] for current_point_id in face_nodes: - new_point_id: int = node_mapping.get(current_point_id, current_point_id) + new_point_id: int = node_mapping.get( + current_point_id, current_point_id) new_point_ids.append(new_point_id) new_face_nodes.append(new_point_ids) - new_mesh.InsertNextCell(cell_type, to_vtk_id_list(FaceStream(new_face_nodes).dump())) + new_mesh.InsertNextCell( + cell_type, to_vtk_id_list(FaceStream(new_face_nodes).dump())) else: # For the standard cells, we extract the point ids of the cell directly. # Then the values will be (potentially) overwritten in place, before being sent back into the cell. cell_point_ids: vtkIdList = cell.GetPointIds() for i in range(cell_point_ids.GetNumberOfIds()): current_point_id: int = cell_point_ids.GetId(i) - new_point_id: int = node_mapping.get(current_point_id, current_point_id) + new_point_id: int = node_mapping.get(current_point_id, + current_point_id) cell_point_ids.SetId(i, new_point_id) new_mesh.InsertNextCell(cell_type, cell_point_ids) @@ -369,9 +404,10 @@ def __perform_split(old_mesh: vtkUnstructuredGrid, return new_mesh -def __generate_fracture_mesh(mesh_points: vtkPoints, - fracture_info: FractureInfo, - cell_to_node_mapping: Mapping[int, Mapping[int, int]]) -> vtkUnstructuredGrid: +def __generate_fracture_mesh( + mesh_points: vtkPoints, fracture_info: FractureInfo, + cell_to_node_mapping: Mapping[int, Mapping[int, + int]]) -> vtkUnstructuredGrid: """ Generates the mesh of the fracture. :param mesh_points: The points of the main 3d mesh. @@ -381,7 +417,8 @@ def __generate_fracture_mesh(mesh_points: vtkPoints, """ logging.info("Generating the meshes") - is_node_duplicated = numpy.zeros(mesh_points.GetNumberOfPoints(), dtype=bool) # defaults to False + is_node_duplicated = numpy.zeros(mesh_points.GetNumberOfPoints(), + dtype=bool) # defaults to False for node_mapping in cell_to_node_mapping.values(): for i, o in node_mapping.items(): if not is_node_duplicated[i]: @@ -402,20 +439,27 @@ def __generate_fracture_mesh(mesh_points: vtkPoints, # tmp = [] # for dfns in discarded_face_nodes: # tmp.append(", ".join(map(str, dfns))) - msg: str = "(" + '), ('.join(map(lambda dfns: ", ".join(map(str, dfns)), discarded_face_nodes)) + ")" + msg: str = "(" + '), ('.join( + map(lambda dfns: ", ".join(map(str, dfns)), + discarded_face_nodes)) + ")" # logging.info(f"The {len(tmp)} faces made of nodes ({'), ('.join(tmp)}) were/was discarded from the fracture mesh because none of their/its nodes were duplicated.") # print(f"The {len(tmp)} faces made of nodes ({'), ('.join(tmp)}) were/was discarded from the fracture mesh because none of their/its nodes were duplicated.") - print(f"The faces made of nodes [{msg}] were/was discarded from the fracture mesh because none of their/its nodes were duplicated.") + print( + f"The faces made of nodes [{msg}] were/was discarded from the fracture mesh because none of their/its nodes were duplicated." + ) - fracture_nodes_tmp = numpy.ones(mesh_points.GetNumberOfPoints(), dtype=int) * -1 + fracture_nodes_tmp = numpy.ones(mesh_points.GetNumberOfPoints(), + dtype=int) * -1 for ns in face_nodes: for n in ns: fracture_nodes_tmp[n] = n - fracture_nodes: Collection[int] = tuple(filter(lambda n: n > -1, fracture_nodes_tmp)) + fracture_nodes: Collection[int] = tuple( + filter(lambda n: n > -1, fracture_nodes_tmp)) num_points: int = len(fracture_nodes) points = vtkPoints() points.SetNumberOfPoints(num_points) - node_3d_to_node_2d: Dict[int, int] = {} # Building the node mapping, from 3d mesh nodes to 2d fracture nodes. + node_3d_to_node_2d: Dict[int, int] = { + } # Building the node mapping, from 3d mesh nodes to 2d fracture nodes. for i, n in enumerate(fracture_nodes): coords: Tuple[float, float, float] = mesh_points.GetPoint(n) points.SetPoint(i, coords) @@ -437,31 +481,37 @@ def __generate_fracture_mesh(mesh_points: vtkPoints, buckets[k].update((i, o)) assert set(buckets.keys()) == set(range(num_points)) - max_collocated_nodes: int = max(map(len, buckets.values())) if buckets.values() else 0 - collocated_nodes = numpy.ones((num_points, max_collocated_nodes), dtype=int) * -1 + max_collocated_nodes: int = max(map( + len, buckets.values())) if buckets.values() else 0 + collocated_nodes = numpy.ones( + (num_points, max_collocated_nodes), dtype=int) * -1 for i, bucket in buckets.items(): for j, val in enumerate(bucket): collocated_nodes[i, j] = val array = numpy_to_vtk(collocated_nodes, array_type=VTK_ID_TYPE) array.SetName("collocated_nodes") - fracture_mesh = vtkUnstructuredGrid() # We could be using vtkPolyData, but it's not supported by GEOS for now. + fracture_mesh = vtkUnstructuredGrid( + ) # We could be using vtkPolyData, but it's not supported by GEOS for now. fracture_mesh.SetPoints(points) if polygons.GetNumberOfCells() > 0: - fracture_mesh.SetCells([VTK_POLYGON] * polygons.GetNumberOfCells(), polygons) + fracture_mesh.SetCells([VTK_POLYGON] * polygons.GetNumberOfCells(), + polygons) fracture_mesh.GetPointData().AddArray(array) return fracture_mesh -def __split_mesh_on_fracture(mesh: vtkUnstructuredGrid, - options: Options) -> Tuple[vtkUnstructuredGrid, vtkUnstructuredGrid]: +def __split_mesh_on_fracture( + mesh: vtkUnstructuredGrid, + options: Options) -> Tuple[vtkUnstructuredGrid, vtkUnstructuredGrid]: fracture: FractureInfo = build_fracture_info(mesh, options) cell_to_cell: networkx.Graph = build_cell_to_cell_graph(mesh, fracture) - cell_to_node_mapping: Mapping[int, Mapping[int, int]] = __identify_split(mesh.GetNumberOfPoints(), - cell_to_cell, - fracture.node_to_cells) - output_mesh: vtkUnstructuredGrid = __perform_split(mesh, cell_to_node_mapping) - fractured_mesh: vtkUnstructuredGrid = __generate_fracture_mesh(mesh.GetPoints(), fracture, cell_to_node_mapping) + cell_to_node_mapping: Mapping[int, Mapping[int, int]] = __identify_split( + mesh.GetNumberOfPoints(), cell_to_cell, fracture.node_to_cells) + output_mesh: vtkUnstructuredGrid = __perform_split(mesh, + cell_to_node_mapping) + fractured_mesh: vtkUnstructuredGrid = __generate_fracture_mesh( + mesh.GetPoints(), fracture, cell_to_node_mapping) return output_mesh, fractured_mesh diff --git a/geosx_mesh_doctor/checks/generate_global_ids.py b/geosx_mesh_doctor/checks/generate_global_ids.py index 80474e2..4ecb48c 100644 --- a/geosx_mesh_doctor/checks/generate_global_ids.py +++ b/geosx_mesh_doctor/checks/generate_global_ids.py @@ -6,8 +6,7 @@ from . import vtk_utils from .vtk_utils import ( - VtkOutput, -) + VtkOutput, ) @dataclass(frozen=True) @@ -22,8 +21,7 @@ class Result: info: str -def __build_global_ids(mesh, - generate_cells_global_ids: bool, +def __build_global_ids(mesh, generate_cells_global_ids: bool, generate_points_global_ids: bool) -> None: """ Adds the global ids for cells and points in place into the mesh instance. @@ -54,7 +52,8 @@ def __build_global_ids(mesh, def __check(mesh, options: Options) -> Result: - __build_global_ids(mesh, options.generate_cells_global_ids, options.generate_points_global_ids) + __build_global_ids(mesh, options.generate_cells_global_ids, + options.generate_points_global_ids) vtk_utils.write_mesh(mesh, options.vtk_output) return Result(info=f"Mesh was written to {options.vtk_output.output}") diff --git a/geosx_mesh_doctor/checks/non_conformal.py b/geosx_mesh_doctor/checks/non_conformal.py index 43f26e2..83dd164 100644 --- a/geosx_mesh_doctor/checks/non_conformal.py +++ b/geosx_mesh_doctor/checks/non_conformal.py @@ -21,14 +21,11 @@ vtkUnstructuredGrid, ) from vtkmodules.vtkCommonTransforms import ( - vtkTransform, -) + vtkTransform, ) from vtkmodules.vtkFiltersCore import ( - vtkPolyDataNormals, -) + vtkPolyDataNormals, ) from vtkmodules.vtkFiltersGeometry import ( - vtkDataSetSurfaceFilter, -) + vtkDataSetSurfaceFilter, ) from vtkmodules.vtkFiltersModeling import ( vtkCollisionDetectionFilter, vtkLinearExtrusionFilter, @@ -40,8 +37,7 @@ from . import vtk_utils from .vtk_polyhedron import ( - vtk_iter, -) + vtk_iter, ) from . import triangle_distance @@ -67,32 +63,44 @@ class BoundaryMesh: Therefore, we reorient the polyhedron cells ourselves, so we're sure that they point outwards. And then we compute the boundary meshes for both meshes, given that the computing options are not identical. """ + def __init__(self, mesh: vtkUnstructuredGrid): """ Builds a boundary mesh. :param mesh: The 3d mesh. """ # Building the boundary meshes - boundary_mesh, __normals, self.__original_cells = BoundaryMesh.__build_boundary_mesh(mesh) - cells_to_reorient = filter(lambda c: mesh.GetCell(c).GetCellType() == VTK_POLYHEDRON, - map(self.__original_cells.GetValue, - range(self.__original_cells.GetNumberOfValues()))) + boundary_mesh, __normals, self.__original_cells = BoundaryMesh.__build_boundary_mesh( + mesh) + cells_to_reorient = filter( + lambda c: mesh.GetCell(c).GetCellType() == VTK_POLYHEDRON, + map(self.__original_cells.GetValue, + range(self.__original_cells.GetNumberOfValues()))) reoriented_mesh = reorient_mesh(mesh, cells_to_reorient) - self.re_boundary_mesh, re_normals, _ = BoundaryMesh.__build_boundary_mesh(reoriented_mesh, consistency=False) + self.re_boundary_mesh, re_normals, _ = BoundaryMesh.__build_boundary_mesh( + reoriented_mesh, consistency=False) num_cells = boundary_mesh.GetNumberOfCells() # Precomputing the underlying cell type - self.__is_underlying_cell_type_a_polyhedron = numpy.zeros(num_cells, dtype=bool) + self.__is_underlying_cell_type_a_polyhedron = numpy.zeros(num_cells, + dtype=bool) for ic in range(num_cells): - self.__is_underlying_cell_type_a_polyhedron[ic] = mesh.GetCell(self.__original_cells.GetValue(ic)).GetCellType() == VTK_POLYHEDRON + self.__is_underlying_cell_type_a_polyhedron[ic] = mesh.GetCell( + self.__original_cells.GetValue( + ic)).GetCellType() == VTK_POLYHEDRON # Precomputing the normals - self.__normals: numpy.ndarray = numpy.empty((num_cells, 3), dtype=numpy.double, order='C') # Do not modify the storage layout + self.__normals: numpy.ndarray = numpy.empty( + (num_cells, 3), dtype=numpy.double, + order='C') # Do not modify the storage layout for ic in range(num_cells): if self.__is_underlying_cell_type_a_polyhedron[ic]: self.__normals[ic, :] = re_normals.GetTuple3(ic) else: self.__normals[ic, :] = __normals.GetTuple3(ic) + @staticmethod - def __build_boundary_mesh(mesh: vtkUnstructuredGrid, consistency=True) -> Tuple[vtkUnstructuredGrid, Any, Any]: + def __build_boundary_mesh( + mesh: vtkUnstructuredGrid, + consistency=True) -> Tuple[vtkUnstructuredGrid, Any, Any]: """ From a 3d mesh, build the envelope meshes. :param mesh: The input 3d mesh. @@ -123,7 +131,8 @@ def __build_boundary_mesh(mesh: vtkUnstructuredGrid, consistency=True) -> Tuple[ assert normals assert normals.GetNumberOfComponents() == 3 assert normals.GetNumberOfTuples() == boundary_mesh.GetNumberOfCells() - original_cells = boundary_mesh.GetCellData().GetArray(original_cells_key) + original_cells = boundary_mesh.GetCellData().GetArray( + original_cells_key) assert original_cells return boundary_mesh, normals, original_cells @@ -184,7 +193,8 @@ def original_cells(self): return self.__original_cells -def build_poly_data_for_extrusion(i: int, boundary_mesh: BoundaryMesh) -> vtkPolyData: +def build_poly_data_for_extrusion(i: int, + boundary_mesh: BoundaryMesh) -> vtkPolyData: """ Creates a vtkPolyData containing the unique cell `i` of the boundary mesh. This operation is needed to use the vtk extrusion filter. @@ -211,7 +221,8 @@ def build_poly_data_for_extrusion(i: int, boundary_mesh: BoundaryMesh) -> vtkPol return polygon_poly_data -def are_points_conformal(point_tolerance: float, cell_i: vtkCell, cell_j: vtkCell) -> bool: +def are_points_conformal(point_tolerance: float, cell_i: vtkCell, + cell_j: vtkCell) -> bool: """ Checks if points of cell `i` matches, one by one, the points of cell `j`. :param point_tolerance: The point tolerance to consider that two points match. @@ -232,7 +243,8 @@ def are_points_conformal(point_tolerance: float, cell_i: vtkCell, cell_j: vtkCel for ip in range(cell_j.GetNumberOfPoints()): p = cell_j.GetPoints().GetPoint(ip) squared_dist = vtk_reference(0.) # unused - found_point = point_locator.FindClosestPointWithinRadius(point_tolerance, p, squared_dist) + found_point = point_locator.FindClosestPointWithinRadius( + point_tolerance, p, squared_dist) found_points.add(found_point) return found_points == set(range(cell_i.GetNumberOfPoints())) @@ -242,8 +254,11 @@ class Extruder: Computes and stores all the extrusions of the boundary faces. The main reason for this class is to be lazy and cache the extrusions. """ + def __init__(self, boundary_mesh: BoundaryMesh, face_tolerance: float): - self.__extrusions: List[vtkPolyData] = [None, ] * boundary_mesh.GetNumberOfCells() + self.__extrusions: List[vtkPolyData] = [ + None, + ] * boundary_mesh.GetNumberOfCells() self.__boundary_mesh = boundary_mesh self.__face_tolerance = face_tolerance @@ -271,14 +286,14 @@ def __getitem__(self, i) -> vtkPolyData: extrusion = self.__extrusions[i] if extrusion: return extrusion - extrusion = self.__extrude(build_poly_data_for_extrusion(i, self.__boundary_mesh), - self.__boundary_mesh.normals(i)) + extrusion = self.__extrude( + build_poly_data_for_extrusion(i, self.__boundary_mesh), + self.__boundary_mesh.normals(i)) self.__extrusions[i] = extrusion return extrusion -def are_faces_conformal_using_extrusions(extrusions: Extruder, - i: int, j: int, +def are_faces_conformal_using_extrusions(extrusions: Extruder, i: int, j: int, boundary_mesh: vtkUnstructuredGrid, point_tolerance: float) -> bool: """ @@ -308,12 +323,14 @@ def are_faces_conformal_using_extrusions(extrusions: Extruder, copied_cell_i = cell_i.NewInstance() copied_cell_i.DeepCopy(cell_i) - return are_points_conformal(point_tolerance, copied_cell_i, boundary_mesh.GetCell(j)) + return are_points_conformal(point_tolerance, copied_cell_i, + boundary_mesh.GetCell(j)) def are_faces_conformal_using_distances(i: int, j: int, boundary_mesh: vtkUnstructuredGrid, - face_tolerance: float, point_tolerance: float) -> bool: + face_tolerance: float, + point_tolerance: float) -> bool: """ Tests if two boundary faces are conformal, checking the minimal distance between triangulated surfaces. :param i: The cell index of the first cell. @@ -345,7 +362,7 @@ def build_numpy_triangles(points_ids): __triangles = [] for __i in range(0, len(points_ids), 3): __t = [] - for __pi in points_ids[__i: __i + 3]: + for __pi in points_ids[__i:__i + 3]: __t.append(boundary_mesh.GetPoint(__pi)) __triangles.append(numpy.array(__t, dtype=float)) return __triangles @@ -381,7 +398,8 @@ def __check(mesh: vtkUnstructuredGrid, options: Options) -> Result: num_cells = boundary_mesh.GetNumberOfCells() # Computing the exact number of cells per node - num_cells_per_node = numpy.zeros(boundary_mesh.GetNumberOfPoints(), dtype=int) + num_cells_per_node = numpy.zeros(boundary_mesh.GetNumberOfPoints(), + dtype=int) for ic in range(boundary_mesh.GetNumberOfCells()): c = boundary_mesh.GetCell(ic) point_ids = c.GetPointIds() @@ -396,11 +414,14 @@ def __check(mesh: vtkUnstructuredGrid, options: Options) -> Result: # Precomputing the bounding boxes. # The options are important to directly interact with memory in C++. - bounding_boxes = numpy.empty((boundary_mesh.GetNumberOfCells(), 6), dtype=numpy.double, order="C") + bounding_boxes = numpy.empty((boundary_mesh.GetNumberOfCells(), 6), + dtype=numpy.double, + order="C") for i in range(boundary_mesh.GetNumberOfCells()): bb = vtkBoundingBox(boundary_mesh.bounds(i)) bb.Inflate(2 * options.face_tolerance) - assert bounding_boxes[i, :].data.contiguous # Do not modify the storage layout since vtk deals with raw memory here. + assert bounding_boxes[ + i, :].data.contiguous # Do not modify the storage layout since vtk deals with raw memory here. bb.GetBounds(bounding_boxes[i, :]) non_conformal_cells = [] @@ -413,16 +434,21 @@ def __check(mesh: vtkUnstructuredGrid, options: Options) -> Result: if j < i: continue # Discarding pairs that are not facing each others (with a threshold). - normal_i, normal_j = boundary_mesh.normals(i), boundary_mesh.normals(j) - if numpy.dot(normal_i, normal_j) > -cos_theta: # opposite directions only (can be facing or not) + normal_i, normal_j = boundary_mesh.normals( + i), boundary_mesh.normals(j) + if numpy.dot( + normal_i, normal_j + ) > -cos_theta: # opposite directions only (can be facing or not) continue # At this point, back-to-back and face-to-face pairs of elements are considered. - if not are_faces_conformal_using_extrusions(extrusions, i, j, boundary_mesh, options.point_tolerance): + if not are_faces_conformal_using_extrusions( + extrusions, i, j, boundary_mesh, options.point_tolerance): non_conformal_cells.append((i, j)) # Extracting the original 3d element index (and not the index of the boundary mesh). tmp = [] for i, j in non_conformal_cells: - tmp.append((boundary_mesh.original_cells.GetValue(i), boundary_mesh.original_cells.GetValue(j))) + tmp.append((boundary_mesh.original_cells.GetValue(i), + boundary_mesh.original_cells.GetValue(j))) return Result(non_conformal_cells=tmp) diff --git a/geosx_mesh_doctor/checks/reorient_mesh.py b/geosx_mesh_doctor/checks/reorient_mesh.py index efb664b..047103b 100644 --- a/geosx_mesh_doctor/checks/reorient_mesh.py +++ b/geosx_mesh_doctor/checks/reorient_mesh.py @@ -27,11 +27,9 @@ vtkTetra, ) from vtkmodules.vtkFiltersCore import ( - vtkTriangleFilter, -) + vtkTriangleFilter, ) from .vtk_utils import ( - to_vtk_id_list, -) + to_vtk_id_list, ) from .vtk_polyhedron import ( FaceStream, @@ -74,10 +72,12 @@ def __compute_volume(mesh_points: vtkPoints, face_stream: FaceStream) -> float: # (The basis of all the tetra being the triangles of the envelope). # We could take any point, not only the barycenter. # But in order to work with figure of the same magnitude, let's compute the barycenter. - tmp_barycenter = numpy.empty((face_stream.num_support_points, 3), dtype=float) + tmp_barycenter = numpy.empty((face_stream.num_support_points, 3), + dtype=float) for i, point_id in enumerate(face_stream.support_point_ids): tmp_barycenter[i, :] = mesh_points.GetPoint(point_id) - barycenter = tmp_barycenter[:, 0].mean(), tmp_barycenter[:, 1].mean(), tmp_barycenter[:, 2].mean() + barycenter = tmp_barycenter[:, 0].mean(), tmp_barycenter[:, 1].mean( + ), tmp_barycenter[:, 2].mean() # Looping on all the triangles of the envelope of the polyhedron, creating the matching tetra. # Then the volume of all the tetra are added to get the final polyhedron volume. cell_volume = 0. @@ -85,7 +85,8 @@ def __compute_volume(mesh_points: vtkPoints, face_stream: FaceStream) -> float: triangle = triangles.GetCell(i) assert triangle.GetCellType() == VTK_TRIANGLE p = triangle.GetPoints() - cell_volume += vtkTetra.ComputeVolume(barycenter, p.GetPoint(0), p.GetPoint(1), p.GetPoint(2)) + cell_volume += vtkTetra.ComputeVolume(barycenter, p.GetPoint(0), + p.GetPoint(1), p.GetPoint(2)) return cell_volume @@ -106,14 +107,18 @@ def __select_and_flip_faces(mesh_points: vtkPoints, color_to_nodes[color] += connected_components_indices # This implementation works even if there is one unique color. # Admittedly, there will be one face stream that won't be flipped. - fs: Tuple[FaceStream, FaceStream] = face_stream.flip_faces(color_to_nodes[0]), face_stream.flip_faces(color_to_nodes[1]) - volumes = __compute_volume(mesh_points, fs[0]), __compute_volume(mesh_points, fs[1]) + fs: Tuple[FaceStream, + FaceStream] = (face_stream.flip_faces(color_to_nodes[0]), + face_stream.flip_faces(color_to_nodes[1])) + volumes = __compute_volume(mesh_points, + fs[0]), __compute_volume(mesh_points, fs[1]) # We keep the flipped element for which the volume is largest # (i.e. positive, since they should be the opposite of each other). return fs[numpy.argmax(volumes)] -def __reorient_element(mesh_points: vtkPoints, face_stream_ids: vtkIdList) -> vtkIdList: +def __reorient_element(mesh_points: vtkPoints, + face_stream_ids: vtkIdList) -> vtkIdList: """ Considers a vtk face stream and flips the appropriate faces to get an element with normals directed outwards. :param mesh_points: The mesh points, needed to compute the volume. @@ -121,22 +126,28 @@ def __reorient_element(mesh_points: vtkPoints, face_stream_ids: vtkIdList) -> vt :return: The raw vtk face stream with faces properly flipped. """ face_stream = FaceStream.build_from_vtk_id_list(face_stream_ids) - face_graph = build_face_to_face_connectivity_through_edges(face_stream, add_compatibility=True) + face_graph = build_face_to_face_connectivity_through_edges( + face_stream, add_compatibility=True) # Removing the non-compatible connections to build the non-connected components. g = networkx.Graph() g.add_nodes_from(face_graph.nodes) - g.add_edges_from(filter(lambda uvd: uvd[2]["compatible"] == "+", face_graph.edges(data=True))) + g.add_edges_from( + filter(lambda uvd: uvd[2]["compatible"] == "+", + face_graph.edges(data=True))) connected_components = tuple(networkx.connected_components(g)) # Squashing all the connected nodes that need to receive the normal direction flip (or not) together. - quotient_graph = networkx.algorithms.quotient_graph(face_graph, connected_components) + quotient_graph = networkx.algorithms.quotient_graph( + face_graph, connected_components) # Coloring the new graph lets us know how which cluster of faces need to eventually receive the same flip. # W.r.t. the nature of our problem (a normal can be directed inwards or outwards), # two colors should be enough to color the face graph. # `colors` maps the nodes of each connected component to its color. - colors: Dict[FrozenSet[int], int] = networkx.algorithms.greedy_color(quotient_graph) + colors: Dict[FrozenSet[int], + int] = networkx.algorithms.greedy_color(quotient_graph) assert len(colors) in (1, 2) # We now compute the face stream which generates outwards normal vectors. - flipped_face_stream = __select_and_flip_faces(mesh_points, colors, face_stream) + flipped_face_stream = __select_and_flip_faces(mesh_points, colors, + face_stream) return to_vtk_id_list(flipped_face_stream.dump()) @@ -157,8 +168,12 @@ def reorient_mesh(mesh, cell_indices: Iterator[int]) -> vtkUnstructuredGrid: # I did not manage to call `output_mesh.CopyStructure(mesh)` because I could not modify the polyhedron in place. # Therefore, I insert the cells one by one... output_mesh.SetPoints(mesh.GetPoints()) - logging.info("Reorienting the polyhedron cells to enforce normals directed outward.") - with tqdm(total=needs_to_be_reoriented.sum(), desc="Reorienting polyhedra") as progress_bar: # For smoother progress, we only update on reoriented elements. + logging.info( + "Reorienting the polyhedron cells to enforce normals directed outward." + ) + with tqdm( + total=needs_to_be_reoriented.sum(), desc="Reorienting polyhedra" + ) as progress_bar: # For smoother progress, we only update on reoriented elements. for ic in range(num_cells): cell = mesh.GetCell(ic) cell_type = cell.GetCellType() @@ -166,7 +181,8 @@ def reorient_mesh(mesh, cell_indices: Iterator[int]) -> vtkUnstructuredGrid: face_stream_ids = vtkIdList() mesh.GetFaceStream(ic, face_stream_ids) if needs_to_be_reoriented[ic]: - new_face_stream_ids = __reorient_element(mesh.GetPoints(), face_stream_ids) + new_face_stream_ids = __reorient_element( + mesh.GetPoints(), face_stream_ids) else: new_face_stream_ids = face_stream_ids output_mesh.InsertNextCell(VTK_POLYHEDRON, new_face_stream_ids) diff --git a/geosx_mesh_doctor/checks/self_intersecting_elements.py b/geosx_mesh_doctor/checks/self_intersecting_elements.py index 0e98d4f..b439362 100644 --- a/geosx_mesh_doctor/checks/self_intersecting_elements.py +++ b/geosx_mesh_doctor/checks/self_intersecting_elements.py @@ -5,16 +5,10 @@ List, ) -from vtkmodules.vtkFiltersGeneral import ( - vtkCellValidator -) -from vtkmodules.vtkCommonCore import ( - vtkOutputWindow, - vtkFileOutputWindow -) +from vtkmodules.vtkFiltersGeneral import (vtkCellValidator) +from vtkmodules.vtkCommonCore import (vtkOutputWindow, vtkFileOutputWindow) from vtkmodules.util.numpy_support import ( - vtk_to_numpy, -) + vtk_to_numpy, ) from . import vtk_utils @@ -36,7 +30,8 @@ class Result: def __check(mesh, options: Options) -> Result: err_out = vtkFileOutputWindow() - err_out.SetFileName("/dev/null") # vtkCellValidator outputs loads for each cell... + err_out.SetFileName( + "/dev/null") # vtkCellValidator outputs loads for each cell... vtk_std_err_out = vtkOutputWindow() vtk_std_err_out.SetInstance(err_out) @@ -62,7 +57,8 @@ def __check(mesh, options: Options) -> Result: f.Update() output = f.GetOutput() - validity = output.GetCellData().GetArray("ValidityState") # Could not change name using the vtk interface. + validity = output.GetCellData().GetArray( + "ValidityState") # Could not change name using the vtk interface. assert validity is not None validity = vtk_to_numpy(validity) for i, v in enumerate(validity): @@ -79,12 +75,14 @@ def __check(mesh, options: Options) -> Result: non_convex_elements.append(i) if v & faces_are_oriented_incorrectly: faces_are_oriented_incorrectly_elements.append(i) - return Result(wrong_number_of_points_elements=wrong_number_of_points_elements, - intersecting_edges_elements=intersecting_edges_elements, - intersecting_faces_elements=intersecting_faces_elements, - non_contiguous_edges_elements=non_contiguous_edges_elements, - non_convex_elements=non_convex_elements, - faces_are_oriented_incorrectly_elements=faces_are_oriented_incorrectly_elements) + return Result( + wrong_number_of_points_elements=wrong_number_of_points_elements, + intersecting_edges_elements=intersecting_edges_elements, + intersecting_faces_elements=intersecting_faces_elements, + non_contiguous_edges_elements=non_contiguous_edges_elements, + non_convex_elements=non_convex_elements, + faces_are_oriented_incorrectly_elements= + faces_are_oriented_incorrectly_elements) def check(vtk_input_file: str, options: Options) -> Result: diff --git a/geosx_mesh_doctor/checks/supported_elements.py b/geosx_mesh_doctor/checks/supported_elements.py index 84c5fcb..f451b23 100644 --- a/geosx_mesh_doctor/checks/supported_elements.py +++ b/geosx_mesh_doctor/checks/supported_elements.py @@ -17,8 +17,7 @@ import numpy from vtkmodules.vtkCommonCore import ( - vtkIdList, -) + vtkIdList, ) from vtkmodules.vtkCommonDataModel import ( vtkCellTypes, vtkUnstructuredGrid, @@ -32,8 +31,7 @@ VTK_WEDGE, ) from vtkmodules.util.numpy_support import ( - vtk_to_numpy, -) + vtk_to_numpy, ) from . import vtk_utils from .vtk_utils import vtk_iter @@ -49,13 +47,16 @@ class Options: @dataclass(frozen=True) class Result: unsupported_std_elements_types: FrozenSet[int] # list of unsupported types - unsupported_polyhedron_elements: FrozenSet[int] # list of polyhedron elements that could not be converted to supported std elements + unsupported_polyhedron_elements: FrozenSet[ + int] # list of polyhedron elements that could not be converted to supported std elements -MESH: Optional[vtkUnstructuredGrid] = None # for multiprocessing, vtkUnstructuredGrid cannot be pickled. Let's use a global variable instead. +MESH: Optional[ + vtkUnstructuredGrid] = None # for multiprocessing, vtkUnstructuredGrid cannot be pickled. Let's use a global variable instead. class IsPolyhedronConvertible: + def __init__(self, mesh: vtkUnstructuredGrid): global MESH # for multiprocessing, vtkUnstructuredGrid cannot be pickled. Let's use a global variable instead. MESH = mesh @@ -78,18 +79,19 @@ def build_prism_graph(n: int, name: str) -> networkx.Graph: tet_graph = networkx.complete_graph(4) tet_graph.name = "Tetrahedron" pyr_graph = build_prism_graph(4, "Pyramid") - pyr_graph.remove_node(5) # Removing a node also removes its associated edges. + pyr_graph.remove_node( + 5) # Removing a node also removes its associated edges. self.__reference_graphs: Mapping[int, Iterable[networkx.Graph]] = { - 4: (tet_graph,), + 4: (tet_graph, ), 5: (pyr_graph, build_prism_graph(3, "Wedge")), - 6: (build_prism_graph(4, "Hexahedron"),), - 7: (build_prism_graph(5, "Prism5"),), - 8: (build_prism_graph(6, "Prism6"),), - 9: (build_prism_graph(7, "Prism7"),), - 10: (build_prism_graph(8, "Prism8"),), - 11: (build_prism_graph(9, "Prism9"),), - 12: (build_prism_graph(10, "Prism10"),), - 13: (build_prism_graph(11, "Prism11"),), + 6: (build_prism_graph(4, "Hexahedron"), ), + 7: (build_prism_graph(5, "Prism5"), ), + 8: (build_prism_graph(6, "Prism6"), ), + 9: (build_prism_graph(7, "Prism7"), ), + 10: (build_prism_graph(8, "Prism8"), ), + 11: (build_prism_graph(9, "Prism9"), ), + 12: (build_prism_graph(10, "Prism10"), ), + 13: (build_prism_graph(11, "Prism11"), ), } def __is_polyhedron_supported(self, face_stream) -> str: @@ -99,7 +101,8 @@ def __is_polyhedron_supported(self, face_stream) -> str: :param face_stream: The polyhedron. :return: The name of the supported type or an empty string. """ - cell_graph = build_face_to_face_connectivity_through_edges(face_stream, add_compatibility=True) + cell_graph = build_face_to_face_connectivity_through_edges( + face_stream, add_compatibility=True) for reference_graph in self.__reference_graphs[cell_graph.order()]: if networkx.is_isomorphic(reference_graph, cell_graph): return str(reference_graph.name) @@ -120,29 +123,29 @@ def __call__(self, ic: int) -> int: face_stream = FaceStream.build_from_vtk_id_list(pt_ids) converted_type_name = self.__is_polyhedron_supported(face_stream) if converted_type_name: - logging.debug(f"Polyhedron cell {ic} can be converted into \"{converted_type_name}\"") + logging.debug( + f"Polyhedron cell {ic} can be converted into \"{converted_type_name}\"" + ) return -1 else: - logging.debug(f"Polyhedron cell {ic} cannot be converted into any supported element.") + logging.debug( + f"Polyhedron cell {ic} cannot be converted into any supported element." + ) return ic def __check(mesh: vtkUnstructuredGrid, options: Options) -> Result: - if hasattr(mesh, "GetDistinctCellTypesArray"): # For more recent versions of vtk. + if hasattr( + mesh, + "GetDistinctCellTypesArray"): # For more recent versions of vtk. cell_types = set(vtk_to_numpy(mesh.GetDistinctCellTypesArray())) else: cell_types = vtkCellTypes() mesh.GetCellTypes(cell_types) cell_types = set(vtk_iter(cell_types)) supported_cell_types = { - VTK_HEXAGONAL_PRISM, - VTK_HEXAHEDRON, - VTK_PENTAGONAL_PRISM, - VTK_POLYHEDRON, - VTK_PYRAMID, - VTK_TETRA, - VTK_VOXEL, - VTK_WEDGE + VTK_HEXAGONAL_PRISM, VTK_HEXAHEDRON, VTK_PENTAGONAL_PRISM, + VTK_POLYHEDRON, VTK_PYRAMID, VTK_TETRA, VTK_VOXEL, VTK_WEDGE } unsupported_std_elements_types = cell_types - supported_cell_types @@ -150,12 +153,19 @@ def __check(mesh: vtkUnstructuredGrid, options: Options) -> Result: num_cells = mesh.GetNumberOfCells() result = numpy.ones(num_cells, dtype=int) * -1 with multiprocessing.Pool(processes=options.num_proc) as pool: - generator = pool.imap_unordered(IsPolyhedronConvertible(mesh), range(num_cells), chunksize=options.chunk_size) - for i, val in enumerate(tqdm(generator, total=num_cells, desc="Testing support for elements")): + generator = pool.imap_unordered(IsPolyhedronConvertible(mesh), + range(num_cells), + chunksize=options.chunk_size) + for i, val in enumerate( + tqdm(generator, + total=num_cells, + desc="Testing support for elements")): result[i] = val unsupported_polyhedron_elements = [i for i in result if i > -1] - return Result(unsupported_std_elements_types=frozenset(unsupported_std_elements_types), - unsupported_polyhedron_elements=frozenset(unsupported_polyhedron_elements)) + return Result(unsupported_std_elements_types=frozenset( + unsupported_std_elements_types), + unsupported_polyhedron_elements=frozenset( + unsupported_polyhedron_elements)) def check(vtk_input_file: str, options: Options) -> Result: diff --git a/geosx_mesh_doctor/checks/triangle_distance.py b/geosx_mesh_doctor/checks/triangle_distance.py index ef1f3c9..068acbb 100644 --- a/geosx_mesh_doctor/checks/triangle_distance.py +++ b/geosx_mesh_doctor/checks/triangle_distance.py @@ -6,7 +6,7 @@ from numpy.linalg import norm -def __div_clamp(num: float, den :float) -> float: +def __div_clamp(num: float, den: float) -> float: """ Computes the division `num / den`. and clamps the result between 0 and 1. If `den` is zero, the result of the division is set to 0. @@ -25,8 +25,9 @@ def __div_clamp(num: float, den :float) -> float: return tmp -def distance_between_two_segments(x0: numpy.ndarray, d0: numpy.ndarray, - x1: numpy.ndarray, d1: numpy.ndarray) -> Tuple[numpy.ndarray, numpy.ndarray]: +def distance_between_two_segments( + x0: numpy.ndarray, d0: numpy.ndarray, x1: numpy.ndarray, + d1: numpy.ndarray) -> Tuple[numpy.ndarray, numpy.ndarray]: """ Compute the minimum distance between two segments. :param x0: First point of segment 0. @@ -41,7 +42,8 @@ def distance_between_two_segments(x0: numpy.ndarray, d0: numpy.ndarray, # In the reference, the indices start at 1, while in this implementation, they start at 0. tmp: numpy.ndarray = x1 - x0 - D0: float = numpy.dot(d0, d0) # As such, this is D1 in the reference paper. + D0: float = numpy.dot(d0, + d0) # As such, this is D1 in the reference paper. D1: float = numpy.dot(d1, d1) R: float = numpy.dot(d0, d1) S0: float = numpy.dot(d0, tmp) @@ -60,14 +62,17 @@ def distance_between_two_segments(x0: numpy.ndarray, d0: numpy.ndarray, # Step 3: compute t1 for point on line 1 closest to point at t0. t1: float = __div_clamp(t0 * R - S1, D1) # Eq (10, right) - sol_1: numpy.ndarray = x1 + t1 * d1 # Eq (3) + sol_1: numpy.ndarray = x1 + t1 * d1 # Eq (3) t0: float = __div_clamp(t1 * R + S0, D0) # Eq (10, left) - sol_0: numpy.ndarray = x0 + t0 * d0 # Eq (4) + sol_0: numpy.ndarray = x0 + t0 * d0 # Eq (4) return sol_0, sol_1 -def __compute_nodes_to_triangle_distance(tri_0, edges_0, tri_1) -> Tuple[Union[float, None], Union[numpy.ndarray, None], Union[numpy.ndarray, None], bool]: +def __compute_nodes_to_triangle_distance( + tri_0, edges_0, tri_1 +) -> Tuple[Union[float, None], Union[numpy.ndarray, None], Union[numpy.ndarray, + None], bool]: """ Computes the distance from nodes of `tri_1` points onto `tri_0`. :param tri_0: First triangle. @@ -102,18 +107,24 @@ def __compute_nodes_to_triangle_distance(tri_0, edges_0, tri_1) -> Tuple[Union[f if point > -1: are_disjoint = True # But we must check that its projection is inside `tri_0`. - if numpy.dot(tri_1[point] - tri_0[0], numpy.cross(tri_0_normal, edges_0[0])) > 0: - if numpy.dot(tri_1[point] - tri_0[1], numpy.cross(tri_0_normal, edges_0[1])) > 0: - if numpy.dot(tri_1[point] - tri_0[2], numpy.cross(tri_0_normal, edges_0[2])) > 0: + if numpy.dot(tri_1[point] - tri_0[0], + numpy.cross(tri_0_normal, edges_0[0])) > 0: + if numpy.dot(tri_1[point] - tri_0[1], + numpy.cross(tri_0_normal, edges_0[1])) > 0: + if numpy.dot(tri_1[point] - tri_0[2], + numpy.cross(tri_0_normal, edges_0[2])) > 0: # It is! sol_0 = tri_1[point] - sol_1 = tri_1[point] + (tri_1_proj[point] / tri_0_normal_norm) * tri_0_normal + sol_1 = tri_1[point] + ( + tri_1_proj[point] / + tri_0_normal_norm) * tri_0_normal return norm(sol_1 - sol_0), sol_0, sol_1, are_disjoint return None, None, None, are_disjoint -def distance_between_two_triangles(tri_0: numpy.ndarray, - tri_1: numpy.ndarray) -> Tuple[float, numpy.ndarray, numpy.ndarray]: +def distance_between_two_triangles( + tri_0: numpy.ndarray, + tri_1: numpy.ndarray) -> Tuple[float, numpy.ndarray, numpy.ndarray]: """ Returns the minimum distance between two triangles, and the two points where this minimum occurs. If the two triangles touch, then distance is exactly 0. @@ -138,7 +149,8 @@ def distance_between_two_triangles(tri_0: numpy.ndarray, # Looping over all the pair of edges. for i, j in itertools.product(range(3), repeat=2): # Find the closest points on edges i and j. - sol_0, sol_1 = distance_between_two_segments(tri_0[i], edges_0[i], tri_1[j], edges_1[j]) + sol_0, sol_1 = distance_between_two_segments(tri_0[i], edges_0[i], + tri_1[j], edges_1[j]) # Computing the distance between the two solutions. delta_sol = sol_1 - sol_0 dist: float = numpy.dot(delta_sol, delta_sol) @@ -168,12 +180,14 @@ def distance_between_two_triangles(tri_0: numpy.ndarray, are_disjoint = True # No edge pair contained the closest points. # Checking the node/face situation. - distance, sol_0, sol_1, are_disjoint_tmp = __compute_nodes_to_triangle_distance(tri_0, edges_0, tri_1) + distance, sol_0, sol_1, are_disjoint_tmp = __compute_nodes_to_triangle_distance( + tri_0, edges_0, tri_1) if distance: return distance, sol_0, sol_1 are_disjoint = are_disjoint or are_disjoint_tmp - distance, sol_0, sol_1, are_disjoint_tmp = __compute_nodes_to_triangle_distance(tri_1, edges_1, tri_0) + distance, sol_0, sol_1, are_disjoint_tmp = __compute_nodes_to_triangle_distance( + tri_1, edges_1, tri_0) if distance: return distance, sol_0, sol_1 are_disjoint = are_disjoint or are_disjoint_tmp diff --git a/geosx_mesh_doctor/checks/vtk_polyhedron.py b/geosx_mesh_doctor/checks/vtk_polyhedron.py index e246a57..b42c48e 100644 --- a/geosx_mesh_doctor/checks/vtk_polyhedron.py +++ b/geosx_mesh_doctor/checks/vtk_polyhedron.py @@ -11,14 +11,12 @@ ) from vtkmodules.vtkCommonCore import ( - vtkIdList, -) + vtkIdList, ) import networkx from .vtk_utils import ( - vtk_iter, -) + vtk_iter, ) @dataclass(frozen=True) @@ -60,6 +58,7 @@ class FaceStream: """ Helper class to manipulate the vtk face streams. """ + def __init__(self, data: Sequence[Sequence[int]]): # self.__data contains the list of faces nodes, like it appears in vtk face streams. # Except that the additional size information is removed @@ -126,7 +125,9 @@ def flip_faces(self, face_indices): """ result = [] for face_index, face_nodes in enumerate(self.__data): - result.append(tuple(reversed(face_nodes)) if face_index in face_indices else face_nodes) + result.append( + tuple(reversed(face_nodes)) if face_index in + face_indices else face_nodes) return FaceStream(tuple(result)) def dump(self) -> Sequence[int]: @@ -149,7 +150,9 @@ def __repr__(self): return ",\n".join(result) -def build_face_to_face_connectivity_through_edges(face_stream: FaceStream, add_compatibility=False) -> networkx.Graph: +def build_face_to_face_connectivity_through_edges(face_stream: FaceStream, + add_compatibility=False + ) -> networkx.Graph: """ Given a face stream/polyhedron, builds the connections between the faces. Those connections happen when two faces share an edge. @@ -165,7 +168,8 @@ def build_face_to_face_connectivity_through_edges(face_stream: FaceStream, add_c for face_index, face_nodes in enumerate(face_stream.face_nodes): # Each edge is defined by two nodes. We do a small trick to loop on consecutive points. face_indices: Tuple[int, int] - for face_indices in zip(face_nodes, face_nodes[1:] + (face_nodes[0], )): + for face_indices in zip(face_nodes, + face_nodes[1:] + (face_nodes[0], )): edges_to_face_indices[frozenset(face_indices)].append(face_index) # We are doing here some small validations w.r.t. the connections of the faces # which may only make sense in the context of numerical simulations. @@ -194,11 +198,15 @@ def build_face_to_face_connectivity_through_edges(face_stream: FaceStream, add_c graph.add_nodes_from(range(face_stream.num_faces)) for edge, face_indices in edges_to_face_indices.items(): face_index_0, face_index_1 = face_indices - face_nodes_0 = face_stream[face_index_0] + (face_stream[face_index_0][0],) - face_nodes_1 = face_stream[face_index_1] + (face_stream[face_index_1][0],) + face_nodes_0 = face_stream[face_index_0] + ( + face_stream[face_index_0][0], ) + face_nodes_1 = face_stream[face_index_1] + ( + face_stream[face_index_1][0], ) node_0, node_1 = edge - order_0 = 1 if face_nodes_0[face_nodes_0.index(node_0) + 1] == node_1 else -1 - order_1 = 1 if face_nodes_1[face_nodes_1.index(node_0) + 1] == node_1 else -1 + order_0 = 1 if face_nodes_0[face_nodes_0.index(node_0) + + 1] == node_1 else -1 + order_1 = 1 if face_nodes_1[face_nodes_1.index(node_0) + + 1] == node_1 else -1 # Same order of nodes means that the normals of the faces # are _not_ both in the same "direction" (inward or outward). if order_0 * order_1 == 1: diff --git a/geosx_mesh_doctor/checks/vtk_utils.py b/geosx_mesh_doctor/checks/vtk_utils.py index 2604609..dc202de 100644 --- a/geosx_mesh_doctor/checks/vtk_utils.py +++ b/geosx_mesh_doctor/checks/vtk_utils.py @@ -9,11 +9,9 @@ ) from vtkmodules.vtkCommonCore import ( - vtkIdList, -) + vtkIdList, ) from vtkmodules.vtkCommonDataModel import ( - vtkUnstructuredGrid, -) + vtkUnstructuredGrid, ) from vtkmodules.vtkIOLegacy import ( vtkUnstructuredGridWriter, vtkUnstructuredGridReader, @@ -54,10 +52,14 @@ def vtk_iter(l) -> Iterator[Any]: def __read_vtk(vtk_input_file: str) -> Optional[vtkUnstructuredGrid]: reader = vtkUnstructuredGridReader() - logging.info(f"Testing file format \"{vtk_input_file}\" using legacy format reader...") + logging.info( + f"Testing file format \"{vtk_input_file}\" using legacy format reader..." + ) reader.SetFileName(vtk_input_file) if reader.IsFileUnstructuredGrid(): - logging.info(f"Reader matches. Reading file \"{vtk_input_file}\" using legacy format reader.") + logging.info( + f"Reader matches. Reading file \"{vtk_input_file}\" using legacy format reader." + ) reader.Update() return reader.GetOutput() else: @@ -67,10 +69,13 @@ def __read_vtk(vtk_input_file: str) -> Optional[vtkUnstructuredGrid]: def __read_vtu(vtk_input_file: str) -> Optional[vtkUnstructuredGrid]: reader = vtkXMLUnstructuredGridReader() - logging.info(f"Testing file format \"{vtk_input_file}\" using XML format reader...") + logging.info( + f"Testing file format \"{vtk_input_file}\" using XML format reader...") if reader.CanReadFile(vtk_input_file): reader.SetFileName(vtk_input_file) - logging.info(f"Reader matches. Reading file \"{vtk_input_file}\" using XML format reader.") + logging.info( + f"Reader matches. Reading file \"{vtk_input_file}\" using XML format reader." + ) reader.Update() return reader.GetOutput() else: @@ -86,8 +91,7 @@ def read_mesh(vtk_input_file: str) -> vtkUnstructuredGrid: :return: A unstructured grid. """ file_extension = os.path.splitext(vtk_input_file)[-1] - extension_to_reader = {".vtk": __read_vtk, - ".vtu": __read_vtu} + extension_to_reader = {".vtk": __read_vtk, ".vtu": __read_vtu} # Testing first the reader that should match if file_extension in extension_to_reader: output_mesh = extension_to_reader.pop(file_extension)(vtk_input_file) @@ -99,7 +103,9 @@ def read_mesh(vtk_input_file: str) -> vtkUnstructuredGrid: if output_mesh: return output_mesh # No reader did work. Dying. - logging.critical(f"Could not find the appropriate VTK reader for file \"{vtk_input_file}\". Dying...") + logging.critical( + f"Could not find the appropriate VTK reader for file \"{vtk_input_file}\". Dying..." + ) sys.exit(1) @@ -111,12 +117,14 @@ def __write_vtk(mesh: vtkUnstructuredGrid, output: str) -> int: return writer.Write() -def __write_vtu(mesh: vtkUnstructuredGrid, output: str, is_data_mode_binary: bool) -> int: +def __write_vtu(mesh: vtkUnstructuredGrid, output: str, + is_data_mode_binary: bool) -> int: logging.info(f"Writing mesh into file \"{output}\" using XML format.") writer = vtkXMLUnstructuredGridWriter() writer.SetFileName(output) writer.SetInputData(mesh) - writer.SetDataModeToBinary() if is_data_mode_binary else writer.SetDataModeToAscii() + writer.SetDataModeToBinary( + ) if is_data_mode_binary else writer.SetDataModeToAscii() return writer.Write() @@ -129,15 +137,19 @@ def write_mesh(mesh: vtkUnstructuredGrid, vtk_output: VtkOutput) -> int: :return: 0 in case of success. """ if os.path.exists(vtk_output.output): - logging.error(f"File \"{vtk_output.output}\" already exists, nothing done.") + logging.error( + f"File \"{vtk_output.output}\" already exists, nothing done.") return 1 file_extension = os.path.splitext(vtk_output.output)[-1] if file_extension == ".vtk": success_code = __write_vtk(mesh, vtk_output.output) elif file_extension == ".vtu": - success_code = __write_vtu(mesh, vtk_output.output, vtk_output.is_data_mode_binary) + success_code = __write_vtu(mesh, vtk_output.output, + vtk_output.is_data_mode_binary) else: # No writer found did work. Dying. - logging.critical(f"Could not find the appropriate VTK writer for extension \"{file_extension}\". Dying...") + logging.critical( + f"Could not find the appropriate VTK writer for extension \"{file_extension}\". Dying..." + ) sys.exit(1) return 0 if success_code else 2 # the Write member function return 1 in case of success, 0 otherwise. diff --git a/geosx_mesh_doctor/mesh_doctor.py b/geosx_mesh_doctor/mesh_doctor.py index f28cc7e..4e2de8e 100644 --- a/geosx_mesh_doctor/mesh_doctor.py +++ b/geosx_mesh_doctor/mesh_doctor.py @@ -4,7 +4,9 @@ min_python_version = (3, 7) assert sys.version_info >= min_python_version except AssertionError as e: - print(f"Please update python to at least version {'.'.join(map(str, min_python_version))}.") + print( + f"Please update python to at least version {'.'.join(map(str, min_python_version))}." + ) sys.exit(1) import logging diff --git a/geosx_mesh_doctor/parsing/__init__.py b/geosx_mesh_doctor/parsing/__init__.py index 0d06f73..7674469 100644 --- a/geosx_mesh_doctor/parsing/__init__.py +++ b/geosx_mesh_doctor/parsing/__init__.py @@ -2,7 +2,6 @@ from dataclasses import dataclass from typing import Callable, Any - COLLOCATES_NODES = "collocated_nodes" ELEMENT_VOLUMES = "element_volumes" FIX_ELEMENTS_ORDERINGS = "fix_elements_orderings" diff --git a/geosx_mesh_doctor/parsing/cli_parsing.py b/geosx_mesh_doctor/parsing/cli_parsing.py index a2eb20e..69c9e51 100644 --- a/geosx_mesh_doctor/parsing/cli_parsing.py +++ b/geosx_mesh_doctor/parsing/cli_parsing.py @@ -47,23 +47,28 @@ def init_parser() -> argparse.ArgumentParser: An option may be missing because of an unloaded module. Increase verbosity (-{__VERBOSITY_FLAG}, -{__VERBOSITY_FLAG * 2}) to get full information. """ - formatter = lambda prog: argparse.RawTextHelpFormatter(prog, max_help_position=8) + formatter = lambda prog: argparse.RawTextHelpFormatter(prog, + max_help_position=8) parser = argparse.ArgumentParser(description='Inspects meshes for GEOSX.', epilog=textwrap.dedent(epilog_msg), formatter_class=formatter) # Nothing will be done with this verbosity/quiet input. # It's only here for the `--help` message. # `parse_verbosity` does the real parsing instead. - parser.add_argument('-' + __VERBOSITY_FLAG, - action='count', - default=2, - dest=__VERBOSE_KEY, - help=f"Use -{__VERBOSITY_FLAG} 'INFO', -{__VERBOSITY_FLAG * 2} for 'DEBUG'. Defaults to 'WARNING'.") - parser.add_argument('-' + __QUIET_FLAG, - action='count', - default=0, - dest=__QUIET_KEY, - help=f"Use -{__QUIET_FLAG} to reduce the verbosity of the output.") + parser.add_argument( + '-' + __VERBOSITY_FLAG, + action='count', + default=2, + dest=__VERBOSE_KEY, + help= + f"Use -{__VERBOSITY_FLAG} 'INFO', -{__VERBOSITY_FLAG * 2} for 'DEBUG'. Defaults to 'WARNING'." + ) + parser.add_argument( + '-' + __QUIET_FLAG, + action='count', + default=0, + dest=__QUIET_KEY, + help=f"Use -{__QUIET_FLAG} to reduce the verbosity of the output.") parser.add_argument('-i', '--vtk-input-file', metavar='VTK_MESH_FILE', diff --git a/geosx_mesh_doctor/parsing/collocated_nodes_parsing.py b/geosx_mesh_doctor/parsing/collocated_nodes_parsing.py index 421ae95..7721b9b 100644 --- a/geosx_mesh_doctor/parsing/collocated_nodes_parsing.py +++ b/geosx_mesh_doctor/parsing/collocated_nodes_parsing.py @@ -19,10 +19,13 @@ def convert(parsed_options) -> Options: def fill_subparser(subparsers) -> None: p = subparsers.add_parser(COLLOCATES_NODES, help="Checks if nodes are collocated.") - p.add_argument('--' + __TOLERANCE, - type=float, - required=True, - help="[float]: The absolute distance between two nodes for them to be considered collocated.") + p.add_argument( + '--' + __TOLERANCE, + type=float, + required=True, + help= + "[float]: The absolute distance between two nodes for them to be considered collocated." + ) def display_results(options: Options, result: Result): @@ -30,9 +33,12 @@ def display_results(options: Options, result: Result): for bucket in result.nodes_buckets: for node in bucket: all_collocated_nodes.append(node) - all_collocated_nodes: FrozenSet[int] = frozenset(all_collocated_nodes) # Surely useless + all_collocated_nodes: FrozenSet[int] = frozenset( + all_collocated_nodes) # Surely useless if all_collocated_nodes: - logging.error(f"You have {len(all_collocated_nodes)} collocated nodes (tolerance = {options.tolerance}).") + logging.error( + f"You have {len(all_collocated_nodes)} collocated nodes (tolerance = {options.tolerance})." + ) logging.info("Here are all the buckets of collocated nodes.") tmp: List[str] = [] @@ -40,10 +46,13 @@ def display_results(options: Options, result: Result): tmp.append(f"({', '.join(map(str, bucket))})") logging.info(f"({', '.join(tmp)})") else: - logging.error(f"You have no collocated node (tolerance = {options.tolerance}).") + logging.error( + f"You have no collocated node (tolerance = {options.tolerance}).") if result.wrong_support_elements: tmp: str = ", ".join(map(str, result.wrong_support_elements)) - logging.error(f"You have {len(result.wrong_support_elements)} elements with duplicated support nodes.\n" + tmp) + logging.error( + f"You have {len(result.wrong_support_elements)} elements with duplicated support nodes.\n" + + tmp) else: logging.error("You have no element with duplicated support nodes.") diff --git a/geosx_mesh_doctor/parsing/element_volumes_parsing.py b/geosx_mesh_doctor/parsing/element_volumes_parsing.py index 3b19682..d5c4212 100644 --- a/geosx_mesh_doctor/parsing/element_volumes_parsing.py +++ b/geosx_mesh_doctor/parsing/element_volumes_parsing.py @@ -9,14 +9,20 @@ def fill_subparser(subparsers) -> None: - p = subparsers.add_parser(ELEMENT_VOLUMES, - help=f"Checks if the volumes of the elements are greater than \"{__MIN_VOLUME}\".") - p.add_argument('--' + __MIN_VOLUME, - type=float, - metavar=__MIN_VOLUME_DEFAULT, - default=__MIN_VOLUME_DEFAULT, - required=True, - help=f"[float]: The minimum acceptable volume. Defaults to {__MIN_VOLUME_DEFAULT}.") + p = subparsers.add_parser( + ELEMENT_VOLUMES, + help= + f"Checks if the volumes of the elements are greater than \"{__MIN_VOLUME}\"." + ) + p.add_argument( + '--' + __MIN_VOLUME, + type=float, + metavar=__MIN_VOLUME_DEFAULT, + default=__MIN_VOLUME_DEFAULT, + required=True, + help= + f"[float]: The minimum acceptable volume. Defaults to {__MIN_VOLUME_DEFAULT}." + ) def convert(parsed_options) -> Options: @@ -29,6 +35,9 @@ def convert(parsed_options) -> Options: def display_results(options: Options, result: Result): - logging.error(f"You have {len(result.element_volumes)} elements with volumes smaller than {options.min_volume}.") + logging.error( + f"You have {len(result.element_volumes)} elements with volumes smaller than {options.min_volume}." + ) if result.element_volumes: - logging.error("The elements indices and their volumes are:\n" + "\n".join(map(str, result.element_volumes))) + logging.error("The elements indices and their volumes are:\n" + + "\n".join(map(str, result.element_volumes))) diff --git a/geosx_mesh_doctor/parsing/fix_elements_orderings_parsing.py b/geosx_mesh_doctor/parsing/fix_elements_orderings_parsing.py index c105792..c264a35 100644 --- a/geosx_mesh_doctor/parsing/fix_elements_orderings_parsing.py +++ b/geosx_mesh_doctor/parsing/fix_elements_orderings_parsing.py @@ -15,7 +15,6 @@ from . import vtk_output_parsing, FIX_ELEMENTS_ORDERINGS - __CELL_TYPE_MAPPING = { "Hexahedron": VTK_HEXAHEDRON, "Prism5": VTK_PENTAGONAL_PRISM, @@ -38,17 +37,19 @@ def fill_subparser(subparsers) -> None: - p = subparsers.add_parser(FIX_ELEMENTS_ORDERINGS, - help="Reorders the support nodes for the given cell types.") + p = subparsers.add_parser( + FIX_ELEMENTS_ORDERINGS, + help="Reorders the support nodes for the given cell types.") for key, vtk_key in __CELL_TYPE_MAPPING.items(): tmp = list(range(__CELL_TYPE_SUPPORT_SIZE[vtk_key])) random.Random(4).shuffle(tmp) - p.add_argument('--' + key, - type=str, - metavar=",".join(map(str, tmp)), - default=None, - required=False, - help=f"[list of integers]: node permutation for \"{key}\".") + p.add_argument( + '--' + key, + type=str, + metavar=",".join(map(str, tmp)), + default=None, + required=False, + help=f"[list of integers]: node permutation for \"{key}\".") vtk_output_parsing.fill_vtk_output_subparser(p) @@ -77,7 +78,9 @@ def display_results(options: Options, result: Result): if result.output: logging.info(f"New mesh was written to file '{result.output}'") if result.unchanged_cell_types: - logging.info(f"Those vtk types were not reordered: [{', '.join(map(str, result.unchanged_cell_types))}].") + logging.info( + f"Those vtk types were not reordered: [{', '.join(map(str, result.unchanged_cell_types))}]." + ) else: logging.info("All the cells of the mesh were reordered.") else: diff --git a/geosx_mesh_doctor/parsing/generate_cube_parsing.py b/geosx_mesh_doctor/parsing/generate_cube_parsing.py index 41c0e04..e14bdbb 100644 --- a/geosx_mesh_doctor/parsing/generate_cube_parsing.py +++ b/geosx_mesh_doctor/parsing/generate_cube_parsing.py @@ -5,15 +5,18 @@ from . import vtk_output_parsing, generate_global_ids_parsing, GENERATE_CUBE from .generate_global_ids_parsing import GlobalIdsInfo - __X, __Y, __Z, __NX, __NY, __NZ = "x", "y", "z", "nx", "ny", "nz" __FIELDS = "fields" def convert(parsed_options) -> Options: + def check_discretizations(x, nx, title): if len(x) != len(nx) + 1: - raise ValueError(f"{title} information (\"{x}\" and \"{nx}\") does not have consistent size.") + raise ValueError( + f"{title} information (\"{x}\" and \"{nx}\") does not have consistent size." + ) + check_discretizations(parsed_options[__X], parsed_options[__NX], __X) check_discretizations(parsed_options[__Y], parsed_options[__NY], __Y) check_discretizations(parsed_options[__Z], parsed_options[__NZ], __Z) @@ -21,17 +24,21 @@ def check_discretizations(x, nx, title): def parse_fields(s): name, support, dim = s.split(":") if support not in ("CELLS", "POINTS"): - raise ValueError(f"Support {support} for field \"{name}\" must be one of \"CELLS\" or \"POINTS\".") + raise ValueError( + f"Support {support} for field \"{name}\" must be one of \"CELLS\" or \"POINTS\"." + ) try: dim = int(dim) assert dim > 0 except ValueError: - raise ValueError(f"Dimension {dim} cannot be converted to an integer.") + raise ValueError( + f"Dimension {dim} cannot be converted to an integer.") except AssertionError: raise ValueError(f"Dimension {dim} must be a positive integer") return FieldInfo(name=name, support=support, dimension=dim) - gids: GlobalIdsInfo = generate_global_ids_parsing.convert_global_ids(parsed_options) + gids: GlobalIdsInfo = generate_global_ids_parsing.convert_global_ids( + parsed_options) return Options(vtk_output=vtk_output_parsing.convert(parsed_options), generate_cells_global_ids=gids.cells, @@ -60,25 +67,31 @@ def fill_subparser(subparsers) -> None: type=lambda s: tuple(map(float, s.split(":"))), metavar="0:1", help="[list of floats]: Z coordinates of the points.") - p.add_argument('--' + __NX, - type=lambda s: tuple(map(int, s.split(":"))), - metavar="2:2", - help="[list of integers]: Number of elements in the X direction.") - p.add_argument('--' + __NY, - type=lambda s: tuple(map(int, s.split(":"))), - metavar="1:1", - help="[list of integers]: Number of elements in the Y direction.") - p.add_argument('--' + __NZ, - type=lambda s: tuple(map(int, s.split(":"))), - metavar="4", - help="[list of integers]: Number of elements in the Z direction.") - p.add_argument('--' + __FIELDS, - type=str, - metavar="name:support:dim", - nargs="+", - required=False, - default=(), - help="Create fields on CELLS or POINTS, with given dimension (typically 1 or 3).") + p.add_argument( + '--' + __NX, + type=lambda s: tuple(map(int, s.split(":"))), + metavar="2:2", + help="[list of integers]: Number of elements in the X direction.") + p.add_argument( + '--' + __NY, + type=lambda s: tuple(map(int, s.split(":"))), + metavar="1:1", + help="[list of integers]: Number of elements in the Y direction.") + p.add_argument( + '--' + __NZ, + type=lambda s: tuple(map(int, s.split(":"))), + metavar="4", + help="[list of integers]: Number of elements in the Z direction.") + p.add_argument( + '--' + __FIELDS, + type=str, + metavar="name:support:dim", + nargs="+", + required=False, + default=(), + help= + "Create fields on CELLS or POINTS, with given dimension (typically 1 or 3)." + ) generate_global_ids_parsing.fill_generate_global_ids_subparser(p) vtk_output_parsing.fill_vtk_output_subparser(p) diff --git a/geosx_mesh_doctor/parsing/generate_fractures_parsing.py b/geosx_mesh_doctor/parsing/generate_fractures_parsing.py index 4789793..b3ef4a3 100644 --- a/geosx_mesh_doctor/parsing/generate_fractures_parsing.py +++ b/geosx_mesh_doctor/parsing/generate_fractures_parsing.py @@ -7,7 +7,7 @@ __POLICY = "policy" __FIELD_POLICY = "field" __INTERNAL_SURFACES_POLICY = "internal_surfaces" -__POLICIES = (__FIELD_POLICY, __INTERNAL_SURFACES_POLICY ) +__POLICIES = (__FIELD_POLICY, __INTERNAL_SURFACES_POLICY) __FIELD_NAME = "name" __FIELD_VALUES = "values" @@ -26,25 +26,37 @@ def convert_to_fracture_policy(s: str) -> FracturePolicy: return FracturePolicy.FIELD elif s == __INTERNAL_SURFACES_POLICY: return FracturePolicy.INTERNAL_SURFACES - raise ValueError(f"Policy {s} is not valid. Please use one of \"{', '.join(map(str, __POLICIES))}\".") + raise ValueError( + f"Policy {s} is not valid. Please use one of \"{', '.join(map(str, __POLICIES))}\"." + ) def fill_subparser(subparsers) -> None: - p = subparsers.add_parser(GENERATE_FRACTURES, - help="Splits the mesh to generate the faults and fractures. [EXPERIMENTAL]") - p.add_argument('--' + __POLICY, - type=convert_to_fracture_policy, - metavar=", ".join(__POLICIES), - required=True, - help=f"[string]: The criterion to define the surfaces that will be changed into fracture zones. " - f"Possible values are \"{', '.join(__POLICIES)}\"") - p.add_argument('--' + __FIELD_NAME, - type=str, - help=f"[string]: If the \"{__FIELD_POLICY}\" {__POLICY} is selected, defines which field will be considered to define the fractures. " - f"If the \"{__INTERNAL_SURFACES_POLICY}\" {__POLICY} is selected, defines the name of the attribute will be considered to identify the fractures. ") - p.add_argument('--' + __FIELD_VALUES, - type=str, - help=f"[list of comma separated integers]: If the \"{__FIELD_POLICY}\" {__POLICY} is selected, which changes of the field will be considered as a fracture. If the \"{__INTERNAL_SURFACES_POLICY}\" {__POLICY} is selected, list of the fracture attributes.") + p = subparsers.add_parser( + GENERATE_FRACTURES, + help= + "Splits the mesh to generate the faults and fractures. [EXPERIMENTAL]") + p.add_argument( + '--' + __POLICY, + type=convert_to_fracture_policy, + metavar=", ".join(__POLICIES), + required=True, + help= + f"[string]: The criterion to define the surfaces that will be changed into fracture zones. " + f"Possible values are \"{', '.join(__POLICIES)}\"") + p.add_argument( + '--' + __FIELD_NAME, + type=str, + help= + f"[string]: If the \"{__FIELD_POLICY}\" {__POLICY} is selected, defines which field will be considered to define the fractures. " + f"If the \"{__INTERNAL_SURFACES_POLICY}\" {__POLICY} is selected, defines the name of the attribute will be considered to identify the fractures. " + ) + p.add_argument( + '--' + __FIELD_VALUES, + type=str, + help= + f"[list of comma separated integers]: If the \"{__FIELD_POLICY}\" {__POLICY} is selected, which changes of the field will be considered as a fracture. If the \"{__INTERNAL_SURFACES_POLICY}\" {__POLICY} is selected, list of the fracture attributes." + ) vtk_output_parsing.fill_vtk_output_subparser(p) vtk_output_parsing.fill_vtk_output_subparser(p, prefix=__FRACTURE_PREFIX) @@ -52,9 +64,11 @@ def fill_subparser(subparsers) -> None: def convert(parsed_options) -> Options: policy = parsed_options[__POLICY] field = parsed_options[__FIELD_NAME] - field_values = frozenset(map(int, parsed_options[__FIELD_VALUES].split(","))) + field_values = frozenset( + map(int, parsed_options[__FIELD_VALUES].split(","))) vtk_output = vtk_output_parsing.convert(parsed_options) - vtk_fracture_output = vtk_output_parsing.convert(parsed_options, prefix=__FRACTURE_PREFIX) + vtk_fracture_output = vtk_output_parsing.convert(parsed_options, + prefix=__FRACTURE_PREFIX) return Options(policy=policy, field=field, field_values=field_values, diff --git a/geosx_mesh_doctor/parsing/generate_global_ids_parsing.py b/geosx_mesh_doctor/parsing/generate_global_ids_parsing.py index 730599a..68de1dc 100644 --- a/geosx_mesh_doctor/parsing/generate_global_ids_parsing.py +++ b/geosx_mesh_doctor/parsing/generate_global_ids_parsing.py @@ -5,7 +5,6 @@ from . import vtk_output_parsing, GENERATE_GLOBAL_IDS - __CELLS, __POINTS = "cells", "points" @@ -28,17 +27,19 @@ def convert(parsed_options) -> Options: def fill_generate_global_ids_subparser(p): - p.add_argument('--' + __CELLS, - action="store_true", - help=f"[bool]: Generate global ids for cells. Defaults to true.") + p.add_argument( + '--' + __CELLS, + action="store_true", + help=f"[bool]: Generate global ids for cells. Defaults to true.") p.add_argument('--no-' + __CELLS, action="store_false", dest=__CELLS, help=f"[bool]: Don't generate global ids for cells.") p.set_defaults(**{__CELLS: True}) - p.add_argument('--' + __POINTS, - action="store_true", - help=f"[bool]: Generate global ids for points. Defaults to true.") + p.add_argument( + '--' + __POINTS, + action="store_true", + help=f"[bool]: Generate global ids for points. Defaults to true.") p.add_argument('--no-' + __POINTS, action="store_false", dest=__POINTS, diff --git a/geosx_mesh_doctor/parsing/non_conformal_parsing.py b/geosx_mesh_doctor/parsing/non_conformal_parsing.py index 33625f6..2ebf0a6 100644 --- a/geosx_mesh_doctor/parsing/non_conformal_parsing.py +++ b/geosx_mesh_doctor/parsing/non_conformal_parsing.py @@ -1,8 +1,8 @@ import logging from typing import ( - FrozenSet, - List, + FrozenSet, + List, ) from checks.non_conformal import Options, Result @@ -25,19 +25,25 @@ def convert(parsed_options) -> Options: def fill_subparser(subparsers) -> None: - p = subparsers.add_parser(NON_CONFORMAL, - help="Detects non conformal elements. [EXPERIMENTAL]") - p.add_argument('--' + __ANGLE_TOLERANCE, - type=float, - metavar=__ANGLE_TOLERANCE_DEFAULT, - default=__ANGLE_TOLERANCE_DEFAULT, - help=f"[float]: angle tolerance in degrees. Defaults to {__ANGLE_TOLERANCE_DEFAULT}") - p.add_argument('--' + __POINT_TOLERANCE, - type=float, - help=f"[float]: tolerance for two points to be considered collocated.") - p.add_argument('--' + __FACE_TOLERANCE, - type=float, - help=f"[float]: tolerance for two faces to be considered \"touching\".") + p = subparsers.add_parser( + NON_CONFORMAL, help="Detects non conformal elements. [EXPERIMENTAL]") + p.add_argument( + '--' + __ANGLE_TOLERANCE, + type=float, + metavar=__ANGLE_TOLERANCE_DEFAULT, + default=__ANGLE_TOLERANCE_DEFAULT, + help= + f"[float]: angle tolerance in degrees. Defaults to {__ANGLE_TOLERANCE_DEFAULT}" + ) + p.add_argument( + '--' + __POINT_TOLERANCE, + type=float, + help=f"[float]: tolerance for two points to be considered collocated.") + p.add_argument( + '--' + __FACE_TOLERANCE, + type=float, + help=f"[float]: tolerance for two faces to be considered \"touching\"." + ) def display_results(options: Options, result: Result): @@ -45,4 +51,6 @@ def display_results(options: Options, result: Result): for i, j in result.non_conformal_cells: non_conformal_cells += i, j non_conformal_cells: FrozenSet[int] = frozenset(non_conformal_cells) - logging.error(f"You have {len(non_conformal_cells)} non conformal cells.\n{', '.join(map(str, sorted(non_conformal_cells)))}") + logging.error( + f"You have {len(non_conformal_cells)} non conformal cells.\n{', '.join(map(str, sorted(non_conformal_cells)))}" + ) diff --git a/geosx_mesh_doctor/parsing/self_intersecting_elements_parsing.py b/geosx_mesh_doctor/parsing/self_intersecting_elements_parsing.py index 70f5d6a..235dcdd 100644 --- a/geosx_mesh_doctor/parsing/self_intersecting_elements_parsing.py +++ b/geosx_mesh_doctor/parsing/self_intersecting_elements_parsing.py @@ -13,24 +13,35 @@ def convert(parsed_options) -> Options: tolerance = parsed_options[__TOLERANCE] if tolerance == 0: - logging.warning("Having tolerance set to 0 can induce lots of false positive results (adjacent faces may be considered intersecting).") + logging.warning( + "Having tolerance set to 0 can induce lots of false positive results (adjacent faces may be considered intersecting)." + ) elif tolerance < 0: - raise ValueError(f"Negative tolerance ({tolerance}) in the {SELF_INTERSECTING_ELEMENTS} check is not allowed.") + raise ValueError( + f"Negative tolerance ({tolerance}) in the {SELF_INTERSECTING_ELEMENTS} check is not allowed." + ) return Options(tolerance=tolerance) def fill_subparser(subparsers) -> None: - p = subparsers.add_parser(SELF_INTERSECTING_ELEMENTS, - help="Checks if the faces of the elements are self intersecting.") - p.add_argument('--' + __TOLERANCE, - type=float, - required=False, - metavar=__TOLERANCE_DEFAULT, - default=__TOLERANCE_DEFAULT, - help=f"[float]: The tolerance in the computation. Defaults to your machine precision {__TOLERANCE_DEFAULT}.") + p = subparsers.add_parser( + SELF_INTERSECTING_ELEMENTS, + help="Checks if the faces of the elements are self intersecting.") + p.add_argument( + '--' + __TOLERANCE, + type=float, + required=False, + metavar=__TOLERANCE_DEFAULT, + default=__TOLERANCE_DEFAULT, + help= + f"[float]: The tolerance in the computation. Defaults to your machine precision {__TOLERANCE_DEFAULT}." + ) def display_results(options: Options, result: Result): - logging.error(f"You have {len(result.intersecting_faces_elements)} elements with self intersecting faces.") + logging.error( + f"You have {len(result.intersecting_faces_elements)} elements with self intersecting faces." + ) if result.intersecting_faces_elements: - logging.error("The elements indices are:\n" + ", ".join(map(str, result.intersecting_faces_elements))) + logging.error("The elements indices are:\n" + + ", ".join(map(str, result.intersecting_faces_elements))) diff --git a/geosx_mesh_doctor/parsing/supported_elements_parsing.py b/geosx_mesh_doctor/parsing/supported_elements_parsing.py index c68905b..3710ed8 100644 --- a/geosx_mesh_doctor/parsing/supported_elements_parsing.py +++ b/geosx_mesh_doctor/parsing/supported_elements_parsing.py @@ -8,7 +8,6 @@ __CHUNK_SIZE = "chunck_size" __NUM_PROC = "nproc" - __ALL_KEYWORDS = {__CHUNK_SIZE, __NUM_PROC} __CHUNK_SIZE_DEFAULT = 1 @@ -21,29 +20,45 @@ def convert(parsed_options) -> Options: def fill_subparser(subparsers) -> None: - p = subparsers.add_parser(SUPPORTED_ELEMENTS, - help="Check that all the elements of the mesh are supported by GEOSX.") - p.add_argument('--' + __CHUNK_SIZE, - type=int, - required=False, - metavar=__CHUNK_SIZE_DEFAULT, - default=__CHUNK_SIZE_DEFAULT, - help=f"[int]: Defaults chunk size for parallel processing to {__CHUNK_SIZE_DEFAULT}") - p.add_argument('--' + __NUM_PROC, - type=int, - required=False, - metavar=__NUM_PROC_DEFAULT, - default=__NUM_PROC_DEFAULT, - help=f"[int]: Number of threads used for parallel processing. Defaults to your CPU count {__NUM_PROC_DEFAULT}.") + p = subparsers.add_parser( + SUPPORTED_ELEMENTS, + help="Check that all the elements of the mesh are supported by GEOSX.") + p.add_argument( + '--' + __CHUNK_SIZE, + type=int, + required=False, + metavar=__CHUNK_SIZE_DEFAULT, + default=__CHUNK_SIZE_DEFAULT, + help= + f"[int]: Defaults chunk size for parallel processing to {__CHUNK_SIZE_DEFAULT}" + ) + p.add_argument( + '--' + __NUM_PROC, + type=int, + required=False, + metavar=__NUM_PROC_DEFAULT, + default=__NUM_PROC_DEFAULT, + help= + f"[int]: Number of threads used for parallel processing. Defaults to your CPU count {__NUM_PROC_DEFAULT}." + ) def display_results(options: Options, result: Result): if result.unsupported_polyhedron_elements: - logging.error(f"There is/are {len(result.unsupported_polyhedron_elements)} polyhedra that may not be converted to supported elements.") - logging.error(f"The list of the unsupported polyhedra is\n{tuple(sorted(result.unsupported_polyhedron_elements))}.") + logging.error( + f"There is/are {len(result.unsupported_polyhedron_elements)} polyhedra that may not be converted to supported elements." + ) + logging.error( + f"The list of the unsupported polyhedra is\n{tuple(sorted(result.unsupported_polyhedron_elements))}." + ) else: - logging.info("All the polyhedra (if any) can be converted to supported elements.") + logging.info( + "All the polyhedra (if any) can be converted to supported elements." + ) if result.unsupported_std_elements_types: - logging.error(f"There are unsupported vtk standard element types. The list of those vtk types is {tuple(sorted(result.unsupported_std_elements_types))}.") + logging.error( + f"There are unsupported vtk standard element types. The list of those vtk types is {tuple(sorted(result.unsupported_std_elements_types))}." + ) else: - logging.info("All the standard vtk element types (if any) are supported.") \ No newline at end of file + logging.info( + "All the standard vtk element types (if any) are supported.") diff --git a/geosx_mesh_doctor/parsing/vtk_output_parsing.py b/geosx_mesh_doctor/parsing/vtk_output_parsing.py index 6e9b7d5..7dfbded 100644 --- a/geosx_mesh_doctor/parsing/vtk_output_parsing.py +++ b/geosx_mesh_doctor/parsing/vtk_output_parsing.py @@ -4,7 +4,6 @@ from checks.vtk_utils import VtkOutput - __OUTPUT_FILE = "output" __OUTPUT_BINARY_MODE = "data-mode" __OUTPUT_BINARY_MODE_VALUES = "binary", "ascii" @@ -27,19 +26,25 @@ def fill_vtk_output_subparser(parser, prefix="") -> None: type=str, required=True, help=f"[string]: The vtk output file destination.") - parser.add_argument('--' + __build_arg(prefix, __OUTPUT_BINARY_MODE), - type=str, - metavar=", ".join(__OUTPUT_BINARY_MODE_VALUES), - default=__OUTPUT_BINARY_MODE_DEFAULT, - help=f"""[string]: For ".vtu" output format, the data mode can be {" or ".join(__OUTPUT_BINARY_MODE_VALUES)}. Defaults to {__OUTPUT_BINARY_MODE_DEFAULT}.""") + parser.add_argument( + '--' + __build_arg(prefix, __OUTPUT_BINARY_MODE), + type=str, + metavar=", ".join(__OUTPUT_BINARY_MODE_VALUES), + default=__OUTPUT_BINARY_MODE_DEFAULT, + help= + f"""[string]: For ".vtu" output format, the data mode can be {" or ".join(__OUTPUT_BINARY_MODE_VALUES)}. Defaults to {__OUTPUT_BINARY_MODE_DEFAULT}.""" + ) def convert(parsed_options, prefix="") -> VtkOutput: output_key = __build_arg(prefix, __OUTPUT_FILE).replace("-", "_") - binary_mode_key = __build_arg(prefix, __OUTPUT_BINARY_MODE).replace("-", "_") + binary_mode_key = __build_arg(prefix, + __OUTPUT_BINARY_MODE).replace("-", "_") output = parsed_options[output_key] - if parsed_options[binary_mode_key] and os.path.splitext(output)[-1] == ".vtk": - logging.info("VTK data mode will be ignored for legacy file format \"vtk\".") - is_data_mode_binary: bool = parsed_options[binary_mode_key] == __OUTPUT_BINARY_MODE_DEFAULT - return VtkOutput(output=output, - is_data_mode_binary=is_data_mode_binary) + if parsed_options[binary_mode_key] and os.path.splitext( + output)[-1] == ".vtk": + logging.info( + "VTK data mode will be ignored for legacy file format \"vtk\".") + is_data_mode_binary: bool = parsed_options[ + binary_mode_key] == __OUTPUT_BINARY_MODE_DEFAULT + return VtkOutput(output=output, is_data_mode_binary=is_data_mode_binary) diff --git a/geosx_mesh_doctor/register.py b/geosx_mesh_doctor/register.py index a36001e..d00b4aa 100644 --- a/geosx_mesh_doctor/register.py +++ b/geosx_mesh_doctor/register.py @@ -6,7 +6,6 @@ import parsing from parsing import CheckHelper, cli_parsing - __HELPERS: Dict[str, Callable[[None], CheckHelper]] = dict() __CHECKS: Dict[str, Callable[[None], Any]] = dict() @@ -16,8 +15,10 @@ def __load_module_check(module_name: str, check_fct="check"): return getattr(module, check_fct) -def __load_module_check_helper(module_name: str, parsing_fct_suffix="_parsing"): - module = importlib.import_module("parsing." + module_name + parsing_fct_suffix) +def __load_module_check_helper(module_name: str, + parsing_fct_suffix="_parsing"): + module = importlib.import_module("parsing." + module_name + + parsing_fct_suffix) return CheckHelper(fill_subparser=module.fill_subparser, convert=module.convert, display_results=module.display_results) @@ -40,7 +41,8 @@ def __load_checks() -> Dict[str, Callable[[str, Any], Any]]: return loaded_checks -def register() -> Tuple[argparse.ArgumentParser, Dict[str, Callable[[str, Any], Any]], Dict[str, CheckHelper]]: +def register() -> Tuple[argparse.ArgumentParser, Dict[str, Callable[ + [str, Any], Any]], Dict[str, CheckHelper]]: """ Register all the parsing checks. Eventually initiate the registration of all the checks too. :return: The checks and the checks helpers. @@ -51,13 +53,11 @@ def register() -> Tuple[argparse.ArgumentParser, Dict[str, Callable[[str, Any], def closure_trick(cn: str): __HELPERS[check_name] = lambda: __load_module_check_helper(cn) __CHECKS[check_name] = lambda: __load_module_check(cn) + # Register the modules to load here. - for check_name in (parsing.COLLOCATES_NODES, - parsing.ELEMENT_VOLUMES, - parsing.FIX_ELEMENTS_ORDERINGS, - parsing.GENERATE_CUBE, - parsing.GENERATE_FRACTURES, - parsing.GENERATE_GLOBAL_IDS, + for check_name in (parsing.COLLOCATES_NODES, parsing.ELEMENT_VOLUMES, + parsing.FIX_ELEMENTS_ORDERINGS, parsing.GENERATE_CUBE, + parsing.GENERATE_FRACTURES, parsing.GENERATE_GLOBAL_IDS, parsing.NON_CONFORMAL, parsing.SELF_INTERSECTING_ELEMENTS, parsing.SUPPORTED_ELEMENTS): diff --git a/geosx_mesh_doctor/tests/test_cli_parsing.py b/geosx_mesh_doctor/tests/test_cli_parsing.py index 445b7c9..5bb92c9 100644 --- a/geosx_mesh_doctor/tests/test_cli_parsing.py +++ b/geosx_mesh_doctor/tests/test_cli_parsing.py @@ -9,8 +9,7 @@ import pytest from checks.vtk_utils import ( - VtkOutput, -) + VtkOutput, ) from checks.generate_fractures import ( FracturePolicy, @@ -37,13 +36,18 @@ def __generate_generate_fractures_parsing_test_data() -> Iterator[TestCase]: fracture_mesh: str = "fracture.vtu" cli_gen: str = f"generate_fractures --policy {{}} --name {field} --values 0,1 --output {main_mesh} --fracture-output {fracture_mesh}" - all_cli_args = cli_gen.format("field").split(), cli_gen.format("internal_surfaces").split(), cli_gen.format("dummy").split() + all_cli_args = cli_gen.format("field").split(), cli_gen.format( + "internal_surfaces").split(), cli_gen.format("dummy").split() policies = FracturePolicy.FIELD, FracturePolicy.INTERNAL_SURFACES, FracturePolicy.FIELD exceptions = False, False, True for cli_args, policy, exception in zip(all_cli_args, policies, exceptions): - options: Options = Options(policy=policy, field=field, field_values=frozenset((0, 1)), - vtk_output=VtkOutput(output=main_mesh, is_data_mode_binary=True), - vtk_fracture_output=VtkOutput(output=fracture_mesh, is_data_mode_binary=True)) + options: Options = Options( + policy=policy, + field=field, + field_values=frozenset((0, 1)), + vtk_output=VtkOutput(output=main_mesh, is_data_mode_binary=True), + vtk_fracture_output=VtkOutput(output=fracture_mesh, + is_data_mode_binary=True)) yield TestCase(cli_args, options, exception) @@ -63,7 +67,8 @@ def test_display_results(): display_results(None, None) -@pytest.mark.parametrize("test_case", __generate_generate_fractures_parsing_test_data()) +@pytest.mark.parametrize("test_case", + __generate_generate_fractures_parsing_test_data()) def test(test_case: TestCase): if test_case.exception: with pytest.raises(SystemExit): diff --git a/geosx_mesh_doctor/tests/test_collocated_nodes.py b/geosx_mesh_doctor/tests/test_collocated_nodes.py index 6936331..c29c40e 100644 --- a/geosx_mesh_doctor/tests/test_collocated_nodes.py +++ b/geosx_mesh_doctor/tests/test_collocated_nodes.py @@ -3,8 +3,7 @@ import pytest from vtkmodules.vtkCommonCore import ( - vtkPoints, -) + vtkPoints, ) from vtkmodules.vtkCommonDataModel import ( VTK_TETRA, vtkCellArray, diff --git a/geosx_mesh_doctor/tests/test_element_volumes.py b/geosx_mesh_doctor/tests/test_element_volumes.py index e37c22c..ef673f2 100644 --- a/geosx_mesh_doctor/tests/test_element_volumes.py +++ b/geosx_mesh_doctor/tests/test_element_volumes.py @@ -1,8 +1,7 @@ import numpy from vtkmodules.vtkCommonCore import ( - vtkPoints, -) + vtkPoints, ) from vtkmodules.vtkCommonDataModel import ( VTK_TETRA, vtkCellArray, @@ -41,7 +40,8 @@ def test_simple_tet(): assert len(result.element_volumes) == 1 assert result.element_volumes[0][0] == 0 - assert abs(result.element_volumes[0][1] - 1./6.) < 10 * numpy.finfo(float).eps + assert abs(result.element_volumes[0][1] - + 1. / 6.) < 10 * numpy.finfo(float).eps result = __check(mesh, Options(min_volume=0.)) diff --git a/geosx_mesh_doctor/tests/test_generate_cube.py b/geosx_mesh_doctor/tests/test_generate_cube.py index 4d93abd..cf0d9b0 100644 --- a/geosx_mesh_doctor/tests/test_generate_cube.py +++ b/geosx_mesh_doctor/tests/test_generate_cube.py @@ -2,20 +2,18 @@ def test_generate_cube(): - options = Options( - vtk_output=None, - generate_cells_global_ids=True, - generate_points_global_ids=False, - xs=(0, 5, 10), - ys=(0, 4, 8), - zs=(0, 1), - nxs=(5, 2), - nys=(1, 1), - nzs=(1,), - fields=( - FieldInfo(name="test", dimension=2, support="CELLS"), - ) - ) + options = Options(vtk_output=None, + generate_cells_global_ids=True, + generate_points_global_ids=False, + xs=(0, 5, 10), + ys=(0, 4, 8), + zs=(0, 1), + nxs=(5, 2), + nys=(1, 1), + nzs=(1, ), + fields=(FieldInfo(name="test", + dimension=2, + support="CELLS"), )) output = __build(options) assert output.GetNumberOfCells() == 14 assert output.GetNumberOfPoints() == 48 diff --git a/geosx_mesh_doctor/tests/test_generate_fractures.py b/geosx_mesh_doctor/tests/test_generate_fractures.py index f197731..fb19e06 100644 --- a/geosx_mesh_doctor/tests/test_generate_fractures.py +++ b/geosx_mesh_doctor/tests/test_generate_fractures.py @@ -18,12 +18,10 @@ VTK_QUAD, ) from vtkmodules.util.numpy_support import ( - numpy_to_vtk, -) + numpy_to_vtk, ) from checks.vtk_utils import ( - to_vtk_id_list, -) + to_vtk_id_list, ) from checks.check_fractures import format_collocated_nodes from checks.generate_cube import build_rectilinear_blocks_mesh, XYZ @@ -78,6 +76,7 @@ def __build_test_case(xs: Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray], # Utility class to generate the new indices of the newly created collocated nodes. class Incrementor: + def __init__(self, start): self.__val = start @@ -92,9 +91,12 @@ def __generate_test_data() -> Iterator[TestCase]: four_nodes = numpy.arange(4, dtype=float) # Split in 2 - mesh, options = __build_test_case((three_nodes, three_nodes, three_nodes), (0, 1, 0, 1, 0, 1, 0, 1)) - yield TestCase(input_mesh=mesh, options=options, - collocated_nodes=tuple(map(lambda i: (1 + 3 * i, 27 + i), range(9))), + mesh, options = __build_test_case((three_nodes, three_nodes, three_nodes), + (0, 1, 0, 1, 0, 1, 0, 1)) + yield TestCase(input_mesh=mesh, + options=options, + collocated_nodes=tuple( + map(lambda i: (1 + 3 * i, 27 + i), range(9))), result=TestResult(9 * 4, 8, 9, 4)) # Split in 3 @@ -113,8 +115,11 @@ def __generate_test_data() -> Iterator[TestCase]: (4 + 18, *inc.next(2)), (7 + 18, *inc.next(1)), ) - mesh, options = __build_test_case((three_nodes, three_nodes, three_nodes), (0, 1, 2, 1, 0, 1, 2, 1)) - yield TestCase(input_mesh=mesh, options=options, collocated_nodes=collocated_nodes, + mesh, options = __build_test_case((three_nodes, three_nodes, three_nodes), + (0, 1, 2, 1, 0, 1, 2, 1)) + yield TestCase(input_mesh=mesh, + options=options, + collocated_nodes=collocated_nodes, result=TestResult(9 * 4 + 6, 8, 12, 6)) # Split in 8 @@ -140,22 +145,29 @@ def __generate_test_data() -> Iterator[TestCase]: (5 + 18, *inc.next(1)), (7 + 18, *inc.next(1)), ) - mesh, options = __build_test_case((three_nodes, three_nodes, three_nodes), range(8)) - yield TestCase(input_mesh=mesh, options=options, collocated_nodes=collocated_nodes, + mesh, options = __build_test_case((three_nodes, three_nodes, three_nodes), + range(8)) + yield TestCase(input_mesh=mesh, + options=options, + collocated_nodes=collocated_nodes, result=TestResult(8 * 8, 8, 3 * 3 * 3 - 8, 12)) # Straight notch inc = Incrementor(27) collocated_nodes: Sequence[Sequence[int]] = ( (1, *inc.next(1)), - (4,), + (4, ), (1 + 9, *inc.next(1)), - (4 + 9,), + (4 + 9, ), (1 + 18, *inc.next(1)), - (4 + 18,), + (4 + 18, ), ) - mesh, options = __build_test_case((three_nodes, three_nodes, three_nodes), (0, 1, 2, 2, 0, 1, 2, 2), field_values=(0, 1)) - yield TestCase(input_mesh=mesh, options=options, collocated_nodes=collocated_nodes, + mesh, options = __build_test_case((three_nodes, three_nodes, three_nodes), + (0, 1, 2, 2, 0, 1, 2, 2), + field_values=(0, 1)) + yield TestCase(input_mesh=mesh, + options=options, + collocated_nodes=collocated_nodes, result=TestResult(3 * 3 * 3 + 3, 8, 6, 2)) # L-shaped notch @@ -165,13 +177,17 @@ def __generate_test_data() -> Iterator[TestCase]: (4, *inc.next(1)), (7, *inc.next(1)), (1 + 9, *inc.next(1)), - (4 + 9,), - (7 + 9,), + (4 + 9, ), + (7 + 9, ), (1 + 18, *inc.next(1)), - (4 + 18,), + (4 + 18, ), ) - mesh, options = __build_test_case((three_nodes, three_nodes, three_nodes), (0, 1, 0, 1, 0, 1, 2, 2), field_values=(0, 1)) - yield TestCase(input_mesh=mesh, options=options, collocated_nodes=collocated_nodes, + mesh, options = __build_test_case((three_nodes, three_nodes, three_nodes), + (0, 1, 0, 1, 0, 1, 2, 2), + field_values=(0, 1)) + yield TestCase(input_mesh=mesh, + options=options, + collocated_nodes=collocated_nodes, result=TestResult(3 * 3 * 3 + 5, 8, 8, 3)) # 3x1x1 split @@ -186,31 +202,46 @@ def __generate_test_data() -> Iterator[TestCase]: (5 + 8, *inc.next(1)), (6 + 8, *inc.next(1)), ) - mesh, options = __build_test_case((four_nodes, two_nodes, two_nodes), (0, 1, 2)) - yield TestCase(input_mesh=mesh, options=options, collocated_nodes=collocated_nodes, + mesh, options = __build_test_case((four_nodes, two_nodes, two_nodes), + (0, 1, 2)) + yield TestCase(input_mesh=mesh, + options=options, + collocated_nodes=collocated_nodes, result=TestResult(6 * 4, 3, 2 * 4, 2)) # Discarded fracture element if no node duplication. collocated_nodes: Sequence[Sequence[int]] = () - mesh, options = __build_test_case((three_nodes, four_nodes, four_nodes), [0, ] * 8 + [1, 2] + [0, ] * 8, field_values=(1, 2)) - yield TestCase(input_mesh=mesh, options=options, collocated_nodes=collocated_nodes, + mesh, options = __build_test_case((three_nodes, four_nodes, four_nodes), [ + 0, + ] * 8 + [1, 2] + [ + 0, + ] * 8, + field_values=(1, 2)) + yield TestCase(input_mesh=mesh, + options=options, + collocated_nodes=collocated_nodes, result=TestResult(3 * 4 * 4, 2 * 3 * 3, 0, 0)) # Fracture on a corner inc = Incrementor(3 * 4 * 4) collocated_nodes: Sequence[Sequence[int]] = ( - (1 + 12,), - (4 + 12,), - (7 + 12,), + (1 + 12, ), + (4 + 12, ), + (7 + 12, ), (1 + 12 * 2, *inc.next(1)), (4 + 12 * 2, *inc.next(1)), - (7 + 12 * 2,), + (7 + 12 * 2, ), (1 + 12 * 3, *inc.next(1)), (4 + 12 * 3, *inc.next(1)), - (7 + 12 * 3,), + (7 + 12 * 3, ), ) - mesh, options = __build_test_case((three_nodes, four_nodes, four_nodes), [0, ] * 6 + [1, 2, 1, 2, 0, 0, 1, 2, 1, 2, 0, 0], field_values=(1, 2)) - yield TestCase(input_mesh=mesh, options=options, collocated_nodes=collocated_nodes, + mesh, options = __build_test_case((three_nodes, four_nodes, four_nodes), [ + 0, + ] * 6 + [1, 2, 1, 2, 0, 0, 1, 2, 1, 2, 0, 0], + field_values=(1, 2)) + yield TestCase(input_mesh=mesh, + options=options, + collocated_nodes=collocated_nodes, result=TestResult(3 * 4 * 4 + 4, 2 * 3 * 3, 9, 4)) # Generate mesh with 2 hexs, one being a standard hex, the other a 42 hex. @@ -221,16 +252,23 @@ def __generate_test_data() -> Iterator[TestCase]: (1 + 6, *inc.next(1)), (1 + 9, *inc.next(1)), ) - mesh, options = __build_test_case((three_nodes, two_nodes, two_nodes), (0, 1)) + mesh, options = __build_test_case((three_nodes, two_nodes, two_nodes), + (0, 1)) polyhedron_mesh = vtkUnstructuredGrid() polyhedron_mesh.SetPoints(mesh.GetPoints()) polyhedron_mesh.Allocate(2) - polyhedron_mesh.InsertNextCell(VTK_HEXAHEDRON, to_vtk_id_list((1, 2, 5, 4, 7, 8, 10, 11))) - poly = to_vtk_id_list([6] + [4, 0, 1, 7, 6] + [4, 1, 4, 10, 7] + [4, 4, 3, 9, 10] + [4, 3, 0, 6, 9] + [4, 6, 7, 10, 9] + [4, 1, 0, 3, 4]) + polyhedron_mesh.InsertNextCell(VTK_HEXAHEDRON, + to_vtk_id_list((1, 2, 5, 4, 7, 8, 10, 11))) + poly = to_vtk_id_list([6] + [4, 0, 1, 7, 6] + [4, 1, 4, 10, 7] + + [4, 4, 3, 9, 10] + [4, 3, 0, 6, 9] + + [4, 6, 7, 10, 9] + [4, 1, 0, 3, 4]) polyhedron_mesh.InsertNextCell(VTK_POLYHEDRON, poly) - polyhedron_mesh.GetCellData().AddArray(mesh.GetCellData().GetArray("attribute")) + polyhedron_mesh.GetCellData().AddArray( + mesh.GetCellData().GetArray("attribute")) - yield TestCase(input_mesh=polyhedron_mesh, options=options, collocated_nodes=collocated_nodes, + yield TestCase(input_mesh=polyhedron_mesh, + options=options, + collocated_nodes=collocated_nodes, result=TestResult(4 * 4, 2, 4, 1)) # Split in 2 using the internal fracture description @@ -241,21 +279,29 @@ def __generate_test_data() -> Iterator[TestCase]: (1 + 6, *inc.next(1)), (1 + 9, *inc.next(1)), ) - mesh, options = __build_test_case((three_nodes, two_nodes, two_nodes), attribute=(0, 0, 0), field_values=(0,), + mesh, options = __build_test_case((three_nodes, two_nodes, two_nodes), + attribute=(0, 0, 0), + field_values=(0, ), policy=FracturePolicy.INTERNAL_SURFACES) - mesh.InsertNextCell(VTK_QUAD, to_vtk_id_list((1, 4, 7, 10))) # Add a fracture on the fly - yield TestCase(input_mesh=mesh, options=options, + mesh.InsertNextCell(VTK_QUAD, to_vtk_id_list( + (1, 4, 7, 10))) # Add a fracture on the fly + yield TestCase(input_mesh=mesh, + options=options, collocated_nodes=collocated_nodes, result=TestResult(4 * 4, 3, 4, 1)) @pytest.mark.parametrize("test_case", __generate_test_data()) def test_generate_fracture(test_case: TestCase): - main_mesh, fracture_mesh = __split_mesh_on_fracture(test_case.input_mesh, test_case.options) - assert main_mesh.GetNumberOfPoints() == test_case.result.main_mesh_num_points + main_mesh, fracture_mesh = __split_mesh_on_fracture( + test_case.input_mesh, test_case.options) + assert main_mesh.GetNumberOfPoints( + ) == test_case.result.main_mesh_num_points assert main_mesh.GetNumberOfCells() == test_case.result.main_mesh_num_cells - assert fracture_mesh.GetNumberOfPoints() == test_case.result.fracture_mesh_num_points - assert fracture_mesh.GetNumberOfCells() == test_case.result.fracture_mesh_num_cells + assert fracture_mesh.GetNumberOfPoints( + ) == test_case.result.fracture_mesh_num_points + assert fracture_mesh.GetNumberOfCells( + ) == test_case.result.fracture_mesh_num_cells res = format_collocated_nodes(fracture_mesh) assert res == test_case.collocated_nodes diff --git a/geosx_mesh_doctor/tests/test_non_conformal.py b/geosx_mesh_doctor/tests/test_non_conformal.py index bcf60fe..cf676c7 100644 --- a/geosx_mesh_doctor/tests/test_non_conformal.py +++ b/geosx_mesh_doctor/tests/test_non_conformal.py @@ -15,13 +15,17 @@ def test_two_close_hexs(): mesh = build_rectilinear_blocks_mesh((xyz0, xyz1)) # Close enough, but points tolerance is too strict to consider the faces matching. - options = Options(angle_tolerance=1., point_tolerance=delta / 2, face_tolerance=delta * 2) + options = Options(angle_tolerance=1., + point_tolerance=delta / 2, + face_tolerance=delta * 2) results = __check(mesh, options) assert len(results.non_conformal_cells) == 1 assert set(results.non_conformal_cells[0]) == {0, 1} # Close enough, and points tolerance is loose enough to consider the faces matching. - options = Options(angle_tolerance=1., point_tolerance=delta * 2, face_tolerance=delta * 2) + options = Options(angle_tolerance=1., + point_tolerance=delta * 2, + face_tolerance=delta * 2) results = __check(mesh, options) assert len(results.non_conformal_cells) == 0 @@ -33,7 +37,9 @@ def test_two_distant_hexs(): xyz1 = XYZ(tmp + 1 + delta, tmp, tmp) mesh = build_rectilinear_blocks_mesh((xyz0, xyz1)) - options = Options(angle_tolerance=1., point_tolerance=delta / 2., face_tolerance=delta / 2.) + options = Options(angle_tolerance=1., + point_tolerance=delta / 2., + face_tolerance=delta / 2.) results = __check(mesh, options) assert len(results.non_conformal_cells) == 0 @@ -46,7 +52,9 @@ def test_two_close_shifted_hexs(): xyz1 = XYZ(tmp + 1 + delta_x, tmp + delta_y, tmp + delta_y) mesh = build_rectilinear_blocks_mesh((xyz0, xyz1)) - options = Options(angle_tolerance=1., point_tolerance=delta_x * 2, face_tolerance=delta_x * 2) + options = Options(angle_tolerance=1., + point_tolerance=delta_x * 2, + face_tolerance=delta_x * 2) results = __check(mesh, options) assert len(results.non_conformal_cells) == 1 @@ -60,7 +68,9 @@ def test_big_elem_next_to_small_elem(): xyz1 = XYZ(3 * tmp + 1 + delta, 3 * tmp, 3 * tmp) mesh = build_rectilinear_blocks_mesh((xyz0, xyz1)) - options = Options(angle_tolerance=1., point_tolerance=delta * 2, face_tolerance=delta * 2) + options = Options(angle_tolerance=1., + point_tolerance=delta * 2, + face_tolerance=delta * 2) results = __check(mesh, options) assert len(results.non_conformal_cells) == 1 diff --git a/geosx_mesh_doctor/tests/test_reorient_mesh.py b/geosx_mesh_doctor/tests/test_reorient_mesh.py index 1136bbb..ad45cb9 100644 --- a/geosx_mesh_doctor/tests/test_reorient_mesh.py +++ b/geosx_mesh_doctor/tests/test_reorient_mesh.py @@ -40,7 +40,8 @@ def __build_test_meshes() -> Generator[Expected, None, None]: (3, 2, 0), (3, 3, 0), (0, 3, 0), - ), dtype=float) + ), + dtype=float) front_nodes = numpy.array(front_nodes, dtype=float) back_nodes = front_nodes - (0., 0., 1.) @@ -57,9 +58,7 @@ def __build_test_meshes() -> Generator[Expected, None, None]: faces = [] # Creating the side faces for i in range(n): - faces.append( - (i % n + n, (i + 1) % n + n, (i + 1) % n, i % n) - ) + faces.append((i % n + n, (i + 1) % n + n, (i + 1) % n, i % n)) # Creating the front faces faces.append(tuple(range(n))) faces.append(tuple(reversed(range(n, 2 * n)))) @@ -71,33 +70,31 @@ def __build_test_meshes() -> Generator[Expected, None, None]: mesh = vtkUnstructuredGrid() mesh.Allocate(1) mesh.SetPoints(points) - mesh.InsertNextCell(VTK_POLYHEDRON, to_vtk_id_list( - face_stream.dump() - )) + mesh.InsertNextCell(VTK_POLYHEDRON, to_vtk_id_list(face_stream.dump())) yield Expected(mesh=mesh, face_stream=face_stream) # Here, two faces are flipped. mesh = vtkUnstructuredGrid() mesh.Allocate(1) mesh.SetPoints(points) - mesh.InsertNextCell(VTK_POLYHEDRON, to_vtk_id_list( - face_stream.flip_faces((1, 2)).dump() - )) + mesh.InsertNextCell(VTK_POLYHEDRON, + to_vtk_id_list(face_stream.flip_faces((1, 2)).dump())) yield Expected(mesh=mesh, face_stream=face_stream) # Last, all faces are flipped. mesh = vtkUnstructuredGrid() mesh.Allocate(1) mesh.SetPoints(points) - mesh.InsertNextCell(VTK_POLYHEDRON, to_vtk_id_list( - face_stream.flip_faces(range(len(faces))).dump() - )) + mesh.InsertNextCell( + VTK_POLYHEDRON, + to_vtk_id_list(face_stream.flip_faces(range(len(faces))).dump())) yield Expected(mesh=mesh, face_stream=face_stream) @pytest.mark.parametrize("expected", __build_test_meshes()) def test_reorient_polyhedron(expected: Expected): - output_mesh = reorient_mesh(expected.mesh, range(expected.mesh.GetNumberOfCells())) + output_mesh = reorient_mesh(expected.mesh, + range(expected.mesh.GetNumberOfCells())) assert output_mesh.GetNumberOfCells() == 1 assert output_mesh.GetCell(0).GetCellType() == VTK_POLYHEDRON face_stream_ids = vtkIdList() diff --git a/geosx_mesh_doctor/tests/test_self_intersecting_elements.py b/geosx_mesh_doctor/tests/test_self_intersecting_elements.py index 8993e68..053d510 100644 --- a/geosx_mesh_doctor/tests/test_self_intersecting_elements.py +++ b/geosx_mesh_doctor/tests/test_self_intersecting_elements.py @@ -1,6 +1,5 @@ from vtkmodules.vtkCommonCore import ( - vtkPoints, -) + vtkPoints, ) from vtkmodules.vtkCommonDataModel import ( VTK_HEXAHEDRON, vtkCellArray, @@ -8,7 +7,6 @@ vtkUnstructuredGrid, ) - from checks.self_intersecting_elements import Options, __check diff --git a/geosx_mesh_doctor/tests/test_supported_elements.py b/geosx_mesh_doctor/tests/test_supported_elements.py index 639d904..8a80647 100644 --- a/geosx_mesh_doctor/tests/test_supported_elements.py +++ b/geosx_mesh_doctor/tests/test_supported_elements.py @@ -15,19 +15,20 @@ from checks.supported_elements import Options, check, __check from checks.vtk_polyhedron import parse_face_stream, build_face_to_face_connectivity_through_edges, FaceStream from checks.vtk_utils import ( - to_vtk_id_list, -) + to_vtk_id_list, ) -@pytest.mark.parametrize("base_name", - ("supportedElements.vtk", "supportedElementsAsVTKPolyhedra.vtk")) +@pytest.mark.parametrize( + "base_name", + ("supportedElements.vtk", "supportedElementsAsVTKPolyhedra.vtk")) def test_supported_elements(base_name) -> None: """ Testing that the supported elements are properly detected as supported! :param base_name: Supported elements are provided as standard elements or polyhedron elements. """ directory = os.path.dirname(os.path.realpath(__file__)) - supported_elements_file_name = os.path.join(directory, "../../../../unitTests/meshTests", base_name) + supported_elements_file_name = os.path.join( + directory, "../../../../unitTests/meshTests", base_name) options = Options(chunk_size=1, num_proc=4) result = check(supported_elements_file_name, options) assert not result.unsupported_std_elements_types @@ -40,42 +41,92 @@ def make_dodecahedron() -> Tuple[vtkPoints, vtkIdList]: This code was adapted from an official vtk example. :return: The tuple of points and faces (as vtk instances). """ - points = ( - (1.21412, 0, 1.58931), - (0.375185, 1.1547, 1.58931), - (-0.982247, 0.713644, 1.58931), - (-0.982247, -0.713644, 1.58931), - (0.375185, -1.1547, 1.58931), - (1.96449, 0, 0.375185), - (0.607062, 1.86835, 0.375185), - (-1.58931, 1.1547, 0.375185), - (-1.58931, -1.1547, 0.375185), - (0.607062, -1.86835, 0.375185), - (1.58931, 1.1547, -0.375185), - (-0.607062, 1.86835, -0.375185), - (-1.96449, 0, -0.375185), - (-0.607062, -1.86835, -0.375185), - (1.58931, -1.1547, -0.375185), - (0.982247, 0.713644, -1.58931), - (-0.375185, 1.1547, -1.58931), - (-1.21412, 0, -1.58931), - (-0.375185, -1.1547, -1.58931), - (0.982247, -0.713644, -1.58931) - ) - - faces = (12, # number of faces - 5, 0, 1, 2, 3, 4, # number of ids on face, ids - 5, 0, 5, 10, 6, 1, - 5, 1, 6, 11, 7, 2, - 5, 2, 7, 12, 8, 3, - 5, 3, 8, 13, 9, 4, - 5, 4, 9, 14, 5, 0, - 5, 15, 10, 5, 14, 19, - 5, 16, 11, 6, 10, 15, - 5, 17, 12, 7, 11, 16, - 5, 18, 13, 8, 12, 17, - 5, 19, 14, 9, 13, 18, - 5, 19, 18, 17, 16, 15) + points = ((1.21412, 0, 1.58931), (0.375185, 1.1547, + 1.58931), (-0.982247, 0.713644, 1.58931), + (-0.982247, -0.713644, 1.58931), (0.375185, -1.1547, 1.58931), + (1.96449, 0, 0.375185), (0.607062, 1.86835, + 0.375185), (-1.58931, 1.1547, 0.375185), + (-1.58931, -1.1547, 0.375185), (0.607062, -1.86835, 0.375185), + (1.58931, 1.1547, -0.375185), (-0.607062, 1.86835, -0.375185), + (-1.96449, 0, -0.375185), (-0.607062, -1.86835, -0.375185), + (1.58931, -1.1547, -0.375185), (0.982247, 0.713644, -1.58931), + (-0.375185, 1.1547, -1.58931), (-1.21412, 0, -1.58931), + (-0.375185, -1.1547, -1.58931), (0.982247, -0.713644, -1.58931)) + + faces = ( + 12, # number of faces + 5, + 0, + 1, + 2, + 3, + 4, # number of ids on face, ids + 5, + 0, + 5, + 10, + 6, + 1, + 5, + 1, + 6, + 11, + 7, + 2, + 5, + 2, + 7, + 12, + 8, + 3, + 5, + 3, + 8, + 13, + 9, + 4, + 5, + 4, + 9, + 14, + 5, + 0, + 5, + 15, + 10, + 5, + 14, + 19, + 5, + 16, + 11, + 6, + 10, + 15, + 5, + 17, + 12, + 7, + 11, + 16, + 5, + 18, + 13, + 8, + 12, + 17, + 5, + 19, + 14, + 9, + 13, + 18, + 5, + 19, + 18, + 17, + 16, + 15) p = vtkPoints() p.Allocate(len(points)) @@ -105,20 +156,10 @@ def test_dodecahedron() -> None: def test_parse_face_stream() -> None: _, faces = make_dodecahedron() result = parse_face_stream(faces) - expected = ( - (0, 1, 2, 3, 4), - (0, 5, 10, 6, 1), - (1, 6, 11, 7, 2), - (2, 7, 12, 8, 3), - (3, 8, 13, 9, 4), - (4, 9, 14, 5, 0), - (15, 10, 5, 14, 19), - (16, 11, 6, 10, 15), - (17, 12, 7, 11, 16), - (18, 13, 8, 12, 17), - (19, 14, 9, 13, 18), - (19, 18, 17, 16, 15) - ) + expected = ((0, 1, 2, 3, 4), (0, 5, 10, 6, 1), (1, 6, 11, 7, 2), + (2, 7, 12, 8, 3), (3, 8, 13, 9, 4), (4, 9, 14, 5, 0), + (15, 10, 5, 14, 19), (16, 11, 6, 10, 15), (17, 12, 7, 11, 16), + (18, 13, 8, 12, 17), (19, 14, 9, 13, 18), (19, 18, 17, 16, 15)) assert result == expected face_stream = FaceStream.build_from_vtk_id_list(faces) assert face_stream.num_faces == 12 diff --git a/geosx_mesh_doctor/tests/test_triangle_distance.py b/geosx_mesh_doctor/tests/test_triangle_distance.py index 605169b..7313ff3 100644 --- a/geosx_mesh_doctor/tests/test_triangle_distance.py +++ b/geosx_mesh_doctor/tests/test_triangle_distance.py @@ -18,14 +18,8 @@ class ExpectedSeg: @classmethod def from_tuples(cls, p0, u0, p1, u1, x, y): - return cls( - numpy.array(p0), - numpy.array(u0), - numpy.array(p1), - numpy.array(u1), - numpy.array(x), - numpy.array(y) - ) + return cls(numpy.array(p0), numpy.array(u0), numpy.array(p1), + numpy.array(u1), numpy.array(x), numpy.array(y)) def __get_segments_references(): @@ -64,7 +58,7 @@ def __get_segments_references(): u0=(1., 2., 1.), p1=(1., 0., 0.), u1=(1., 1., 0.), - x=(1./6., 2./6., 1./6.), + x=(1. / 6., 2. / 6., 1. / 6.), y=(1., 0., 0.), ) # Overlapping edges. @@ -90,7 +84,8 @@ def __get_segments_references(): @pytest.mark.parametrize("expected", __get_segments_references()) def test_segments(expected: ExpectedSeg): eps = numpy.finfo(float).eps - x, y = distance_between_two_segments(expected.p0, expected.u0, expected.p1, expected.u1) + x, y = distance_between_two_segments(expected.p0, expected.u0, expected.p1, + expected.u1) if norm(expected.x - expected.y) == 0: assert norm(x - y) == 0. else: @@ -108,64 +103,59 @@ class ExpectedTri: @classmethod def from_tuples(cls, t0, t1, d, p0, p1): - return cls( - numpy.array(t0), - numpy.array(t1), - float(d), - numpy.array(p0), - numpy.array(p1) - ) + return cls(numpy.array(t0), numpy.array(t1), float(d), numpy.array(p0), + numpy.array(p1)) def __get_triangles_references(): # Node to node configuration. - yield ExpectedTri.from_tuples( - t0=((0., 0., 0.), (1., 0., 0.), (0., 1., 1.)), - t1=((2., 0., 0.), (3., 0., 0.), (2., 1., 1.)), - d=1., - p0=(1., 0., 0.), - p1=(2., 0., 0.) - ) + yield ExpectedTri.from_tuples(t0=((0., 0., 0.), (1., 0., 0.), (0., 1., + 1.)), + t1=((2., 0., 0.), (3., 0., 0.), (2., 1., + 1.)), + d=1., + p0=(1., 0., 0.), + p1=(2., 0., 0.)) # Node to edge configuration. - yield ExpectedTri.from_tuples( - t0=((0., 0., 0.), (1., 0., 0.), (0., 1., 1.)), - t1=((2., -1., 0.), (3., 0., 0.), (2., 1., 0.)), - d=1., - p0=(1., 0., 0.), - p1=(2., 0., 0.) - ) + yield ExpectedTri.from_tuples(t0=((0., 0., 0.), (1., 0., 0.), (0., 1., + 1.)), + t1=((2., -1., 0.), (3., 0., 0.), (2., 1., + 0.)), + d=1., + p0=(1., 0., 0.), + p1=(2., 0., 0.)) # Edge to edge configuration. - yield ExpectedTri.from_tuples( - t0=((0., 0., 0.), (1., 1., 1.), (1., -1., -1.)), - t1=((2., -1., 0.), (2., 1., 0.), (3., 0., 0.)), - d=1., - p0=(1., 0., 0.), - p1=(2., 0., 0.) - ) + yield ExpectedTri.from_tuples(t0=((0., 0., 0.), (1., 1., 1.), (1., -1., + -1.)), + t1=((2., -1., 0.), (2., 1., 0.), (3., 0., + 0.)), + d=1., + p0=(1., 0., 0.), + p1=(2., 0., 0.)) # Point to face configuration. - yield ExpectedTri.from_tuples( - t0=((0., 0., 0.), (1., 0., 0.), (0., 1., 1.)), - t1=((2., -1., 0.), (2., 1., -1.), (2, 1., 1.)), - d=1., - p0=(1., 0., 0.), - p1=(2., 0., 0.) - ) + yield ExpectedTri.from_tuples(t0=((0., 0., 0.), (1., 0., 0.), (0., 1., + 1.)), + t1=((2., -1., 0.), (2., 1., -1.), (2, 1., + 1.)), + d=1., + p0=(1., 0., 0.), + p1=(2., 0., 0.)) # Same triangles configuration. - yield ExpectedTri.from_tuples( - t0=((0., 0., 0.), (1., 0., 0.), (0., 1., 1.)), - t1=((0., 0., 0.), (1., 0., 0.), (0., 1., 1.)), - d=0., - p0=(0., 0., 0.), - p1=(0., 0., 0.) - ) + yield ExpectedTri.from_tuples(t0=((0., 0., 0.), (1., 0., 0.), (0., 1., + 1.)), + t1=((0., 0., 0.), (1., 0., 0.), (0., 1., + 1.)), + d=0., + p0=(0., 0., 0.), + p1=(0., 0., 0.)) # Crossing triangles configuration. - yield ExpectedTri.from_tuples( - t0=((0., 0., 0.), (2., 0., 0.), (2., 0., 1.)), - t1=((1., -1., 0.), (1., 1., 0.), (1., 1., 1.)), - d=0., - p0=(0., 0., 0.), - p1=(0., 0., 0.) - ) + yield ExpectedTri.from_tuples(t0=((0., 0., 0.), (2., 0., 0.), (2., 0., + 1.)), + t1=((1., -1., 0.), (1., 1., 0.), (1., 1., + 1.)), + d=0., + p0=(0., 0., 0.), + p1=(0., 0., 0.)) @pytest.mark.parametrize("expected", __get_triangles_references()) diff --git a/geosx_mesh_tools_package/geosx_mesh_tools/abaqus_converter.py b/geosx_mesh_tools_package/geosx_mesh_tools/abaqus_converter.py index 14f6299..adf9760 100644 --- a/geosx_mesh_tools_package/geosx_mesh_tools/abaqus_converter.py +++ b/geosx_mesh_tools_package/geosx_mesh_tools/abaqus_converter.py @@ -1,10 +1,12 @@ -import meshio # type: ignore[import] -from meshio._mesh import CellBlock # type: ignore[import] +import meshio # type: ignore[import] +from meshio._mesh import CellBlock # type: ignore[import] import numpy as np import logging -def convert_abaqus_to_gmsh(input_mesh: str, output_mesh: str, logger: logging.Logger = None) -> int: +def convert_abaqus_to_gmsh(input_mesh: str, + output_mesh: str, + logger: logging.Logger = None) -> int: """ Convert an abaqus mesh to gmsh 2 format, preserving nodeset information. @@ -41,12 +43,16 @@ def convert_abaqus_to_gmsh(input_mesh: str, output_mesh: str, logger: logging.Lo cell_ids.append(np.zeros(len(block[1]), dtype=int) - 1) for region_id, region in enumerate(region_list): mesh.field_data[region] = [region_id + 1, 3] - cell_ids[block_id][mesh.cell_sets[region][block_id]] = region_id + 1 + cell_ids[block_id][mesh.cell_sets[region] + [block_id]] = region_id + 1 # Check for bad element region conversions if (-1 in cell_ids[-1]): - logger.warning('Some element regions in block %i did not convert correctly to tags!' % (block_id)) - logger.warning('Note: These will be indicated by a -1 in the output file.') + logger.warning( + 'Some element regions in block %i did not convert correctly to tags!' + % (block_id)) + logger.warning( + 'Note: These will be indicated by a -1 in the output file.') n_warnings += 1 # Add to the meshio datastructure @@ -93,8 +99,11 @@ def convert_abaqus_to_gmsh(input_mesh: str, output_mesh: str, logger: logging.Lo quad_region.append(region_id) else: - logger.warning(' Discarding an element with an unexpected number of nodes') - logger.warning(' n_nodes=%i, element=%i, set=%s' % (n_matching, element_id, nodeset_name)) + logger.warning( + ' Discarding an element with an unexpected number of nodes' + ) + logger.warning(' n_nodes=%i, element=%i, set=%s' % + (n_matching, element_id, nodeset_name)) n_warnings += 1 # Add new tris @@ -102,7 +111,8 @@ def convert_abaqus_to_gmsh(input_mesh: str, output_mesh: str, logger: logging.Lo logger.info(' Adding %i new triangles...' % (len(new_tris))) if (-1 in tri_region): logger.warning('Triangles with empty region information found!') - logger.warning('Note: These will be indicated by a -1 in the output file.') + logger.warning( + 'Note: These will be indicated by a -1 in the output file.') n_warnings += 1 mesh.cells.append(CellBlock('triangle', np.array(new_tris))) mesh.cell_data['gmsh:geometrical'].append(np.array(tri_region)) @@ -113,7 +123,8 @@ def convert_abaqus_to_gmsh(input_mesh: str, output_mesh: str, logger: logging.Lo logger.info(' Adding %i new quads...' % (len(new_quads))) if (-1 in quad_region): logger.warning('Quads with empty region information found!') - logger.warning('Note: These will be indicated by a -1 in the output file.') + logger.warning( + 'Note: These will be indicated by a -1 in the output file.') n_warnings += 1 mesh.cells.append(CellBlock('quad', np.array(new_quads))) mesh.cell_data['gmsh:geometrical'].append(np.array(quad_region)) @@ -127,7 +138,9 @@ def convert_abaqus_to_gmsh(input_mesh: str, output_mesh: str, logger: logging.Lo return (n_warnings > 0) -def convert_abaqus_to_vtu(input_mesh: str, output_mesh: str, logger: logging.Logger = None) -> int: +def convert_abaqus_to_vtu(input_mesh: str, + output_mesh: str, + logger: logging.Logger = None) -> int: """ Convert an abaqus mesh to vtu format, preserving nodeset information. diff --git a/geosx_mesh_tools_package/geosx_mesh_tools/main.py b/geosx_mesh_tools_package/geosx_mesh_tools/main.py index 1637d07..b1608ed 100644 --- a/geosx_mesh_tools_package/geosx_mesh_tools/main.py +++ b/geosx_mesh_tools_package/geosx_mesh_tools/main.py @@ -11,8 +11,13 @@ def build_abaqus_converter_input_parser() -> argparse.ArgumentParser: """ parser = argparse.ArgumentParser() parser.add_argument('input', type=str, help='Input abaqus mesh file name') - parser.add_argument('output', type=str, help='Output gmsh/vtu mesh file name') - parser.add_argument('-v', '--verbose', help='Increase verbosity level', action="store_true") + parser.add_argument('output', + type=str, + help='Output gmsh/vtu mesh file name') + parser.add_argument('-v', + '--verbose', + help='Increase verbosity level', + action="store_true") return parser @@ -40,11 +45,14 @@ def main() -> None: # Call the converter err = 0 if ('.msh' in args.output): - err = abaqus_converter.convert_abaqus_to_gmsh(args.input, args.output, logger) + err = abaqus_converter.convert_abaqus_to_gmsh(args.input, args.output, + logger) else: - err = abaqus_converter.convert_abaqus_to_vtu(args.input, args.output, logger) + err = abaqus_converter.convert_abaqus_to_vtu(args.input, args.output, + logger) if err: - sys.exit('Warnings detected: check the output file for potential errors!') + sys.exit( + 'Warnings detected: check the output file for potential errors!') if __name__ == '__main__': diff --git a/geosx_xml_tools_package/geosx_xml_tools/attribute_coverage.py b/geosx_xml_tools_package/geosx_xml_tools/attribute_coverage.py index 2b64df8..066bbcd 100644 --- a/geosx_xml_tools_package/geosx_xml_tools/attribute_coverage.py +++ b/geosx_xml_tools_package/geosx_xml_tools/attribute_coverage.py @@ -1,4 +1,4 @@ -from lxml import etree as ElementTree # type: ignore[import] +from lxml import etree as ElementTree # type: ignore[import] import os from pathlib import Path from typing import Any, Iterable, Dict @@ -7,11 +7,14 @@ record_type = Dict[str, Dict[str, Any]] -def parse_schema_element(root: ElementTree.Element, - node: ElementTree.Element, - xsd: str = '{http://www.w3.org/2001/XMLSchema}', - recursive_types: Iterable[str] = ['PeriodicEvent', 'SoloEvent', 'HaltEvent'], - folders: Iterable[str] = ['src', 'examples']) -> record_type: +def parse_schema_element( + root: ElementTree.Element, + node: ElementTree.Element, + xsd: str = '{http://www.w3.org/2001/XMLSchema}', + recursive_types: Iterable[str] = [ + 'PeriodicEvent', 'SoloEvent', 'HaltEvent' + ], + folders: Iterable[str] = ['src', 'examples']) -> record_type: """Parse the xml schema at the current level Args: @@ -35,15 +38,18 @@ def parse_schema_element(root: ElementTree.Element, attribute_name = attribute.get('name') local_types['attributes'][attribute_name] = {ka: [] for ka in folders} if ('default' in attribute.attrib): - local_types['attributes'][attribute_name]['default'] = attribute.get('default') + local_types['attributes'][attribute_name][ + 'default'] = attribute.get('default') # Parse children choice_node = element_def.findall('%schoice' % (xsd)) if choice_node: for child in choice_node[0].findall('%selement' % (xsd)): child_name = child.get('name') - if not ((child_name in recursive_types) and (element_name in recursive_types)): - local_types['children'][child_name] = parse_schema_element(root, child) + if not ((child_name in recursive_types) and + (element_name in recursive_types)): + local_types['children'][child_name] = parse_schema_element( + root, child) return local_types @@ -63,7 +69,9 @@ def parse_schema(fname: str) -> record_type: return {'Problem': parse_schema_element(xml_root, problem_node)} -def collect_xml_attributes_level(local_types: record_type, node: ElementTree.Element, folder: str) -> None: +def collect_xml_attributes_level(local_types: record_type, + node: ElementTree.Element, + folder: str) -> None: """Collect xml attribute usage at the current level Args: @@ -76,10 +84,12 @@ def collect_xml_attributes_level(local_types: record_type, node: ElementTree.Ele for child in node: if child.tag in local_types['children']: - collect_xml_attributes_level(local_types['children'][child.tag], child, folder) + collect_xml_attributes_level(local_types['children'][child.tag], + child, folder) -def collect_xml_attributes(xml_types: record_type, fname: str, folder: str) -> None: +def collect_xml_attributes(xml_types: record_type, fname: str, + folder: str) -> None: """Collect xml attribute usage in a file Args: @@ -87,16 +97,18 @@ def collect_xml_attributes(xml_types: record_type, fname: str, folder: str) -> N fname (str): name of the target file folder (str): the source folder for the current file """ - parser = ElementTree.XMLParser(remove_comments=True, remove_blank_text=True) + parser = ElementTree.XMLParser(remove_comments=True, + remove_blank_text=True) xml_tree = ElementTree.parse(fname, parser=parser) xml_root = xml_tree.getroot() collect_xml_attributes_level(xml_types['Problem'], xml_root, folder) -def write_attribute_usage_xml_level(local_types: record_type, - node: ElementTree.Element, - folders: Iterable[str] = ['src', 'examples']) -> None: +def write_attribute_usage_xml_level( + local_types: record_type, + node: ElementTree.Element, + folders: Iterable[str] = ['src', 'examples']) -> None: """Write xml attribute usage file at a given level Args: @@ -110,7 +122,8 @@ def write_attribute_usage_xml_level(local_types: record_type, node.append(attribute_node) if ('default' in local_types['attributes'][ka]): - attribute_node.set('default', local_types['attributes'][ka]['default']) + attribute_node.set('default', + local_types['attributes'][ka]['default']) unique_values = [] for f in folders: diff --git a/geosx_xml_tools_package/geosx_xml_tools/command_line_parsers.py b/geosx_xml_tools_package/geosx_xml_tools/command_line_parsers.py index 4c07d11..51f33f8 100644 --- a/geosx_xml_tools_package/geosx_xml_tools/command_line_parsers.py +++ b/geosx_xml_tools_package/geosx_xml_tools/command_line_parsers.py @@ -10,25 +10,40 @@ def build_preprocessor_input_parser() -> argparse.ArgumentParser: """ # Parse the user arguments parser = argparse.ArgumentParser() - parser.add_argument('-i', '--input', type=str, action='append', help='Input file name (multiple allowed)') - parser.add_argument('-c', - '--compiled-name', + parser.add_argument('-i', + '--input', type=str, - help='Compiled xml file name (otherwise, it is randomly genrated)', - default='') - parser.add_argument('-s', '--schema', type=str, help='GEOSX schema to use for validation', default='') - parser.add_argument('-v', '--verbose', type=int, help='Verbosity of outputs', default=0) - parser.add_argument('-p', - '--parameters', - nargs='+', action='append', - help='Parameter overrides (name value, multiple allowed)', - default=[]) + help='Input file name (multiple allowed)') + parser.add_argument( + '-c', + '--compiled-name', + type=str, + help='Compiled xml file name (otherwise, it is randomly genrated)', + default='') + parser.add_argument('-s', + '--schema', + type=str, + help='GEOSX schema to use for validation', + default='') + parser.add_argument('-v', + '--verbose', + type=int, + help='Verbosity of outputs', + default=0) + parser.add_argument( + '-p', + '--parameters', + nargs='+', + action='append', + help='Parameter overrides (name value, multiple allowed)', + default=[]) return parser -def parse_xml_preprocessor_arguments() -> Tuple[argparse.Namespace, Iterable[str]]: +def parse_xml_preprocessor_arguments( +) -> Tuple[argparse.Namespace, Iterable[str]]: """Parse user arguments Args: @@ -54,12 +69,36 @@ def build_xml_formatter_input_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() parser.add_argument('input', type=str, help='Input file name') - parser.add_argument('-i', '--indent', type=int, help='Indent size', default=2) - parser.add_argument('-s', '--style', type=int, help='Indent style', default=0) - parser.add_argument('-d', '--depth', type=int, help='Block separation depth', default=2) - parser.add_argument('-a', '--alphebitize', type=int, help='Alphebetize attributes', default=0) - parser.add_argument('-c', '--close', type=int, help='Close tag style', default=0) - parser.add_argument('-n', '--namespace', type=int, help='Include namespace', default=0) + parser.add_argument('-i', + '--indent', + type=int, + help='Indent size', + default=2) + parser.add_argument('-s', + '--style', + type=int, + help='Indent style', + default=0) + parser.add_argument('-d', + '--depth', + type=int, + help='Block separation depth', + default=2) + parser.add_argument('-a', + '--alphebitize', + type=int, + help='Alphebetize attributes', + default=0) + parser.add_argument('-c', + '--close', + type=int, + help='Close tag style', + default=0) + parser.add_argument('-n', + '--namespace', + type=int, + help='Include namespace', + default=0) return parser @@ -71,8 +110,16 @@ def build_attribute_coverage_input_parser() -> argparse.ArgumentParser: """ parser = argparse.ArgumentParser() - parser.add_argument('-r', '--root', type=str, help='GEOSX root', default='') - parser.add_argument('-o', '--output', type=str, help='Output file name', default='attribute_test.xml') + parser.add_argument('-r', + '--root', + type=str, + help='GEOSX root', + default='') + parser.add_argument('-o', + '--output', + type=str, + help='Output file name', + default='attribute_test.xml') return parser @@ -84,5 +131,9 @@ def build_xml_redundancy_input_parser() -> argparse.ArgumentParser: """ parser = argparse.ArgumentParser() - parser.add_argument('-r', '--root', type=str, help='GEOSX root', default='') + parser.add_argument('-r', + '--root', + type=str, + help='GEOSX root', + default='') return parser diff --git a/geosx_xml_tools_package/geosx_xml_tools/main.py b/geosx_xml_tools_package/geosx_xml_tools/main.py index b511028..3ab0bf8 100644 --- a/geosx_xml_tools_package/geosx_xml_tools/main.py +++ b/geosx_xml_tools_package/geosx_xml_tools/main.py @@ -25,9 +25,10 @@ def check_mpi_rank() -> int: TFunc = Callable[..., Any] -def wait_for_file_write_rank_0(target_file_argument: Union[int, str] = 0, - max_wait_time: float = 100, - max_startup_delay: float = 1) -> Callable[[TFunc], TFunc]: +def wait_for_file_write_rank_0( + target_file_argument: Union[int, str] = 0, + max_wait_time: float = 100, + max_startup_delay: float = 1) -> Callable[[TFunc], TFunc]: """Constructor for a function decorator that waits for a target file to be written on rank 0 Args: @@ -64,7 +65,8 @@ def wait_for_file_write_rank_0_decorator(*args, **kwargs) -> Any: # Variations in thread startup times may mean the file has already been processed # If the last edit was done within the specified time, then allow the thread to proceed - if (abs(target_file_edit_time - time.time()) < max_startup_delay): + if (abs(target_file_edit_time - time.time()) + < max_startup_delay): target_file_edit_time = 0.0 # Go into the target process or wait for the expected file update @@ -91,7 +93,8 @@ def preprocess_serial() -> None: Entry point for the geosx_xml_tools console script """ # Process the xml file - args, unknown_args = command_line_parsers.parse_xml_preprocessor_arguments() + args, unknown_args = command_line_parsers.parse_xml_preprocessor_arguments( + ) # Attempt to only process the file on rank 0 # Note: The rank here is determined by inspecting the system environment variables @@ -99,7 +102,9 @@ def preprocess_serial() -> None: # If the rank detection fails, then it will preprocess the file on all ranks, which # sometimes cause a (seemingly harmless) file write conflict. # processor = xml_processor.process - processor = wait_for_file_write_rank_0(target_file_argument='outputFile', max_wait_time=100)(xml_processor.process) + processor = wait_for_file_write_rank_0(target_file_argument='outputFile', + max_wait_time=100)( + xml_processor.process) compiled_name = processor(args.input, outputFile=args.compiled_name, @@ -124,23 +129,26 @@ def preprocess_parallel() -> Iterable[str]: MPI aware xml preprocesing """ # Process the xml file - from mpi4py import MPI # type: ignore[import] + from mpi4py import MPI # type: ignore[import] comm = MPI.COMM_WORLD rank = comm.Get_rank() - args, unknown_args = command_line_parsers.parse_xml_preprocessor_arguments() + args, unknown_args = command_line_parsers.parse_xml_preprocessor_arguments( + ) compiled_name = '' if (rank == 0): - compiled_name = xml_processor.process(args.input, - outputFile=args.compiled_name, - schema=args.schema, - verbose=args.verbose, - parameter_override=args.parameters) + compiled_name = xml_processor.process( + args.input, + outputFile=args.compiled_name, + schema=args.schema, + verbose=args.verbose, + parameter_override=args.parameters) compiled_name = comm.bcast(compiled_name, root=0) return format_geosx_arguments(compiled_name, unknown_args) -def format_geosx_arguments(compiled_name: str, unknown_args: Iterable[str]) -> Iterable[str]: +def format_geosx_arguments(compiled_name: str, + unknown_args: Iterable[str]) -> Iterable[str]: """Format GEOSX arguments Args: diff --git a/geosx_xml_tools_package/geosx_xml_tools/regex_tools.py b/geosx_xml_tools_package/geosx_xml_tools/regex_tools.py index ded5c1a..56ed5ab 100644 --- a/geosx_xml_tools_package/geosx_xml_tools/regex_tools.py +++ b/geosx_xml_tools_package/geosx_xml_tools/regex_tools.py @@ -19,7 +19,8 @@ patterns: Dict[str, str] = { 'parameters': r"\$:?([a-zA-Z_0-9]*)\$?", - 'units': r"([0-9]*?\.?[0-9]+(?:[eE][-+]?[0-9]*?)?)\ *?\[([-+.*/()a-zA-Z0-9]*)\]", + 'units': + r"([0-9]*?\.?[0-9]+(?:[eE][-+]?[0-9]*?)?)\ *?\[([-+.*/()a-zA-Z0-9]*)\]", 'units_b': r"([a-zA-Z]*)", 'symbolic': r"\`([-+.*/() 0-9eE]*)\`", 'sanitize': r"[a-z-[e]A-Z-[E]]", @@ -44,7 +45,8 @@ def SymbolicMathRegexHandler(match: re.Match) -> str: value = eval(sanitized, {'__builtins__': None}) # Format the string, removing any trailing zeros, decimals, etc. - str_value = re.sub(patterns['strip_trailing'], '', symbolic_format % (value)) + str_value = re.sub(patterns['strip_trailing'], '', + symbolic_format % (value)) str_value = re.sub(patterns['strip_trailing_b'], '', str_value) return str_value else: @@ -71,7 +73,9 @@ def __call__(self, match: re.Match) -> str: k = match.group(1) if k: if (k not in self.target.keys()): - raise Exception('Error: Target (%s) is not defined in the regex handler' % k) + raise Exception( + 'Error: Target (%s) is not defined in the regex handler' % + k) value = self.target[k] return str(value) else: diff --git a/geosx_xml_tools_package/geosx_xml_tools/table_generator.py b/geosx_xml_tools_package/geosx_xml_tools/table_generator.py index bb4a63c..d2ebedf 100644 --- a/geosx_xml_tools_package/geosx_xml_tools/table_generator.py +++ b/geosx_xml_tools_package/geosx_xml_tools/table_generator.py @@ -21,7 +21,8 @@ def write_GEOS_table(axes_values: Iterable[np.ndarray], axes_shape = tuple([len(x) for x in axes_values]) for k in properties.keys(): if (np.shape(properties[k]) != axes_shape): - raise Exception("Shape of parameter %s is incompatible with given axes" % (k)) + raise Exception( + "Shape of parameter %s is incompatible with given axes" % (k)) # Write axes files for ka, x in zip(axes_names, axes_values): @@ -33,8 +34,9 @@ def write_GEOS_table(axes_values: Iterable[np.ndarray], np.savetxt('%s.geos' % (k), tmp, fmt=string_format, delimiter=',') -def read_GEOS_table(axes_files: Iterable[str], - property_files: Iterable[str]) -> Tuple[Iterable[np.ndarray], Dict[str, np.ndarray]]: +def read_GEOS_table( + axes_files: Iterable[str], property_files: Iterable[str] +) -> Tuple[Iterable[np.ndarray], Dict[str, np.ndarray]]: """Read an GEOS-compatible ascii table. Args: @@ -46,7 +48,8 @@ def read_GEOS_table(axes_files: Iterable[str], """ axes_values = [] for f in axes_files: - axes_values.append(np.loadtxt('%s.geos' % (f), unpack=True, delimiter=',')) + axes_values.append( + np.loadtxt('%s.geos' % (f), unpack=True, delimiter=',')) axes_shape = tuple([len(x) for x in axes_values]) # Open property files diff --git a/geosx_xml_tools_package/geosx_xml_tools/tests/generate_test_xml.py b/geosx_xml_tools_package/geosx_xml_tools/tests/generate_test_xml.py index 3c1b771..d9536e8 100644 --- a/geosx_xml_tools_package/geosx_xml_tools/tests/generate_test_xml.py +++ b/geosx_xml_tools_package/geosx_xml_tools/tests/generate_test_xml.py @@ -340,23 +340,30 @@ def generate_test_xml_files(root_dir): # Write the files, and apply pretty_print to targets for easy matches # No advanced features case with open('%s/no_advanced_features_input.xml' % (root_dir), 'w') as f: - f.write(xml_header + xml_base_a + xml_base_b + field_string_base + xml_footer) + f.write(xml_header + xml_base_a + xml_base_b + field_string_base + + xml_footer) with open('%s/no_advanced_features_target.xml' % (root_dir), 'w') as f: - f.write(xml_header + xml_base_a + xml_base_b + field_string_base + xml_footer) - xml_formatter.format_file('%s/no_advanced_features_target.xml' % (root_dir)) + f.write(xml_header + xml_base_a + xml_base_b + field_string_base + + xml_footer) + xml_formatter.format_file('%s/no_advanced_features_target.xml' % + (root_dir)) # Parameters case with open('%s/parameters_input.xml' % (root_dir), 'w') as f: - f.write(xml_header + xml_parameters + xml_base_a + xml_base_b + field_string_with_parameters + xml_footer) + f.write(xml_header + xml_parameters + xml_base_a + xml_base_b + + field_string_with_parameters + xml_footer) with open('%s/parameters_target.xml' % (root_dir), 'w') as f: - f.write(xml_header + xml_base_a + xml_base_b + field_string_base + xml_footer) + f.write(xml_header + xml_base_a + xml_base_b + field_string_base + + xml_footer) xml_formatter.format_file('%s/parameters_target.xml' % (root_dir)) # Symbolic + parameters case with open('%s/symbolic_parameters_input.xml' % (root_dir), 'w') as f: - f.write(xml_header + xml_parameters + xml_base_a + xml_base_b + field_string_with_symbolic + xml_footer) + f.write(xml_header + xml_parameters + xml_base_a + xml_base_b + + field_string_with_symbolic + xml_footer) with open('%s/symbolic_parameters_target.xml' % (root_dir), 'w') as f: - f.write(xml_header + xml_base_a + xml_base_b + field_string_alt + xml_footer) + f.write(xml_header + xml_base_a + xml_base_b + field_string_alt + + xml_footer) xml_formatter.format_file('%s/symbolic_parameters_target.xml' % (root_dir)) # Included case @@ -370,5 +377,6 @@ def generate_test_xml_files(root_dir): with open('%s/included/included_c.xml' % (root_dir), 'w') as f: f.write(xml_header + field_string_base + xml_footer) with open('%s/included_target.xml' % (root_dir), 'w') as f: - f.write(xml_header + xml_base_a + xml_base_b + field_string_base + xml_footer) + f.write(xml_header + xml_base_a + xml_base_b + field_string_base + + xml_footer) xml_formatter.format_file('%s/included_target.xml' % (root_dir)) diff --git a/geosx_xml_tools_package/geosx_xml_tools/tests/test_manager.py b/geosx_xml_tools_package/geosx_xml_tools/tests/test_manager.py index 54e60a7..4e8e16b 100644 --- a/geosx_xml_tools_package/geosx_xml_tools/tests/test_manager.py +++ b/geosx_xml_tools_package/geosx_xml_tools/tests/test_manager.py @@ -23,18 +23,29 @@ def test_unit_dict(self): self.assertTrue(bool(unitManager.units)) # Scale value tests - @parameterized.expand([['meter', '2', 2.0], ['meter', '1.234', 1.234], ['meter', '1.234e1', 12.34], - ['meter', '1.234E1', 12.34], ['meter', '1.234e+1', 12.34], ['meter', '1.234e-1', 0.1234], - ['mumeter', '1', 1.0e-6], ['micrometer', '1', 1.0e-6], ['kilometer', '1', 1.0e3], - ['ms', '1', 1.0e-3], ['millisecond', '1', 1.0e-3], ['Ms', '1', 1.0e6], ['m/s', '1', 1.0], - ['micrometer/s', '1', 1.0e-6], ['micrometer/ms', '1', 1.0e-3], - ['micrometer/microsecond', '1', 1.0], ['m**2', '1', 1.0], ['km**2', '1', 1.0e6], - ['kilometer**2', '1', 1.0e6], ['(km*mm)', '1', 1.0], ['(km*mm)**2', '1', 1.0], - ['km^2', '1', 1.0e6, True], ['bbl/day', '1', 0.000001840130728333], ['cP', '1', 0.001]]) + @parameterized.expand([['meter', '2', 2.0], ['meter', '1.234', 1.234], + ['meter', '1.234e1', 12.34], + ['meter', '1.234E1', 12.34], + ['meter', '1.234e+1', 12.34], + ['meter', '1.234e-1', 0.1234], + ['mumeter', '1', + 1.0e-6], ['micrometer', '1', 1.0e-6], + ['kilometer', '1', 1.0e3], ['ms', '1', 1.0e-3], + ['millisecond', '1', 1.0e-3], ['Ms', '1', 1.0e6], + ['m/s', '1', 1.0], ['micrometer/s', '1', 1.0e-6], + ['micrometer/ms', '1', 1.0e-3], + ['micrometer/microsecond', '1', 1.0], + ['m**2', '1', 1.0], ['km**2', '1', 1.0e6], + ['kilometer**2', '1', 1.0e6], ['(km*mm)', '1', 1.0], + ['(km*mm)**2', '1', 1.0], + ['km^2', '1', 1.0e6, True], + ['bbl/day', '1', 0.000001840130728333], + ['cP', '1', 0.001]]) def test_units(self, unit, scale, expected_value, expect_fail=False): try: val = float(unitManager([scale, unit])) - self.assertTrue((abs(val - expected_value) < self.tol) != expect_fail) + self.assertTrue((abs(val - + expected_value) < self.tol) != expect_fail) except TypeError: self.assertTrue(expect_fail) @@ -48,15 +59,22 @@ def setUpClass(cls): cls.regexHandler.target['foo'] = '1.23' cls.regexHandler.target['bar'] = '4.56e7' - @parameterized.expand([['$:foo*1.234', '1.23*1.234'], ['$:foo*1.234/$:bar', '1.23*1.234/4.56e7'], - ['$:foo*1.234/($:bar*$:foo)', '1.23*1.234/(4.56e7*1.23)'], - ['$foo*1.234/$bar', '1.23*1.234/4.56e7'], ['$foo$*1.234/$bar', '1.23*1.234/4.56e7'], - ['$foo$*1.234/$bar$', '1.23*1.234/4.56e7'], - ['$blah$*1.234/$bar$', '1.23*1.234/4.56e7', True], - ['$foo$*1.234/$bar$', '4.56e7*1.234/4.56e7', True]]) - def test_parameter_regex(self, parameterInput, expectedValue, expect_fail=False): + @parameterized.expand( + [['$:foo*1.234', '1.23*1.234'], + ['$:foo*1.234/$:bar', '1.23*1.234/4.56e7'], + ['$:foo*1.234/($:bar*$:foo)', '1.23*1.234/(4.56e7*1.23)'], + ['$foo*1.234/$bar', '1.23*1.234/4.56e7'], + ['$foo$*1.234/$bar', '1.23*1.234/4.56e7'], + ['$foo$*1.234/$bar$', '1.23*1.234/4.56e7'], + ['$blah$*1.234/$bar$', '1.23*1.234/4.56e7', True], + ['$foo$*1.234/$bar$', '4.56e7*1.234/4.56e7', True]]) + def test_parameter_regex(self, + parameterInput, + expectedValue, + expect_fail=False): try: - result = re.sub(regex_tools.patterns['parameters'], self.regexHandler, parameterInput) + result = re.sub(regex_tools.patterns['parameters'], + self.regexHandler, parameterInput) self.assertTrue((result == expectedValue) != expect_fail) except Exception: self.assertTrue(expect_fail) @@ -69,13 +87,16 @@ class TestUnitsRegex(unittest.TestCase): def setUpClass(cls): cls.tol = 1e-6 - @parameterized.expand([['1.234[m**2/s]', '1.234'], ['1.234 [m**2/s]', '1.234'], ['1.234[m**2/s]*3.4', '1.234*3.4'], - ['1.234[m**2/s] + 5.678[mm/s]', '1.234 + 5.678e-3'], - ['1.234 [m**2/s] + 5.678 [mm/s]', '1.234 + 5.678e-3'], - ['(1.234[m**2/s])*5.678', '(1.234)*5.678']]) + @parameterized.expand( + [['1.234[m**2/s]', '1.234'], ['1.234 [m**2/s]', '1.234'], + ['1.234[m**2/s]*3.4', '1.234*3.4'], + ['1.234[m**2/s] + 5.678[mm/s]', '1.234 + 5.678e-3'], + ['1.234 [m**2/s] + 5.678 [mm/s]', '1.234 + 5.678e-3'], + ['(1.234[m**2/s])*5.678', '(1.234)*5.678']]) def test_units_regex(self, unitInput, expectedValue, expect_fail=False): try: - result = re.sub(regex_tools.patterns['units'], unitManager.regexHandler, unitInput) + result = re.sub(regex_tools.patterns['units'], + unitManager.regexHandler, unitInput) self.assertTrue((result == expectedValue) != expect_fail) except Exception: self.assertTrue(expect_fail) @@ -88,13 +109,21 @@ class TestSymbolicRegex(unittest.TestCase): def setUpClass(cls): cls.tol = 1e-6 - @parameterized.expand([['`1.234`', '1.234'], ['`1.234*2.0`', '2.468'], ['`10`', '1e1'], ['`10*2`', '2e1'], - ['`1.0/2.0`', '5e-1'], ['`2.0**2`', '4'], ['`1.0 + 2.0**2`', '5'], ['`(1.0 + 2.0)**2`', '9'], - ['`((1.0 + 2.0)**2)**(0.5)`', '3'], ['`(1.2e3)*2`', '2.4e3'], ['`1.2e3*2`', '2.4e3'], + @parameterized.expand([['`1.234`', '1.234'], ['`1.234*2.0`', '2.468'], + ['`10`', '1e1'], ['`10*2`', '2e1'], + ['`1.0/2.0`', '5e-1'], ['`2.0**2`', '4'], + ['`1.0 + 2.0**2`', '5'], ['`(1.0 + 2.0)**2`', '9'], + ['`((1.0 + 2.0)**2)**(0.5)`', '3'], + ['`(1.2e3)*2`', '2.4e3'], ['`1.2e3*2`', '2.4e3'], ['`2.0^2`', '4', True], ['`sqrt(4.0)`', '2', True]]) - def test_symbolic_regex(self, symbolicInput, expectedValue, expect_fail=False): + def test_symbolic_regex(self, + symbolicInput, + expectedValue, + expect_fail=False): try: - result = re.sub(regex_tools.patterns['symbolic'], regex_tools.SymbolicMathRegexHandler, symbolicInput) + result = re.sub(regex_tools.patterns['symbolic'], + regex_tools.SymbolicMathRegexHandler, + symbolicInput) self.assertTrue((result == expectedValue) != expect_fail) except Exception: self.assertTrue(expect_fail) @@ -107,10 +136,11 @@ class TestXMLProcessor(unittest.TestCase): def setUpClass(cls): generate_test_xml.generate_test_xml_files('.') - @parameterized.expand([['no_advanced_features_input.xml', 'no_advanced_features_target.xml'], - ['parameters_input.xml', 'parameters_target.xml'], - ['included_input.xml', 'included_target.xml'], - ['symbolic_parameters_input.xml', 'symbolic_parameters_target.xml']]) + @parameterized.expand( + [['no_advanced_features_input.xml', 'no_advanced_features_target.xml'], + ['parameters_input.xml', 'parameters_target.xml'], + ['included_input.xml', 'included_target.xml'], + ['symbolic_parameters_input.xml', 'symbolic_parameters_target.xml']]) def test_xml_processor(self, input_file, target_file, expect_fail=False): try: tmp = xml_processor.process(input_file, @@ -161,8 +191,16 @@ def main(): # Parse the user arguments parser = argparse.ArgumentParser() - parser.add_argument('-t', '--test_dir', type=str, help='Test output directory', default='./test_results') - parser.add_argument('-v', '--verbose', type=int, help='Verbosity level', default=2) + parser.add_argument('-t', + '--test_dir', + type=str, + help='Test output directory', + default='./test_results') + parser.add_argument('-v', + '--verbose', + type=int, + help='Verbosity level', + default=2) args = parser.parse_args() # Process the xml file diff --git a/geosx_xml_tools_package/geosx_xml_tools/unit_manager.py b/geosx_xml_tools_package/geosx_xml_tools/unit_manager.py index 44360b0..c8687d7 100644 --- a/geosx_xml_tools_package/geosx_xml_tools/unit_manager.py +++ b/geosx_xml_tools_package/geosx_xml_tools/unit_manager.py @@ -25,17 +25,22 @@ def __call__(self, unitStruct: List[Any]) -> str: """ # Replace all instances of units in the string with their scale defined in self.units - symbolicUnits = re.sub(regex_tools.patterns['units_b'], self.unitMatcher, unitStruct[1]) + symbolicUnits = re.sub(regex_tools.patterns['units_b'], + self.unitMatcher, unitStruct[1]) # Strip out any undesired characters and evaluate # Note: the only allowed alpha characters are e and E. This could be relaxed to allow # functions such as sin, cos, etc. - symbolicUnits_sanitized = re.sub(regex_tools.patterns['sanitize'], '', symbolicUnits).strip() - value = float(unitStruct[0]) * eval(symbolicUnits_sanitized, {'__builtins__': None}) + symbolicUnits_sanitized = re.sub(regex_tools.patterns['sanitize'], '', + symbolicUnits).strip() + value = float(unitStruct[0]) * eval(symbolicUnits_sanitized, + {'__builtins__': None}) # Format the string, removing any trailing zeros, decimals, extraneous exponential formats - str_value = re.sub(regex_tools.patterns['strip_trailing'], '', regex_tools.symbolic_format % (value)) - str_value = re.sub(regex_tools.patterns['strip_trailing_b'], '', str_value) + str_value = re.sub(regex_tools.patterns['strip_trailing'], '', + regex_tools.symbolic_format % (value)) + str_value = re.sub(regex_tools.patterns['strip_trailing_b'], '', + str_value) return str_value def regexHandler(self, match: re.Match) -> str: @@ -130,13 +135,17 @@ def buildUnits(self) -> None: prefixes[prefixes[p]['alt']] = {'value': prefixes[p]['value']} for u in list(unit_defs.keys()): for alt in unit_defs[u]['alt']: - unit_defs[alt] = {'value': unit_defs[u]['value'], 'usePrefix': unit_defs[u]['usePrefix']} + unit_defs[alt] = { + 'value': unit_defs[u]['value'], + 'usePrefix': unit_defs[u]['usePrefix'] + } # Combine the results into the final dictionary for u in unit_defs.keys(): if (unit_defs[u]['usePrefix']): for p in prefixes.keys(): - self.units[p + u] = prefixes[p]['value'] * unit_defs[u]['value'] + self.units[ + p + u] = prefixes[p]['value'] * unit_defs[u]['value'] else: self.units[u] = unit_defs[u]['value'] @@ -146,6 +155,8 @@ def buildUnits(self) -> None: duplicates = [k for k, v in Counter(tmp).items() if v > 1] if (duplicates): print(duplicates) - raise Exception('Error: There are overlapping unit definitions in the UnitManager') + raise Exception( + 'Error: There are overlapping unit definitions in the UnitManager' + ) self.unitMatcher.target = self.units diff --git a/geosx_xml_tools_package/geosx_xml_tools/xml_formatter.py b/geosx_xml_tools_package/geosx_xml_tools/xml_formatter.py index eb745ab..8ef929e 100644 --- a/geosx_xml_tools_package/geosx_xml_tools/xml_formatter.py +++ b/geosx_xml_tools_package/geosx_xml_tools/xml_formatter.py @@ -1,11 +1,12 @@ import os -from lxml import etree as ElementTree # type: ignore[import] +from lxml import etree as ElementTree # type: ignore[import] import re from typing import List, Any, TextIO from geosx_xml_tools import command_line_parsers -def format_attribute(attribute_indent: str, ka: str, attribute_value: str) -> str: +def format_attribute(attribute_indent: str, ka: str, + attribute_value: str) -> str: """Format xml attribute strings Args: @@ -28,7 +29,9 @@ def format_attribute(attribute_indent: str, ka: str, attribute_value: str) -> st # Identify and split multi-line attributes if re.match(r"\s*{\s*({[-+.,0-9a-zA-Z\s]*},?\s*)*\s*}", attribute_value): - split_positions: List[Any] = [match.end() for match in re.finditer(r"}\s*,", attribute_value)] + split_positions: List[Any] = [ + match.end() for match in re.finditer(r"}\s*,", attribute_value) + ] newline_indent = '\n%s' % (' ' * (len(attribute_indent) + len(ka) + 4)) new_values = [] for a, b in zip([None] + split_positions, split_positions + [None]): @@ -83,8 +86,10 @@ def format_xml_level(output: TextIO, if ((level == 0) & include_namespace): # Handle the optional namespace information at the root level # Note: preferably, this would point to a schema we host online - attribute_dict['xmlns:xsi'] = 'http://www.w3.org/2001/XMLSchema-instance' - attribute_dict['xsi:noNamespaceSchemaLocation'] = '/usr/gapps/GEOS/schema/schema.xsd' + attribute_dict[ + 'xmlns:xsi'] = 'http://www.w3.org/2001/XMLSchema-instance' + attribute_dict[ + 'xsi:noNamespaceSchemaLocation'] = '/usr/gapps/GEOS/schema/schema.xsd' elif (level > 0): attribute_dict = node.attrib @@ -96,26 +101,32 @@ def format_xml_level(output: TextIO, # Format attributes for ka in akeys: # Avoid formatting mathpresso expressions - if not (node.tag in ["SymbolicFunction", "CompositeFunction"] and ka == "expression"): - attribute_dict[ka] = format_attribute(attribute_indent, ka, attribute_dict[ka]) + if not (node.tag in ["SymbolicFunction", "CompositeFunction"] + and ka == "expression"): + attribute_dict[ka] = format_attribute( + attribute_indent, ka, attribute_dict[ka]) for ii in range(0, len(akeys)): k = akeys[ii] if ((ii == 0) & modify_attribute_indent): output.write(' %s=\"%s\"' % (k, attribute_dict[k])) else: - output.write('\n%s%s=\"%s\"' % (attribute_indent, k, attribute_dict[k])) + output.write('\n%s%s=\"%s\"' % + (attribute_indent, k, attribute_dict[k])) # Write children if len(node): output.write('>') Nc = len(node) for ii, child in zip(range(Nc), node): - format_xml_level(output, child, level + 1, indent, block_separation_max_depth, modify_attribute_indent, - sort_attributes, close_tag_newline, include_namespace) + format_xml_level(output, child, level + 1, indent, + block_separation_max_depth, + modify_attribute_indent, sort_attributes, + close_tag_newline, include_namespace) # Add space between blocks - if ((level < block_separation_max_depth) & (ii < Nc - 1) & (child.tag is not ElementTree.Comment)): + if ((level < block_separation_max_depth) & (ii < Nc - 1) & + (child.tag is not ElementTree.Comment)): output.write('\n') # Write the end tag @@ -149,7 +160,9 @@ def format_file(input_fname: str, try: tree = ElementTree.parse(fname) root = tree.getroot() - prologue_comments = [tmp.text for tmp in root.itersiblings(preceding=True)] + prologue_comments = [ + tmp.text for tmp in root.itersiblings(preceding=True) + ] epilog_comments = [tmp.text for tmp in root.itersiblings()] with open(fname, 'w') as f: @@ -158,15 +171,16 @@ def format_file(input_fname: str, for comment in reversed(prologue_comments): f.write('\n' % (comment)) - format_xml_level(f, - root, - 0, - indent=' ' * indent_size, - block_separation_max_depth=block_separation_max_depth, - modify_attribute_indent=indent_style, - sort_attributes=alphebitize_attributes, - close_tag_newline=close_style, - include_namespace=namespace) + format_xml_level( + f, + root, + 0, + indent=' ' * indent_size, + block_separation_max_depth=block_separation_max_depth, + modify_attribute_indent=indent_style, + sort_attributes=alphebitize_attributes, + close_tag_newline=close_style, + include_namespace=namespace) for comment in epilog_comments: f.write('\n' % (comment)) diff --git a/geosx_xml_tools_package/geosx_xml_tools/xml_processor.py b/geosx_xml_tools_package/geosx_xml_tools/xml_processor.py index 757d025..18edd8c 100644 --- a/geosx_xml_tools_package/geosx_xml_tools/xml_processor.py +++ b/geosx_xml_tools_package/geosx_xml_tools/xml_processor.py @@ -1,7 +1,7 @@ """Tools for processing xml files in GEOSX""" -from lxml import etree as ElementTree # type: ignore[import] -from lxml.etree import XMLSyntaxError # type: ignore[import] +from lxml import etree as ElementTree # type: ignore[import] +from lxml.etree import XMLSyntaxError # type: ignore[import] import re import os from geosx_xml_tools import regex_tools, unit_manager @@ -13,7 +13,8 @@ parameterHandler = regex_tools.DictRegexHandler() -def merge_xml_nodes(existingNode: ElementTree.Element, targetNode: ElementTree.Element, level: int) -> None: +def merge_xml_nodes(existingNode: ElementTree.Element, + targetNode: ElementTree.Element, level: int) -> None: """Merge nodes in an included file into the current structure level by level. Args: @@ -60,7 +61,10 @@ def merge_xml_nodes(existingNode: ElementTree.Element, targetNode: ElementTree.E existingNode.insert(-1, target) -def merge_included_xml_files(root: ElementTree.Element, fname: str, includeCount: int, maxInclude: int = 100) -> None: +def merge_included_xml_files(root: ElementTree.Element, + fname: str, + includeCount: int, + maxInclude: int = 100) -> None: """Recursively merge included files into the current structure. Args: @@ -72,13 +76,15 @@ def merge_included_xml_files(root: ElementTree.Element, fname: str, includeCount # Expand the input path pwd = os.getcwd() - includePath, fname = os.path.split(os.path.abspath(os.path.expanduser(fname))) + includePath, fname = os.path.split( + os.path.abspath(os.path.expanduser(fname))) os.chdir(includePath) # Check to see if the code has fallen into a loop includeCount += 1 if (includeCount > maxInclude): - raise Exception('Reached maximum recursive includes... Is there an include loop?') + raise Exception( + 'Reached maximum recursive includes... Is there an include loop?') # Check to make sure the file exists if (not os.path.isfile(fname)): @@ -87,7 +93,8 @@ def merge_included_xml_files(root: ElementTree.Element, fname: str, includeCount # Load target xml try: - parser = ElementTree.XMLParser(remove_comments=True, remove_blank_text=True) + parser = ElementTree.XMLParser(remove_comments=True, + remove_blank_text=True) includeTree = ElementTree.parse(fname, parser) includeRoot = includeTree.getroot() except XMLSyntaxError as err: @@ -119,22 +126,29 @@ def apply_regex_to_node(node: ElementTree.Element) -> None: # Parameter format: $Parameter or $:Parameter ii = 0 while ('$' in value): - value = re.sub(regex_tools.patterns['parameters'], parameterHandler, value) + value = re.sub(regex_tools.patterns['parameters'], + parameterHandler, value) ii += 1 if (ii > 100): - raise Exception('Reached maximum parameter expands (Node=%s, value=%s)' % (node.tag, value)) + raise Exception( + 'Reached maximum parameter expands (Node=%s, value=%s)' % + (node.tag, value)) # Unit format: 9.81[m**2/s] or 1.0 [bbl/day] if ('[' in value): - value = re.sub(regex_tools.patterns['units'], unitManager.regexHandler, value) + value = re.sub(regex_tools.patterns['units'], + unitManager.regexHandler, value) # Symbolic format: `1 + 2.34e5*2 * ...` ii = 0 while ('`' in value): - value = re.sub(regex_tools.patterns['symbolic'], regex_tools.SymbolicMathRegexHandler, value) + value = re.sub(regex_tools.patterns['symbolic'], + regex_tools.SymbolicMathRegexHandler, value) ii += 1 if (ii > 100): - raise Exception('Reached maximum symbolic expands (Node=%s, value=%s)' % (node.tag, value)) + raise Exception( + 'Reached maximum symbolic expands (Node=%s, value=%s)' % + (node.tag, value)) node.set(k, value) @@ -190,7 +204,9 @@ def process(inputFiles: Iterable[str], # Expand the input path pwd = os.getcwd() - expanded_files = [os.path.abspath(os.path.expanduser(f)) for f in inputFiles] + expanded_files = [ + os.path.abspath(os.path.expanduser(f)) for f in inputFiles + ] single_path, single_input = os.path.split(expanded_files[0]) os.chdir(single_path) @@ -200,7 +216,8 @@ def process(inputFiles: Iterable[str], if (len(expanded_files) == 1): # Load single files directly try: - parser = ElementTree.XMLParser(remove_comments=True, remove_blank_text=True) + parser = ElementTree.XMLParser(remove_comments=True, + remove_blank_text=True) tree = ElementTree.parse(single_input, parser=parser) root = tree.getroot() except XMLSyntaxError as err: @@ -225,7 +242,9 @@ def process(inputFiles: Iterable[str], includeCount = 0 for includeNode in root.findall('Included'): for f in includeNode.findall('File'): - merge_included_xml_files(root, f.get('name'), includeCount) # type: ignore[attr-defined] + merge_included_xml_files( + root, f.get('name'), + includeCount) # type: ignore[attr-defined] os.chdir(pwd) # Build the parameter map @@ -257,15 +276,18 @@ def process(inputFiles: Iterable[str], # Comment out or remove the Parameter, Included nodes for includeNode in root.findall('Included'): if keep_includes: - root.insert(-1, ElementTree.Comment(ElementTree.tostring(includeNode))) + root.insert(-1, + ElementTree.Comment(ElementTree.tostring(includeNode))) root.remove(includeNode) for parameterNode in root.findall('Parameters'): if keep_parameters: - root.insert(-1, ElementTree.Comment(ElementTree.tostring(parameterNode))) + root.insert( + -1, ElementTree.Comment(ElementTree.tostring(parameterNode))) root.remove(parameterNode) for overrideNode in root.findall('CommandLineOverride'): if keep_parameters: - root.insert(-1, ElementTree.Comment(ElementTree.tostring(overrideNode))) + root.insert( + -1, ElementTree.Comment(ElementTree.tostring(overrideNode))) root.remove(overrideNode) # Generate a random output name if not specified @@ -307,11 +329,14 @@ def validate_xml(fname: str, schema: str, verbose: int) -> None: print('Validating the xml against the schema...') try: ofile = ElementTree.parse(fname) - sfile = ElementTree.XMLSchema(ElementTree.parse(os.path.expanduser(schema))) + sfile = ElementTree.XMLSchema( + ElementTree.parse(os.path.expanduser(schema))) sfile.assertValid(ofile) except ElementTree.DocumentInvalid as err: print(err) - print('\nWarning: input XML contains potentially invalid input parameters:') + print( + '\nWarning: input XML contains potentially invalid input parameters:' + ) print('-' * 20 + '\n') print(sfile.error_log) print('\n' + '-' * 20) diff --git a/geosx_xml_tools_package/geosx_xml_tools/xml_redundancy_check.py b/geosx_xml_tools_package/geosx_xml_tools/xml_redundancy_check.py index 251e1df..66b124e 100644 --- a/geosx_xml_tools_package/geosx_xml_tools/xml_redundancy_check.py +++ b/geosx_xml_tools_package/geosx_xml_tools/xml_redundancy_check.py @@ -1,6 +1,6 @@ from geosx_xml_tools.attribute_coverage import parse_schema from geosx_xml_tools.xml_formatter import format_file -from lxml import etree as ElementTree # type: ignore[import] +from lxml import etree as ElementTree # type: ignore[import] import os from pathlib import Path from geosx_xml_tools import command_line_parsers @@ -36,7 +36,8 @@ def check_redundancy_level(local_schema: Dict[str, Any], for child in node: # Comments will not appear in the schema if child.tag in local_schema['children']: - child_is_required = check_redundancy_level(local_schema['children'][child.tag], child) + child_is_required = check_redundancy_level( + local_schema['children'][child.tag], child) node_is_required += child_is_required if not child_is_required: node.remove(child) diff --git a/hdf5_wrapper_package/hdf5_wrapper/use_example.py b/hdf5_wrapper_package/hdf5_wrapper/use_example.py index 5bbdbe3..1205a4f 100644 --- a/hdf5_wrapper_package/hdf5_wrapper/use_example.py +++ b/hdf5_wrapper_package/hdf5_wrapper/use_example.py @@ -3,7 +3,8 @@ from typing import Union, Dict -def print_database_iterative(database: hdf5_wrapper.hdf5_wrapper, level: int = 0) -> None: +def print_database_iterative(database: hdf5_wrapper.hdf5_wrapper, + level: int = 0) -> None: """ Print the database targets iteratively by level diff --git a/hdf5_wrapper_package/hdf5_wrapper/wrapper.py b/hdf5_wrapper_package/hdf5_wrapper/wrapper.py index a8a11b8..064db8c 100644 --- a/hdf5_wrapper_package/hdf5_wrapper/wrapper.py +++ b/hdf5_wrapper_package/hdf5_wrapper/wrapper.py @@ -1,4 +1,4 @@ -import h5py # type: ignore[import] +import h5py # type: ignore[import] import numpy as np from numpy.core.defchararray import encode, decode from typing import Union, Dict, Any, Iterable, Optional, Tuple @@ -15,7 +15,10 @@ class hdf5_wrapper(): A class for reading/writing hdf5 files, which behaves similar to a native dict """ - def __init__(self, fname: str = '', target: Optional[h5py.File] = None, mode: str = 'r') -> None: + def __init__(self, + fname: str = '', + target: Optional[h5py.File] = None, + mode: str = 'r') -> None: """ Initialize the hdf5_wrapper class diff --git a/hdf5_wrapper_package/hdf5_wrapper/wrapper_tests.py b/hdf5_wrapper_package/hdf5_wrapper/wrapper_tests.py index 7e496ee..9666d57 100644 --- a/hdf5_wrapper_package/hdf5_wrapper/wrapper_tests.py +++ b/hdf5_wrapper_package/hdf5_wrapper/wrapper_tests.py @@ -8,7 +8,10 @@ def random_string(N): - return ''.join(random.choices(string.ascii_uppercase + string.ascii_lowercase + string.digits, k=N)) + return ''.join( + random.choices(string.ascii_uppercase + string.ascii_lowercase + + string.digits, + k=N)) def build_test_dict(depth=0, max_depth=3): @@ -52,8 +55,10 @@ def compare_wrapper_dict(self, x, y): vx, vy = x[k], y[k] tx, ty = type(vx), type(vy) - if ((tx != ty) and not (isinstance(vx, (dict, hdf5_wrapper.hdf5_wrapper)) - and isinstance(vy, (dict, hdf5_wrapper.hdf5_wrapper)))): + if ((tx != ty) and + not (isinstance(vx, (dict, hdf5_wrapper.hdf5_wrapper)) + and isinstance(vy, + (dict, hdf5_wrapper.hdf5_wrapper)))): self.assertTrue(np.issubdtype(tx, ty)) if isinstance(vx, (dict, hdf5_wrapper.hdf5_wrapper)): @@ -66,16 +71,22 @@ def compare_wrapper_dict(self, x, y): self.assertTrue(vx == vy) def test_a_insert_write(self): - data = hdf5_wrapper.hdf5_wrapper(os.path.join(self.test_dir, 'test_insert.hdf5'), mode='w') + data = hdf5_wrapper.hdf5_wrapper(os.path.join(self.test_dir, + 'test_insert.hdf5'), + mode='w') data.insert(self.test_dict) def test_b_manual_write(self): - data = hdf5_wrapper.hdf5_wrapper(os.path.join(self.test_dir, 'test_manual.hdf5'), mode='w') + data = hdf5_wrapper.hdf5_wrapper(os.path.join(self.test_dir, + 'test_manual.hdf5'), + mode='w') for k, v in self.test_dict.items(): data[k] = v def test_c_link_write(self): - data = hdf5_wrapper.hdf5_wrapper(os.path.join(self.test_dir, 'test_linked.hdf5'), mode='w') + data = hdf5_wrapper.hdf5_wrapper(os.path.join(self.test_dir, + 'test_linked.hdf5'), + mode='w') for k, v in self.test_dict.items(): if ('child' in k): child_path = os.path.join(self.test_dir, 'test_%s.hdf5' % (k)) @@ -86,20 +97,24 @@ def test_c_link_write(self): data[k] = v def test_d_compare_wrapper(self): - data = hdf5_wrapper.hdf5_wrapper(os.path.join(self.test_dir, 'test_insert.hdf5')) + data = hdf5_wrapper.hdf5_wrapper( + os.path.join(self.test_dir, 'test_insert.hdf5')) self.compare_wrapper_dict(self.test_dict, data) def test_e_compare_wrapper_copy(self): - data = hdf5_wrapper.hdf5_wrapper(os.path.join(self.test_dir, 'test_insert.hdf5')) + data = hdf5_wrapper.hdf5_wrapper( + os.path.join(self.test_dir, 'test_insert.hdf5')) tmp = data.copy() self.compare_wrapper_dict(self.test_dict, tmp) def test_f_compare_wrapper(self): - data = hdf5_wrapper.hdf5_wrapper(os.path.join(self.test_dir, 'test_manual.hdf5')) + data = hdf5_wrapper.hdf5_wrapper( + os.path.join(self.test_dir, 'test_manual.hdf5')) self.compare_wrapper_dict(self.test_dict, data) def test_g_compare_wrapper(self): - data = hdf5_wrapper.hdf5_wrapper(os.path.join(self.test_dir, 'test_linked.hdf5')) + data = hdf5_wrapper.hdf5_wrapper( + os.path.join(self.test_dir, 'test_linked.hdf5')) self.compare_wrapper_dict(self.test_dict, data) @@ -112,7 +127,11 @@ def main(): # Parse the user arguments parser = argparse.ArgumentParser() - parser.add_argument('-v', '--verbose', type=int, help='Verbosity level', default=2) + parser.add_argument('-v', + '--verbose', + type=int, + help='Verbosity level', + default=2) args = parser.parse_args() # Unit manager tests diff --git a/pygeosx_tools_package/pygeosx_tools/file_io.py b/pygeosx_tools_package/pygeosx_tools/file_io.py index 12093e8..1878087 100644 --- a/pygeosx_tools_package/pygeosx_tools/file_io.py +++ b/pygeosx_tools_package/pygeosx_tools/file_io.py @@ -51,16 +51,24 @@ def save_tables(axes: Iterable[np.ndarray], # Write the axes os.makedirs(table_root, exist_ok=True) for g, a in zip(axes, axes_names): - np.savetxt('%s/%s.csv' % (table_root, a), g, fmt='%1.5f', delimiter=',') + np.savetxt('%s/%s.csv' % (table_root, a), + g, + fmt='%1.5f', + delimiter=',') for k, p in properties.items(): - np.savetxt('%s/%s.csv' % (table_root, k), np.reshape(p, (-1), order='F'), fmt='%1.5e', delimiter=',') + np.savetxt('%s/%s.csv' % (table_root, k), + np.reshape(p, (-1), order='F'), + fmt='%1.5e', + delimiter=',') -def load_tables(axes_names: Iterable[str], - property_names: Iterable[str], - table_root: str = './tables', - extension: str = 'csv') -> Tuple[Iterable[np.ndarray], Dict[str, np.ndarray]]: +def load_tables( + axes_names: Iterable[str], + property_names: Iterable[str], + table_root: str = './tables', + extension: str = 'csv' +) -> Tuple[Iterable[np.ndarray], Dict[str, np.ndarray]]: """ Load a set of tables in GEOSX format @@ -74,12 +82,18 @@ def load_tables(axes_names: Iterable[str], tuple: List of axes values, and dictionary of table values """ # Load axes - axes = [np.loadtxt('%s/%s.%s' % (table_root, axis, extension), unpack=True) for axis in axes_names] + axes = [ + np.loadtxt('%s/%s.%s' % (table_root, axis, extension), unpack=True) + for axis in axes_names + ] N = tuple([len(x) for x in axes]) # Load properties properties = { - p: np.reshape(np.loadtxt('%s/%s.%s' % (table_root, p, extension)), N, order='F') + p: + np.reshape(np.loadtxt('%s/%s.%s' % (table_root, p, extension)), + N, + order='F') for p in property_names } diff --git a/pygeosx_tools_package/pygeosx_tools/mesh_interpolation.py b/pygeosx_tools_package/pygeosx_tools/mesh_interpolation.py index 49b2011..9bc8160 100644 --- a/pygeosx_tools_package/pygeosx_tools/mesh_interpolation.py +++ b/pygeosx_tools_package/pygeosx_tools/mesh_interpolation.py @@ -1,5 +1,5 @@ import numpy as np -from scipy import stats # type: ignore[import] +from scipy import stats # type: ignore[import] from typing import Dict, Iterable, List, Tuple, Callable, Union @@ -63,7 +63,12 @@ def extrapolate_nan_values(x, y, slope_scale=0.0): return y -def get_random_realization(x, bins, value, rand_fill=0, rand_scale=0, slope_scale=0): +def get_random_realization(x, + bins, + value, + rand_fill=0, + rand_scale=0, + slope_scale=0): """ Get a random realization for a noisy signal with a set of bins diff --git a/pygeosx_tools_package/pygeosx_tools/well_log.py b/pygeosx_tools_package/pygeosx_tools/well_log.py index 58453f0..5d5455e 100644 --- a/pygeosx_tools_package/pygeosx_tools/well_log.py +++ b/pygeosx_tools_package/pygeosx_tools/well_log.py @@ -49,7 +49,12 @@ def parse_las(fname, variable_start='~C', body_start='~A'): else: # As a fall-back use the full line variable_order.append(line[:-1]) - results[line[:-1]] = {'units': '', 'code': '', 'description': '', 'values': []} + results[line[:-1]] = { + 'units': '', + 'code': '', + 'description': '', + 'values': [] + } # Body else: diff --git a/pygeosx_tools_package/pygeosx_tools/wrapper.py b/pygeosx_tools_package/pygeosx_tools/wrapper.py index 7c46e74..2c084cd 100644 --- a/pygeosx_tools_package/pygeosx_tools/wrapper.py +++ b/pygeosx_tools_package/pygeosx_tools/wrapper.py @@ -89,7 +89,8 @@ def get_wrapper_par(problem, target_key, allgather=False, ghost_key=''): if allgather: comm.Allgather([send_buff, MPI.DOUBLE], [receive_buff, MPI.DOUBLE]) else: - comm.Gather([send_buff, MPI.DOUBLE], [receive_buff, MPI.DOUBLE], root=0) + comm.Gather([send_buff, MPI.DOUBLE], [receive_buff, MPI.DOUBLE], + root=0) # Unpack the buffers all_values = [] @@ -177,7 +178,11 @@ def get_global_value_range(problem, key): return global_min, global_max -def print_global_value_range(problem, key, header, scale=1.0, precision='%1.4f'): +def print_global_value_range(problem, + key, + header, + scale=1.0, + precision='%1.4f'): """ Print the range of a target value across all processes @@ -222,7 +227,11 @@ def set_wrapper_to_value(problem, key, value): local_values[...] = value -def set_wrapper_with_function(problem, target_key, input_keys, fn, target_index=-1): +def set_wrapper_with_function(problem, + target_key, + input_keys, + fn, + target_index=-1): """ Set the value of a wrapper using a function @@ -250,10 +259,13 @@ def set_wrapper_with_function(problem, target_key, input_keys, fn, target_index= elif (len(M) == 1): # Apply the function output across each of the target dims - local_target[...] = np.tile(np.expand_dims(fn_output, axis=1), (1, N[1])) + local_target[...] = np.tile(np.expand_dims(fn_output, axis=1), + (1, N[1])) else: - raise Exception('Shape of function output %s is not compatible with target %s' % (str(M), str(N))) + raise Exception( + 'Shape of function output %s is not compatible with target %s' + % (str(M), str(N))) elif (len(M) == 1): if (len(N) == 2): # 2D target, with 1D output applied to a given index @@ -263,14 +275,20 @@ def set_wrapper_with_function(problem, target_key, input_keys, fn, target_index= # ND target, with 1D output tiled across intermediate indices expand_axes = tuple([ii for ii in range(1, len(N) - 1)]) tile_axes = tuple([1] + [ii for ii in N[1:-1]]) - local_target[..., target_index] = np.tile(np.expand_dims(fn_output, axis=expand_axes), tile_axes) + local_target[..., target_index] = np.tile( + np.expand_dims(fn_output, axis=expand_axes), tile_axes) else: - raise Exception('Shape of function output %s is not compatible with target %s (target axis=%i)' % - (str(M), str(N), target_index)) + raise Exception( + 'Shape of function output %s is not compatible with target %s (target axis=%i)' + % (str(M), str(N), target_index)) -def search_datastructure_wrappers_recursive(group, filters, matching_paths, level=0, group_path=[]): +def search_datastructure_wrappers_recursive(group, + filters, + matching_paths, + level=0, + group_path=[]): """ Recursively search the group and its children for wrappers that match the filters @@ -281,7 +299,9 @@ def search_datastructure_wrappers_recursive(group, filters, matching_paths, leve """ for wrapper in group.wrappers(): wrapper_path = str(wrapper).split()[0] - wrapper_test = group_path + [wrapper_path[wrapper_path.rfind('/') + 1:]] + wrapper_test = group_path + [ + wrapper_path[wrapper_path.rfind('/') + 1:] + ] if all(f in wrapper_test for f in filters): matching_paths.append('/'.join(wrapper_test)) @@ -291,7 +311,8 @@ def search_datastructure_wrappers_recursive(group, filters, matching_paths, leve filters, matching_paths, level=level + 1, - group_path=group_path + [sub_group_name]) + group_path=group_path + + [sub_group_name]) def get_matching_wrapper_path(problem, filters): @@ -322,7 +343,8 @@ def get_matching_wrapper_path(problem, filters): print('Error occured while looking for wrappers:') print('Filters: [%s]' % (', '.join(filters))) print('Matching wrappers: [%s]' % (', '.join(matching_paths))) - raise Exception('Search resulted in 0 or >1 wrappers mathching filters') + raise Exception( + 'Search resulted in 0 or >1 wrappers mathching filters') def run_queries(problem, records): @@ -344,12 +366,18 @@ def run_queries(problem, records): current_time = get_wrapper(problem, "Events/time") records[k]['history'].append(current_time * records[k]['scale']) else: - tmp = print_global_value_range(problem, k, records[k]['label'], scale=records[k]['scale']) + tmp = print_global_value_range(problem, + k, + records[k]['label'], + scale=records[k]['scale']) records[k]['history'].append(tmp) sys.stdout.flush() -def plot_history(records, output_root='.', save_figures=True, show_figures=True): +def plot_history(records, + output_root='.', + save_figures=True, + show_figures=True): """ Plot the time-histories for the records structure. Note: If figures are shown, the GEOSX process will be blocked until they are closed @@ -369,7 +397,7 @@ def plot_history(records, output_root='.', save_figures=True, show_figures=True) # Assemble values to plot t = np.array(records['time']['history']) x = np.array(records[k]['history']) - N = np.shape(x) # (time, min/max, dimensions) + N = np.shape(x) # (time, min/max, dimensions) # Add plots if (len(N) == 2): @@ -387,7 +415,10 @@ def plot_history(records, output_root='.', save_figures=True, show_figures=True) # Setup axes if (('axes' not in records[k]) or (len(fa.axes) == 0)): - records[k]['axes'] = [plt.subplot(rows, columns, ii + 1) for ii in range(0, N[2])] + records[k]['axes'] = [ + plt.subplot(rows, columns, ii + 1) + for ii in range(0, N[2]) + ] for ii in range(0, N[2]): ax = records[k]['axes'][ii] @@ -395,12 +426,14 @@ def plot_history(records, output_root='.', save_figures=True, show_figures=True) ax.plot(t, x[:, 0, ii], label='min') ax.plot(t, x[:, 1, ii], label='max') ax.set_xlabel(records['time']['label']) - ax.set_ylabel('%s (dim=%i)' % (records[k]['label'], ii)) + ax.set_ylabel('%s (dim=%i)' % + (records[k]['label'], ii)) plt.legend(loc=2) records[k]['fhandle'].tight_layout(pad=1.5) if save_figures: fname = k[k.rfind('/') + 1:] - plt.savefig('%s/%s.png' % (output_root, fname), format='png') + plt.savefig('%s/%s.png' % (output_root, fname), + format='png') if show_figures: plt.show() diff --git a/timehistory_package/timehistory/plot_time_history.py b/timehistory_package/timehistory/plot_time_history.py index 839288c..311822e 100644 --- a/timehistory_package/timehistory/plot_time_history.py +++ b/timehistory_package/timehistory/plot_time_history.py @@ -17,7 +17,11 @@ def isiterable(obj): return True -def getHistorySeries(database, variable, setname, indices=None, components=None): +def getHistorySeries(database, + variable, + setname, + indices=None, + components=None): """ Retrieve a series of time history structures suitable for plotting in addition to the specific set index and component for the time series @@ -35,22 +39,32 @@ def getHistorySeries(database, variable, setname, indices=None, components=None) set_regex = re.compile(variable + '(.*?)', re.IGNORECASE) if setname is not None: set_regex = re.compile(variable + '\s*' + str(setname), re.IGNORECASE) - time_regex = re.compile('Time', re.IGNORECASE) # need to make this per-set, thought that was in already? + time_regex = re.compile( + 'Time', re.IGNORECASE + ) # need to make this per-set, thought that was in already? set_match = list(filter(set_regex.match, database.keys())) time_match = list(filter(time_regex.match, database.keys())) if len(set_match) == 0: - print(f"Error: can't locate time history data for variable/set described by regex {set_regex.pattern}") + print( + f"Error: can't locate time history data for variable/set described by regex {set_regex.pattern}" + ) return None if len(time_match) == 0: - print(f"Error: can't locate time history data for set time variable described by regex {time_regex.pattern}") + print( + f"Error: can't locate time history data for set time variable described by regex {time_regex.pattern}" + ) return None if len(set_match) > 1: - print(f"Warning: variable/set specification matches multiple datasets: {', '.join(set_match)}") + print( + f"Warning: variable/set specification matches multiple datasets: {', '.join(set_match)}" + ) if len(time_match) > 1: - print(f"Warning: set specification matches multiple time datasets: {', '.join(time_match)}") + print( + f"Warning: set specification matches multiple time datasets: {', '.join(time_match)}" + ) set_match = set_match[0] time_match = time_match[0] @@ -67,10 +81,15 @@ def getHistorySeries(database, variable, setname, indices=None, components=None) if type(indices) is int: indices = list(indices) if isiterable(indices): - oob_idxs = list(filter(lambda idx: not 0 <= idx < data_series.shape[1], indices)) + oob_idxs = list( + filter(lambda idx: not 0 <= idx < data_series.shape[1], + indices)) if len(oob_idxs) > 0: - print(f"Error: The specified indices: ({', '.join(oob_idxs)}) " + "\n\t" + - f" are out of the dataset index range: [0,{data_series.shape[1]})") + print( + f"Error: The specified indices: ({', '.join(oob_idxs)}) " + + "\n\t" + + f" are out of the dataset index range: [0,{data_series.shape[1]})" + ) indices = list(set(indices) - set(oob_idxs)) else: print(f"Error: unsupported indices type: {type(indices)}") @@ -81,29 +100,42 @@ def getHistorySeries(database, variable, setname, indices=None, components=None) if type(components) is int: components = list(components) if isiterable(components): - oob_comps = list(filter(lambda comp: not 0 <= comp < data_series.shape[2], components)) + oob_comps = list( + filter(lambda comp: not 0 <= comp < data_series.shape[2], + components)) if len(oob_comps) > 0: - print(f"Error: The specified components: ({', '.join(oob_comps)}) " + "\n\t" + - " is out of the dataset component range: [0,{data_series.shape[1]})") + print( + f"Error: The specified components: ({', '.join(oob_comps)}) " + + "\n\t" + + " is out of the dataset component range: [0,{data_series.shape[1]})" + ) components = list(set(components) - set(oob_comps)) else: print(f"Error: unsupported components type: {type(components)}") else: components = range(data_series.shape[2]) - return [(time_series[:, 0], data_series[:, idx, comp], idx, comp) for idx in indices for comp in components] + return [(time_series[:, 0], data_series[:, idx, comp], idx, comp) + for idx in indices for comp in components] def commandLinePlotGen(): parser = argparse.ArgumentParser( - description="A script that parses geosx HDF5 time-history files and produces time-history plots using matplotlib" + description= + "A script that parses geosx HDF5 time-history files and produces time-history plots using matplotlib" ) - parser.add_argument("filename", metavar="history_file", type=str, help="The time history file to parse") - - parser.add_argument("variable", - metavar="variable_name", + parser.add_argument("filename", + metavar="history_file", type=str, - help="Which time-history variable collected by GEOSX to generate a plot file for.") + help="The time history file to parse") + + parser.add_argument( + "variable", + metavar="variable_name", + type=str, + help= + "Which time-history variable collected by GEOSX to generate a plot file for." + ) parser.add_argument( "--sets", @@ -116,19 +148,23 @@ def commandLinePlotGen(): "Which index set of time-history data collected by GEOSX to generate a plot file for, may be specified multiple times with different indices/components for each set." ) - parser.add_argument("--indices", - metavar="index", - type=int, - default=[], - nargs="+", - help="An optional list of specific indices in the most-recently specified set.") + parser.add_argument( + "--indices", + metavar="index", + type=int, + default=[], + nargs="+", + help= + "An optional list of specific indices in the most-recently specified set." + ) - parser.add_argument("--components", - metavar="int", - type=int, - default=[], - nargs="+", - help="An optional list of specific variable components") + parser.add_argument( + "--components", + metavar="int", + type=int, + default=[], + nargs="+", + help="An optional list of specific variable components") args = parser.parse_args() result = 0 @@ -139,11 +175,13 @@ def commandLinePlotGen(): else: with h5w(args.filename, mode='r') as database: for setname in args.sets: - ds = getHistorySeries(database, args.variable, setname, args.indices, args.components) + ds = getHistorySeries(database, args.variable, setname, + args.indices, args.components) if ds is None: result = -1 break - figname = args.variable + ("_" + setname if setname is not None else "") + figname = args.variable + ("_" + setname + if setname is not None else "") fig, ax = plt.subplots() ax.set_title(figname) for d in ds: diff --git a/yapf.cfg b/yapf.cfg new file mode 100644 index 0000000..5f77c30 --- /dev/null +++ b/yapf.cfg @@ -0,0 +1,5 @@ +[style] +based_on_style = pep8 +spaces_before_comment = 4 +split_before_logical_operator = true +column_limit = 120