diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 132f41d..df0729e 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -30,10 +30,10 @@ jobs: python -m pip install --upgrade pip python -m pip install pytest yapf toml python -m pip install . - # - name: Lint with yapf - # working-directory: ./${{ matrix.package-name }} - # run: | - # yapf -r --diff . + - name: Lint with yapf + working-directory: ./${{ matrix.package-name }} + run: | + yapf -r --diff . --style ../.style.yapf # - name: Test with pytest # working-directory: ./${{ matrix.package-name }} # run: | diff --git a/.style.yapf b/.style.yapf new file mode 100644 index 0000000..90bd62e --- /dev/null +++ b/.style.yapf @@ -0,0 +1,6 @@ +[style] +based_on_style = pep8 +spaces_before_comment = 2 +split_before_logical_operator = true +column_limit = 120 +space_inside_brackets = true diff --git a/docs/conf.py b/docs/conf.py index 06f479f..f1cb479 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -18,16 +18,10 @@ # Add python modules to be documented python_root = '..' -python_modules = ('geosx_mesh_tools_package', - 'geosx_xml_tools_package', - 'geosx_mesh_doctor', - 'geos_ats_package', - 'hdf5_wrapper_package', - 'pygeosx_tools_package', - 'timehistory_package') +python_modules = ( 'geosx_mesh_tools_package', 'geosx_xml_tools_package', 'geosx_mesh_doctor', 'geos_ats_package', + 'hdf5_wrapper_package', 'pygeosx_tools_package', 'timehistory_package' ) for m in python_modules: - sys.path.insert(0, os.path.abspath(os.path.join(python_root, m))) - + sys.path.insert( 0, os.path.abspath( os.path.join( python_root, m ) ) ) # -- Project information ----------------------------------------------------- @@ -40,7 +34,6 @@ # The full version, including alpha/beta/rc tags release = u'' - # -- General configuration --------------------------------------------------- # If your documentation needs a minimal Sphinx version, state it here. @@ -51,18 +44,11 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx_design', - 'sphinx.ext.todo', - 'sphinx.ext.autodoc', - 'sphinx.ext.doctest', - 'sphinx.ext.imgmath', - 'sphinxarg.ext', - 'sphinx.ext.napoleon', - 'sphinxcontrib.programoutput' + 'sphinx_design', 'sphinx.ext.todo', 'sphinx.ext.autodoc', 'sphinx.ext.doctest', 'sphinx.ext.imgmath', + 'sphinxarg.ext', 'sphinx.ext.napoleon', 'sphinxcontrib.programoutput' ] - -autodoc_mock_imports = ["pygeosx", "pylvarray", "meshio", "lxml", "mpi4py", "h5py", "ats", "scipy"] +autodoc_mock_imports = [ "pygeosx", "pylvarray", "meshio", "lxml", "mpi4py", "h5py", "ats", "scipy" ] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: @@ -83,14 +69,13 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path . -exclude_patterns = [u'_build', 'Thumbs.db', '.DS_Store', 'cmake/*'] +exclude_patterns = [ u'_build', 'Thumbs.db', '.DS_Store', 'cmake/*' ] todo_include_todos = True # The name of the Pygments (syntax highlighting) style to use. pygments_style = 'sphinx' - # -- Theme options ---------------------------------------------- extensions += [ 'sphinx_rtd_theme', @@ -98,19 +83,14 @@ html_theme = "sphinx_rtd_theme" -html_theme_options = { - 'navigation_depth': -1, - 'collapse_navigation': False -} +html_theme_options = { 'navigation_depth': -1, 'collapse_navigation': False } -html_static_path = ['./_static'] +html_static_path = [ './_static' ] html_css_files = [ 'theme_overrides.css', ] - # -- Options for HTMLHelp output --------------------------------------------- # Output file base name for HTML help builder. htmlhelp_basename = 'geosPythonPackagesDoc' - diff --git a/geos_ats_package/geos_ats/__init__.py b/geos_ats_package/geos_ats/__init__.py index 51abd36..8c72442 100644 --- a/geos_ats_package/geos_ats/__init__.py +++ b/geos_ats_package/geos_ats/__init__.py @@ -3,4 +3,4 @@ # Add the machines module to the ats.atsMachines submodule, # So that ats can find our custom definitions at runtime -sys.modules['ats.atsMachines.machines'] = machines +sys.modules[ 'ats.atsMachines.machines' ] = machines diff --git a/geos_ats_package/geos_ats/command_line_parsers.py b/geos_ats_package/geos_ats/command_line_parsers.py index 95332d4..1e9b44c 100644 --- a/geos_ats_package/geos_ats/command_line_parsers.py +++ b/geos_ats_package/geos_ats/command_line_parsers.py @@ -36,49 +36,58 @@ def build_command_line_parser(): - parser = argparse.ArgumentParser(description="Runs GEOS integrated tests") + parser = argparse.ArgumentParser( description="Runs GEOS integrated tests" ) - parser.add_argument("geos_bin_dir", type=str, help="GEOS binary directory.") + parser.add_argument( "geos_bin_dir", type=str, help="GEOS binary directory." ) - parser.add_argument("-w", "--workingDir", type=str, help="Initial working directory") + parser.add_argument( "-w", "--workingDir", type=str, help="Initial working directory" ) - action_names = ','.join(action_options.keys()) - parser.add_argument("-a", "--action", type=str, default="run", help=f"Test actions options ({action_names})") + action_names = ','.join( action_options.keys() ) + parser.add_argument( "-a", "--action", type=str, default="run", help=f"Test actions options ({action_names})" ) - check_names = ','.join(check_options.keys()) - parser.add_argument("-c", "--check", type=str, default="all", help=f"Test check options ({check_names})") + check_names = ','.join( check_options.keys() ) + parser.add_argument( "-c", "--check", type=str, default="all", help=f"Test check options ({check_names})" ) - verbosity_names = ','.join(verbose_options.keys()) - parser.add_argument("-v", "--verbose", type=str, default="info", help=f"Log verbosity options ({verbosity_names})") + verbosity_names = ','.join( verbose_options.keys() ) + parser.add_argument( "-v", + "--verbose", + type=str, + default="info", + help=f"Log verbosity options ({verbosity_names})" ) - parser.add_argument("-d", "--detail", action="store_true", default=False, help="Show detailed action/check options") + parser.add_argument( "-d", + "--detail", + action="store_true", + default=False, + help="Show detailed action/check options" ) - parser.add_argument("-i", "--info", action="store_true", default=False, help="Info on various topics") + parser.add_argument( "-i", "--info", action="store_true", default=False, help="Info on various topics" ) - parser.add_argument("-r", - "--restartCheckOverrides", - nargs='+', - action='append', - help='Restart check parameter override (name value)', - default=[]) + parser.add_argument( "-r", + "--restartCheckOverrides", + nargs='+', + action='append', + help='Restart check parameter override (name value)', + default=[] ) - parser.add_argument("--salloc", - default=True, - help="Used by the chaosM machine to first allocate nodes with salloc, before running the tests") + parser.add_argument( + "--salloc", + default=True, + help="Used by the chaosM machine to first allocate nodes with salloc, before running the tests" ) parser.add_argument( "--sallocoptions", type=str, default="", - help="Used to override all command-line options for salloc. No other options with be used or added.") + help="Used to override all command-line options for salloc. No other options with be used or added." ) - parser.add_argument("--ats", nargs='+', default=[], action="append", help="pass arguments to ats") + parser.add_argument( "--ats", nargs='+', default=[], action="append", help="pass arguments to ats" ) - parser.add_argument("--machine", default=None, help="name of the machine") + parser.add_argument( "--machine", default=None, help="name of the machine" ) - parser.add_argument("--machine-dir", default=None, help="Search path for machine definitions") + parser.add_argument( "--machine-dir", default=None, help="Search path for machine definitions" ) - parser.add_argument("-l", "--logs", type=str, default=None) + parser.add_argument( "-l", "--logs", type=str, default=None ) parser.add_argument( "--failIfTestsFail", @@ -87,14 +96,14 @@ def build_command_line_parser(): help="geos_ats normally exits with 0. This will cause it to exit with an error code if there was a failed test." ) - parser.add_argument("-n", "-N", "--numNodes", type=int, default="2") + parser.add_argument( "-n", "-N", "--numNodes", type=int, default="2" ) - parser.add_argument("ats_targets", type=str, nargs='*', help="ats files or directories.") + parser.add_argument( "ats_targets", type=str, nargs='*', help="ats files or directories." ) return parser -def parse_command_line_arguments(args): +def parse_command_line_arguments( args ): parser = build_command_line_parser() options, unkown_args = parser.parse_known_args() exit_flag = False @@ -103,7 +112,7 @@ def parse_command_line_arguments(args): check = options.check if check not in check_options: print( - f"Selected check option ({check}) not recognized. Try running with --help/--details for more information") + f"Selected check option ({check}) not recognized. Try running with --help/--details for more information" ) exit_flag = True action = options.action @@ -115,15 +124,15 @@ def parse_command_line_arguments(args): verbose = options.verbose if verbose not in verbose_options: - print(f"Selected verbose option ({verbose}) not recognized") + print( f"Selected verbose option ({verbose}) not recognized" ) exit_flag = True # Print detailed information if options.detail: - for option_type, details in zip(['action', 'check'], [action_options, check_options]): - print(f'\nAvailable {option_type} options:') + for option_type, details in zip( [ 'action', 'check' ], [ action_options, check_options ] ): + print( f'\nAvailable {option_type} options:' ) for k, v in details.items(): - print(f' {k}: {v}') + print( f' {k}: {v}' ) exit_flag = True if exit_flag: @@ -132,20 +141,20 @@ def parse_command_line_arguments(args): return options -def patch_parser(parser): +def patch_parser( parser ): - def add_option_patch(*xargs, **kwargs): + def add_option_patch( *xargs, **kwargs ): """ Convert type string to actual type instance """ - tmp = kwargs.get('type', str) - type_map = {'string': str} - if isinstance(tmp, str): + tmp = kwargs.get( 'type', str ) + type_map = { 'string': str } + if isinstance( tmp, str ): if tmp in type_map: - tmp = type_map[tmp] + tmp = type_map[ tmp ] else: - tmp = locate(tmp) - kwargs['type'] = tmp - parser.add_argument(*xargs, **kwargs) + tmp = locate( tmp ) + kwargs[ 'type' ] = tmp + parser.add_argument( *xargs, **kwargs ) parser.add_option = add_option_patch diff --git a/geos_ats_package/geos_ats/common_utilities.py b/geos_ats_package/geos_ats/common_utilities.py index 0ad4ad8..3e32db1 100644 --- a/geos_ats_package/geos_ats/common_utilities.py +++ b/geos_ats_package/geos_ats/common_utilities.py @@ -10,88 +10,88 @@ # Common code for displaying information to the user. ################################################################################ -logger = logging.getLogger('geos_ats') +logger = logging.getLogger( 'geos_ats' ) -def Error(msg): - raise RuntimeError("Error: %s" % msg) +def Error( msg ): + raise RuntimeError( "Error: %s" % msg ) -def Log(msg): - import ats # type: ignore[import] +def Log( msg ): + import ats # type: ignore[import] testmode = False try: - testmode = ats.tests.AtsTest.getOptions().get("testmode") + testmode = ats.tests.AtsTest.getOptions().get( "testmode" ) except AttributeError as e: - logger.debug(e) + logger.debug( e ) if testmode: - ats.log("ALEATS: " + msg, echo=True) + ats.log( "ALEATS: " + msg, echo=True ) else: - ats.log(msg, echo=True) + ats.log( msg, echo=True ) -class TextTable(object): +class TextTable( object ): - def __init__(self, columns): + def __init__( self, columns ): self.table = [] self.sep = " : " self.indent = " " self.columns = columns - self.colmax = [None] * columns + self.colmax = [ None ] * columns self.maxwidth = self._getwidth() self.rowbreak = None self.rowbreakstyle = " " - def _getwidth(self): + def _getwidth( self ): maxwidth = 100 if os.name == "posix": try: - sttyout = subprocess.Popen(["stty", "size"], stdout=subprocess.PIPE).communicate()[0] - maxwidth = int(sttyout.split()[1]) + sttyout = subprocess.Popen( [ "stty", "size" ], stdout=subprocess.PIPE ).communicate()[ 0 ] + maxwidth = int( sttyout.split()[ 1 ] ) except: # If the stty size approach does not work, the use a default maxwidth - logger.debug("Using default maxwidth") + logger.debug( "Using default maxwidth" ) return maxwidth - def setHeader(self, *row): - assert (len(row) == self.columns) - self.table.insert(0, row) - self.table.insert(1, None) + def setHeader( self, *row ): + assert ( len( row ) == self.columns ) + self.table.insert( 0, row ) + self.table.insert( 1, None ) - def addRowBreak(self): - self.table.append(None) + def addRowBreak( self ): + self.table.append( None ) - def addRow(self, *row): - assert (len(row) == self.columns) - self.table.append(row) + def addRow( self, *row ): + assert ( len( row ) == self.columns ) + self.table.append( row ) - def setColMax(self, colindex, max): - self.colmax[colindex] = max + def setColMax( self, colindex, max ): + self.colmax[ colindex ] = max - def printTable(self, outfile=sys.stdout): + def printTable( self, outfile=sys.stdout ): table_str = '' - if len(self.table) == 0: + if len( self.table ) == 0: return # find the max column sizes colWidth = [] - for i in range(self.columns): - colWidth.append(max([len(str(row[i])) for row in self.table if row is not None])) + for i in range( self.columns ): + colWidth.append( max( [ len( str( row[ i ] ) ) for row in self.table if row is not None ] ) ) # adjust the colWidths down if colmax is step - for i in range(self.columns): - if self.colmax[i] is not None: - if colWidth[i] > self.colmax[i]: - colWidth[i] = self.colmax[i] + for i in range( self.columns ): + if self.colmax[ i ] is not None: + if colWidth[ i ] > self.colmax[ i ]: + colWidth[ i ] = self.colmax[ i ] # last column is floating - total = sum(colWidth) + self.columns * (1 + len(self.sep)) + len(self.indent) + total = sum( colWidth ) + self.columns * ( 1 + len( self.sep ) ) + len( self.indent ) if total > self.maxwidth: - colWidth[-1] = max(10, self.maxwidth - (total - colWidth[-1])) + colWidth[ -1 ] = max( 10, self.maxwidth - ( total - colWidth[ -1 ] ) ) # output the table rowbreakindex = 0 @@ -99,14 +99,14 @@ def printTable(self, outfile=sys.stdout): # row break controls. # if row is None then this is a break - addBreak = (row is None) or (self.rowbreak and rowbreakindex > 0 and rowbreakindex % self.rowbreak == 0) + addBreak = ( row is None ) or ( self.rowbreak and rowbreakindex > 0 and rowbreakindex % self.rowbreak == 0 ) if addBreak: table_str += self.indent - for i in range(self.columns): + for i in range( self.columns ): if i < self.columns - 1: table_str += f"{self.rowbreakstyle * colWidth[i]}{self.sep}" else: - table_str += self.rowbreakstyle * colWidth[i] + table_str += self.rowbreakstyle * colWidth[ i ] table_str += '\n' if row is None: @@ -118,27 +118,27 @@ def printTable(self, outfile=sys.stdout): # determine how many lines are needed by each column of this row. lines = [] - for i in range(self.columns): - if isinstance(row[i], str): - drow = textwrap.dedent(row[i]) + for i in range( self.columns ): + if isinstance( row[ i ], str ): + drow = textwrap.dedent( row[ i ] ) else: - drow = str(row[i]) + drow = str( row[ i ] ) if i == self.columns - 1: - lines.append(textwrap.wrap(drow, colWidth[i], break_long_words=False)) + lines.append( textwrap.wrap( drow, colWidth[ i ], break_long_words=False ) ) else: - lines.append(textwrap.wrap(drow, colWidth[i], break_long_words=True)) + lines.append( textwrap.wrap( drow, colWidth[ i ], break_long_words=True ) ) - maxlines = max([len(x) for x in lines]) + maxlines = max( [ len( x ) for x in lines ] ) # output the row - for j in range(maxlines): + for j in range( maxlines ): table_str += self.indent - for i in range(self.columns): - if len(lines[i]) > j: - entry = lines[i][j].ljust(colWidth[i]) + for i in range( self.columns ): + if len( lines[ i ] ) > j: + entry = lines[ i ][ j ].ljust( colWidth[ i ] ) else: - entry = " ".ljust(colWidth[i]) + entry = " ".ljust( colWidth[ i ] ) if i < self.columns - 1: table_str += f"{entry}{self.sep}" @@ -147,106 +147,108 @@ def printTable(self, outfile=sys.stdout): table_str += '\n' - outfile.write(table_str) + outfile.write( table_str ) -class InfoTopic(object): +class InfoTopic( object ): - def __init__(self, topic, outfile=sys.stdout): + def __init__( self, topic, outfile=sys.stdout ): self.topic = topic self.subtopics = [] self.outfile = outfile - def addTopic(self, topic, brief, function): - self.subtopics.append((topic, brief, function)) + def addTopic( self, topic, brief, function ): + self.subtopics.append( ( topic, brief, function ) ) - def startBanner(self): - self.outfile.write("=" * 80 + '\n') - self.outfile.write(self.topic.center(80)) - self.outfile.write("\n" + "=" * 80 + '\n') + def startBanner( self ): + self.outfile.write( "=" * 80 + '\n' ) + self.outfile.write( self.topic.center( 80 ) ) + self.outfile.write( "\n" + "=" * 80 + '\n' ) - def endBanner(self): - self.outfile.write("." * 80 + '\n') + def endBanner( self ): + self.outfile.write( "." * 80 + '\n' ) - def findTopic(self, topicName): + def findTopic( self, topicName ): for topic in self.subtopics: - if topic[0] == topicName: + if topic[ 0 ] == topicName: return topic return None - def displayMenu(self): + def displayMenu( self ): self.startBanner() - table = TextTable(3) - for i, topic in enumerate(self.subtopics): - table.addRow(i, topic[0], topic[1]) + table = TextTable( 3 ) + for i, topic in enumerate( self.subtopics ): + table.addRow( i, topic[ 0 ], topic[ 1 ] ) - table.addRow(i + 1, "exit", "") + table.addRow( i + 1, "exit", "" ) table.printTable() import ats - if ats.tests.AtsTest.getOptions().get("testmode"): + if ats.tests.AtsTest.getOptions().get( "testmode" ): return while True: - logger.info("Enter a topic: ") + logger.info( "Enter a topic: " ) sys.stdout.flush() try: line = sys.stdin.readline() except KeyboardInterrupt as e: - logger.debug(e) + logger.debug( e ) return None value = line.strip() - topic = self.findTopic(value) + topic = self.findTopic( value ) if topic: return topic try: - index = int(value) - if index >= 0 and index < len(self.subtopics): - return self.subtopics[index] - if index == len(self.subtopics): + index = int( value ) + if index >= 0 and index < len( self.subtopics ): + return self.subtopics[ index ] + if index == len( self.subtopics ): return None except ValueError as e: - logger.debug(e) + logger.debug( e ) - def process(self, args): + def process( self, args ): - if len(args) == 0: + if len( args ) == 0: topic = self.displayMenu() if topic is not None: - topic[2]() + topic[ 2 ]() else: - topicName = args[0] - topic = self.findTopic(topicName) + topicName = args[ 0 ] + topic = self.findTopic( topicName ) if topic: - topic[2](*args[1:]) + topic[ 2 ]( *args[ 1: ] ) else: - logger.warning(f"unknown topic: {topicName}") + logger.warning( f"unknown topic: {topicName}" ) -def removeLogDirectories(dir): +def removeLogDirectories( dir ): # look for subdirs containing 'ats.log' and 'geos_ats.config' # look for symlinks that point to such a directory - files = os.listdir(dir) + files = os.listdir( dir ) deldir = [] for f in files: - ff = os.path.join(dir, f) - if os.path.isdir(ff) and not os.path.islink(ff): + ff = os.path.join( dir, f ) + if os.path.isdir( ff ) and not os.path.islink( ff ): tests = [ - all([os.path.exists(os.path.join(ff, "ats.log")), - os.path.exists(os.path.join(ff, "geos_ats.config"))]), - f.find("TestLogs.") == 0 + all( [ + os.path.exists( os.path.join( ff, "ats.log" ) ), + os.path.exists( os.path.join( ff, "geos_ats.config" ) ) + ] ), + f.find( "TestLogs." ) == 0 ] - if any(tests): - deldir.append(ff) - shutil.rmtree(ff) + if any( tests ): + deldir.append( ff ) + shutil.rmtree( ff ) for f in files: - ff = os.path.join(dir, f) - if os.path.islink(ff): - pointsto = os.path.realpath(ff) + ff = os.path.join( dir, f ) + if os.path.islink( ff ): + pointsto = os.path.realpath( ff ) if pointsto in deldir: - os.remove(ff) + os.remove( ff ) diff --git a/geos_ats_package/geos_ats/configuration_record.py b/geos_ats_package/geos_ats/configuration_record.py index bfb721e..78f1ab8 100644 --- a/geos_ats_package/geos_ats/configuration_record.py +++ b/geos_ats_package/geos_ats/configuration_record.py @@ -10,12 +10,12 @@ ################################################################################ # Get the active logger instance -logger = logging.getLogger('geos_ats') +logger = logging.getLogger( 'geos_ats' ) -class ConfigItem(object): +class ConfigItem( object ): - def __init__(self, name, type, default, doc, public): + def __init__( self, name, type, default, doc, public ): self.name = name self.type = type self.default = default @@ -24,258 +24,259 @@ def __init__(self, name, type, default, doc, public): self.public = public -class Config(object): +class Config( object ): - def __init__(self): - self.__dict__["_items"] = {} + def __init__( self ): + self.__dict__[ "_items" ] = {} - def set(self, name, value): + def set( self, name, value ): # error checking - item = self._items[name] + item = self._items[ name ] try: if item.type == str: - value = item.type(value) + value = item.type( value ) else: - if isinstance(value, str): - value = item.type(eval(value)) + if isinstance( value, str ): + value = item.type( eval( value ) ) else: - value = item.type(value) + value = item.type( value ) except ValueError: - Error("Attempted to set config.%s (which is %s) with %s" % (name, str(item.type), str(value))) + Error( "Attempted to set config.%s (which is %s) with %s" % ( name, str( item.type ), str( value ) ) ) - item.value = item.type(value) + item.value = item.type( value ) - def copy_values(self, target): - logger.debug("Copying command line options to config:") - target_dict = vars(target) + def copy_values( self, target ): + logger.debug( "Copying command line options to config:" ) + target_dict = vars( target ) for k in self._items.keys(): if k in target_dict: - logger.debug(f" {k}: {target_dict[k]}") - self.set(k, target_dict[k]) + logger.debug( f" {k}: {target_dict[k]}" ) + self.set( k, target_dict[ k ] ) - def get(self, name): + def get( self, name ): # error checking - return self._items[name].value + return self._items[ name ].value - def add(self, name, type, default, doc, public=True): - item = ConfigItem(name, type, default, doc, public) - self._items[item.name] = item + def add( self, name, type, default, doc, public=True ): + item = ConfigItem( name, type, default, doc, public ) + self._items[ item.name ] = item - def checkname(self, name): + def checkname( self, name ): if name not in self.__dict__: - matches = difflib.get_close_matches(name, self._items.keys()) - if len(matches) == 0: - Error("Unknown config name: %s. " - "See 'geos_ats -i config' for the complete list." % (name)) + matches = difflib.get_close_matches( name, self._items.keys() ) + if len( matches ) == 0: + Error( "Unknown config name: %s. " + "See 'geos_ats -i config' for the complete list." % ( name ) ) else: - Error("Unknown config name: %s. " - "Perhaps you meant '%s'. " - "See 'geos_ats -i config' for the complete list." % (name, matches[0])) + Error( "Unknown config name: %s. " + "Perhaps you meant '%s'. " + "See 'geos_ats -i config' for the complete list." % ( name, matches[ 0 ] ) ) - def __setattr__(self, name, value): + def __setattr__( self, name, value ): if name in self._items: - self.set(name, value) + self.set( name, value ) else: - self.checkname(name) + self.checkname( name ) - def __getattr__(self, name): + def __getattr__( self, name ): if name in self._items: - return self._items[name].value + return self._items[ name ].value else: - self.checkname(name) + self.checkname( name ) # The global config object config = Config() # Global testTimings object -globalTestTimings = {} # type: ignore[var-annotated] +globalTestTimings = {} # type: ignore[var-annotated] # Depth of testconfig recursion configDepth = 0 -def infoConfigShow(public, outfile=sys.stdout): - topic = InfoTopic("config show", outfile) +def infoConfigShow( public, outfile=sys.stdout ): + topic = InfoTopic( "config show", outfile ) topic.startBanner() - import ats # type: ignore[import] + import ats # type: ignore[import] - keys = sorted(config._items.keys()) - table = TextTable(3) + keys = sorted( config._items.keys() ) + table = TextTable( 3 ) for k in keys: - item = config._items[k] - if (public and item.public) or (not public): + item = config._items[ k ] + if ( public and item.public ) or ( not public ): if item.default == item.value: diff = " " else: diff = "*" - table.addRow(item.name, diff, item.value) + table.addRow( item.name, diff, item.value ) - table.printTable(outfile) + table.printTable( outfile ) - cf = ats.tests.AtsTest.getOptions().get("configFile") - outfile.write(f"\nConfig file: {cf}") + cf = ats.tests.AtsTest.getOptions().get( "configFile" ) + outfile.write( f"\nConfig file: {cf}" ) - configOverride = ats.tests.AtsTest.getOptions().get("configOverride", {}) + configOverride = ats.tests.AtsTest.getOptions().get( "configOverride", {} ) if configOverride: - outfile.write("\nCommand line overrides:") - table = TextTable(1) + outfile.write( "\nCommand line overrides:" ) + table = TextTable( 1 ) for key, value in configOverride.items(): - table.addRow(key) - table.printTable(outfile) + table.addRow( key ) + table.printTable( outfile ) topic.endBanner() -def infoConfigDocumentation(public): +def infoConfigDocumentation( public ): - topic = InfoTopic("config doc") + topic = InfoTopic( "config doc" ) topic.startBanner() - keys = sorted(config._items.keys()) - table = TextTable(4) - table.addRow("[NAME]", "[TYPE]", "[DEFAULT]", "[DOC]") + keys = sorted( config._items.keys() ) + table = TextTable( 4 ) + table.addRow( "[NAME]", "[TYPE]", "[DEFAULT]", "[DOC]" ) for k in keys: - item = config._items[k] - if (public and item.public) or (not public): - table.addRow(item.name, item.type.__name__, item.default, item.doc) + item = config._items[ k ] + if ( public and item.public ) or ( not public ): + table.addRow( item.name, item.type.__name__, item.default, item.doc ) - table.colmax[2] = 20 + table.colmax[ 2 ] = 20 table.printTable() topic.endBanner() -def infoConfig(*args): +def infoConfig( *args ): - menu = InfoTopic("config") - menu.addTopic("show", "Show all the config options", lambda *x: infoConfigShow(True)) - menu.addTopic("doc", "Documentation for the config options", lambda *x: infoConfigDocumentation(True)) - menu.addTopic("showall", "Show all the config options (including the internal options)", - lambda: infoConfigShow(False)) - menu.addTopic("docall", "Documentation for the config options (including the internal options)", - lambda: infoConfigDocumentation(False)) - menu.process(args) + menu = InfoTopic( "config" ) + menu.addTopic( "show", "Show all the config options", lambda *x: infoConfigShow( True ) ) + menu.addTopic( "doc", "Documentation for the config options", lambda *x: infoConfigDocumentation( True ) ) + menu.addTopic( "showall", "Show all the config options (including the internal options)", + lambda: infoConfigShow( False ) ) + menu.addTopic( "docall", "Documentation for the config options (including the internal options)", + lambda: infoConfigDocumentation( False ) ) + menu.process( args ) -def initializeConfig(configFile, configOverride, options): +def initializeConfig( configFile, configOverride, options ): # determine the directory where geos_ats is located. Used to find # location of other programs. - geos_atsdir = os.path.realpath(os.path.dirname(__file__)) + geos_atsdir = os.path.realpath( os.path.dirname( __file__ ) ) # configfile - config.add("testbaseline_dir", str, "", "Base directory that contains all the baselines") + config.add( "testbaseline_dir", str, "", "Base directory that contains all the baselines" ) - config.add("geos_bin_dir", str, "", "Directory that contains 'geos' and related executables.") + config.add( "geos_bin_dir", str, "", "Directory that contains 'geos' and related executables." ) - config.add("userscript_path", str, "", - "Directory that contains scripts for testing, searched after test directory and executable_path.") + config.add( "userscript_path", str, "", + "Directory that contains scripts for testing, searched after test directory and executable_path." ) - config.add("clean_on_pass", bool, False, "If True, then after a TestCase passes, " - "all temporary files are removed.") + config.add( "clean_on_pass", bool, False, "If True, then after a TestCase passes, " + "all temporary files are removed." ) # geos options - config.add("geos_default_args", str, "-i", - "A string containing arguments that will always appear on the geos commandline") + config.add( "geos_default_args", str, "-i", + "A string containing arguments that will always appear on the geos commandline" ) # reporting - config.add("report_html", bool, True, "True if HTML formatted results will be generated with the report action") - config.add("report_html_file", str, "test_results.html", "Location to write the html report") - config.add("report_html_periodic", bool, True, "True to update the html file during the periodic reports") - config.add("browser_command", str, "firefox -no-remote", "Command to use to launch a browser to view html results") - config.add("browser", bool, False, "If True, then launch the browser_command to view the report_html_file") - config.add("report_doc_dir", str, os.path.normpath(os.path.join(geos_atsdir, "..", "doc")), - "Location to the test doc directory (used with html reports)") - config.add("report_doc_link", bool, True, "Link against docgen (used with html reports)") - config.add("report_doc_remake", bool, False, - "Remake test documentation, even if it already exists (used with html reports)") + config.add( "report_html", bool, True, "True if HTML formatted results will be generated with the report action" ) + config.add( "report_html_file", str, "test_results.html", "Location to write the html report" ) + config.add( "report_html_periodic", bool, True, "True to update the html file during the periodic reports" ) + config.add( "browser_command", str, "firefox -no-remote", + "Command to use to launch a browser to view html results" ) + config.add( "browser", bool, False, "If True, then launch the browser_command to view the report_html_file" ) + config.add( "report_doc_dir", str, os.path.normpath( os.path.join( geos_atsdir, "..", "doc" ) ), + "Location to the test doc directory (used with html reports)" ) + config.add( "report_doc_link", bool, True, "Link against docgen (used with html reports)" ) + config.add( "report_doc_remake", bool, False, + "Remake test documentation, even if it already exists (used with html reports)" ) - config.add("report_text", bool, True, "True if you want text results to be generated with the report action") - config.add("report_text_file", str, "test_results.txt", "Location to write the text report") - config.add("report_text_echo", bool, True, "If True, echo the report to stdout") - config.add("report_wait", bool, False, "Wait until all tests are complete before reporting") + config.add( "report_text", bool, True, "True if you want text results to be generated with the report action" ) + config.add( "report_text_file", str, "test_results.txt", "Location to write the text report" ) + config.add( "report_text_echo", bool, True, "If True, echo the report to stdout" ) + config.add( "report_wait", bool, False, "Wait until all tests are complete before reporting" ) - config.add("report_ini", bool, True, "True if you want ini results to be generated with the report action") - config.add("report_ini_file", str, "test_results.ini", "Location to write the ini report") + config.add( "report_ini", bool, True, "True if you want ini results to be generated with the report action" ) + config.add( "report_ini_file", str, "test_results.ini", "Location to write the ini report" ) - config.add("report_notations", type([]), [], "Lines of text that are inserted into the reports.") + config.add( "report_notations", type( [] ), [], "Lines of text that are inserted into the reports." ) - config.add("report_notbuilt_regexp", str, "(not built into this version)", - "Regular expression that must appear in output to indicate that feature is not built.") + config.add( "report_notbuilt_regexp", str, "(not built into this version)", + "Regular expression that must appear in output to indicate that feature is not built." ) - config.add("checkmessages_always_ignore_regexp", type([]), ["not available in this version"], - "Regular expression to ignore in all checkmessages steps.") + config.add( "checkmessages_always_ignore_regexp", type( [] ), [ "not available in this version" ], + "Regular expression to ignore in all checkmessages steps." ) - config.add("checkmessages_never_ignore_regexp", type([]), ["not yet implemented"], - "Regular expression to not ignore in all checkmessages steps.") + config.add( "checkmessages_never_ignore_regexp", type( [] ), [ "not yet implemented" ], + "Regular expression to not ignore in all checkmessages steps." ) - config.add("report_timing", bool, False, "True if you want timing file to be generated with the report action") - config.add("report_timing_overwrite", bool, False, - "True if you want timing file to overwrite existing timing file rather than augment it") + config.add( "report_timing", bool, False, "True if you want timing file to be generated with the report action" ) + config.add( "report_timing_overwrite", bool, False, + "True if you want timing file to overwrite existing timing file rather than augment it" ) # timing and priority - config.add("priority", str, "equal", "Method of prioritization of tests: [\"equal\", \"processors\",\"timing\"]") - config.add("timing_file", str, "timing.txt", "Location of timing file") + config.add( "priority", str, "equal", "Method of prioritization of tests: [\"equal\", \"processors\",\"timing\"]" ) + config.add( "timing_file", str, "timing.txt", "Location of timing file" ) # batch - config.add("batch_dryrun", bool, False, - "If true, the batch jobs will not be submitted, but the batch scripts will be created") - config.add("batch_interactive", bool, False, "If true, the batch jobs will be treated as interactive jobs") - config.add("batch_bank", str, "", "The name of the bank to use") - config.add("batch_ppn", int, 0, "Number of processors per node") - config.add("batch_partition", str, "", "the batch partition, if not specified the default will be used.") - config.add("batch_queue", str, "pbatch", "the batch queue.") - config.add("batch_header", type([]), [], "Additional lines to add to the batch header") + config.add( "batch_dryrun", bool, False, + "If true, the batch jobs will not be submitted, but the batch scripts will be created" ) + config.add( "batch_interactive", bool, False, "If true, the batch jobs will be treated as interactive jobs" ) + config.add( "batch_bank", str, "", "The name of the bank to use" ) + config.add( "batch_ppn", int, 0, "Number of processors per node" ) + config.add( "batch_partition", str, "", "the batch partition, if not specified the default will be used." ) + config.add( "batch_queue", str, "pbatch", "the batch queue." ) + config.add( "batch_header", type( [] ), [], "Additional lines to add to the batch header" ) # retry - config.add("max_retry", int, 2, "Maximum number of times to retry failed runs.") - config.add("retry_err_regexp", str, - "(launch failed|Failure in initializing endpoint|channel initialization failed)", - "Regular expression that must appear in error log in order to retry.") + config.add( "max_retry", int, 2, "Maximum number of times to retry failed runs." ) + config.add( "retry_err_regexp", str, + "(launch failed|Failure in initializing endpoint|channel initialization failed)", + "Regular expression that must appear in error log in order to retry." ) # timeout - config.add("default_timelimit", str, "30m", - "This sets a default timelimit for all test steps which do not explicitly set a timelimit.") - config.add("override_timelimit", bool, False, - "If true, the value used for the default time limit will override the time limit for each test step.") + config.add( "default_timelimit", str, "30m", + "This sets a default timelimit for all test steps which do not explicitly set a timelimit." ) + config.add( "override_timelimit", bool, False, + "If true, the value used for the default time limit will override the time limit for each test step." ) # Decomposition Multiplication config.add( "decomp_factor", int, 1, "This sets the multiplication factor to be applied to the decomposition and number of procs of all eligible tests." ) - config.add("override_np", int, 0, "If non-zero, maximum number of processors to use for each test step.") + config.add( "override_np", int, 0, "If non-zero, maximum number of processors to use for each test step." ) # global environment variables - config.add("environment", dict, {}, "Additional environment variables to use during testing") + config.add( "environment", dict, {}, "Additional environment variables to use during testing" ) # General check config - for check in ("restartcheck", ): + for check in ( "restartcheck", ): config.add( "%s_enabled" % check, bool, True, "If True, this check has the possibility of running, " "but might not run depending on the '--check' command line option. " - "If False, this check will never be run.") + "If False, this check will never be run." ) - for check in ("hdf5_dif.py", ): - config.add("%s_script" % check, - str, - os.path.join(geos_atsdir, "helpers/%s.py" % check), - "Location to the %s frontend script." % check, - public=False) + for check in ( "hdf5_dif.py", ): + config.add( "%s_script" % check, + str, + os.path.join( geos_atsdir, "helpers/%s.py" % check ), + "Location to the %s frontend script." % check, + public=False ) # Checks: Restartcheck - config.add("restart_skip_missing", bool, False, "Determines whether new/missing fields are ignored") - config.add("restart_exclude_pattern", list, [], "A list of field names to ignore in restart files") + config.add( "restart_skip_missing", bool, False, "Determines whether new/missing fields are ignored" ) + config.add( "restart_exclude_pattern", list, [], "A list of field names to ignore in restart files" ) # Checks: Curvecheck - config.add("curvecheck_enabled", bool, True, "Determines whether curvecheck steps are run.") - config.add("curvecheck_tapestry_mode", bool, False, - "Provide temporary backwards compatibility for nighty and weekly suites until they are using geos_ats") - config.add("curvecheck_absolute", float, 1e-5, "absolute tolerance") - config.add("curvecheck_relative", float, 1e-5, "relative tolerance") + config.add( "curvecheck_enabled", bool, True, "Determines whether curvecheck steps are run." ) + config.add( "curvecheck_tapestry_mode", bool, False, + "Provide temporary backwards compatibility for nighty and weekly suites until they are using geos_ats" ) + config.add( "curvecheck_absolute", float, 1e-5, "absolute tolerance" ) + config.add( "curvecheck_relative", float, 1e-5, "relative tolerance" ) config.add( "curvecheck_failtype", str, "composite", "String that represents failure check. 'composite or relative' will fail curvecheck if either the composite error or relative error is too high. 'absolute and slope' will fail only if both the absolute error check and the slope error check fail. The default value is 'composite'." @@ -288,74 +289,74 @@ def initializeConfig(configFile, configOverride, options): "curvecheck_delete_temps", bool, True, "Curvecheck generates a number of temporary data files that are used to create the images for the html file. If this parameter is true, curvecheck will delete these temporary files. By default, the parameter is true." ) - config.add("gnuplot_executable", str, os.path.join("/usr", "bin", "gnuplot"), "Location to gnuplot") + config.add( "gnuplot_executable", str, os.path.join( "/usr", "bin", "gnuplot" ), "Location to gnuplot" ) # Rebaseline: config.add( "rebaseline_undo", bool, False, "If True, and the action is set to 'rebaseline'," - " this option will undo (revert) a previous rebaseline.") - config.add("rebaseline_ask", bool, True, "If True, the rebaseline will not occur until the user has anwered an" - " 'are you sure?' question") + " this option will undo (revert) a previous rebaseline." ) + config.add( "rebaseline_ask", bool, True, "If True, the rebaseline will not occur until the user has anwered an" + " 'are you sure?' question" ) # test modifier - config.add("testmodifier", str, "", "Name of a test modifier to apply") + config.add( "testmodifier", str, "", "Name of a test modifier to apply" ) # filters - config.add("filter_maxprocessors", int, -1, "If not -1, Run only those tests where the number of" - " processors is less than or equal to this value") + config.add( "filter_maxprocessors", int, -1, "If not -1, Run only those tests where the number of" + " processors is less than or equal to this value" ) # machines - config.add("machine_options", list, [], "Arguments to pass to the machine module") + config.add( "machine_options", list, [], "Arguments to pass to the machine module" ) - config.add("script_launch", int, 0, "Whether to launch scripts (and other serial steps) on compute nodes") - config.add("openmpi_install", str, "", "Location to the openmpi installation") - config.add("openmpi_maxprocs", int, 0, "Number of maximum processors openmpi") - config.add("openmpi_procspernode", int, 1, "Number of processors per node for openmpi") + config.add( "script_launch", int, 0, "Whether to launch scripts (and other serial steps) on compute nodes" ) + config.add( "openmpi_install", str, "", "Location to the openmpi installation" ) + config.add( "openmpi_maxprocs", int, 0, "Number of maximum processors openmpi" ) + config.add( "openmpi_procspernode", int, 1, "Number of processors per node for openmpi" ) config.add( "openmpi_precommand", str, "", "A string that will be" " prepended to each command. If the substring '%(np)s' is present," " it will be replaced by then number of processors required for the" " test. If the substring '%(J)s' is present, it will be replaced by" - " the unique name of the test.") - config.add("openmpi_args", str, "", "A string of arguments to mpirun") + " the unique name of the test." ) + config.add( "openmpi_args", str, "", "A string of arguments to mpirun" ) config.add( "openmpi_terminate", str, "", "A string that will be" " called upon abnormal termination. If the substring '%(J)s' is present," - " it will be replaced by the unique name of the test.") + " it will be replaced by the unique name of the test." ) - config.add("windows_mpiexe", str, "", "Location to mpiexe") - config.add("windows_nompi", bool, False, "Run executables on nompi processor") - config.add("windows_oversubscribe", int, 1, - "Multiplier to number of processors to allow oversubscription of processors") + config.add( "windows_mpiexe", str, "", "Location to mpiexe" ) + config.add( "windows_nompi", bool, False, "Run executables on nompi processor" ) + config.add( "windows_oversubscribe", int, 1, + "Multiplier to number of processors to allow oversubscription of processors" ) # populate the config with overrides from the command line for key, value in configOverride.items(): try: - setattr(config, key, value) + setattr( config, key, value ) except RuntimeError as e: # this allows for the testconfig file to define it's own # config options that can be overridden at the command line. - logger.debug(e) + logger.debug( e ) # Setup the config dict if configFile: - logger.warning("Config file override currently not available") + logger.warning( "Config file override currently not available" ) ## override the config file from the command line for key, value in configOverride.items(): - setattr(config, key, value) + setattr( config, key, value ) # validate prioritization scheme - if config.priority.lower().startswith("eq"): + if config.priority.lower().startswith( "eq" ): config.priority = "equal" - elif config.priority.lower().startswith("proc"): + elif config.priority.lower().startswith( "proc" ): config.priority = "processors" - elif config.priority.lower().startswith("tim"): + elif config.priority.lower().startswith( "tim" ): config.priority = "timing" else: - Error("priority '%s' is not valid" % config.priority) + Error( "priority '%s' is not valid" % config.priority ) ## environment variables for k, v in config.environment.items(): - os.environ[k] = v + os.environ[ k ] = v diff --git a/geos_ats_package/geos_ats/environment_setup.py b/geos_ats_package/geos_ats/environment_setup.py index 145e548..509373f 100644 --- a/geos_ats_package/geos_ats/environment_setup.py +++ b/geos_ats_package/geos_ats/environment_setup.py @@ -4,53 +4,57 @@ import argparse -def setup_ats(src_path, build_path, ats_xargs, ats_machine, ats_machine_dir): - bin_dir = os.path.join(build_path, "bin") - geos_ats_fname = os.path.join(bin_dir, "run_geos_ats") - ats_dir = os.path.abspath(os.path.join(src_path, "integratedTests", "tests", "allTests")) - test_path = os.path.join(build_path, "integratedTests") - link_path = os.path.join(test_path, "integratedTests") - run_script_fname = os.path.join(test_path, "geos_ats.sh") - log_dir = os.path.join(test_path, "TestResults") +def setup_ats( src_path, build_path, ats_xargs, ats_machine, ats_machine_dir ): + bin_dir = os.path.join( build_path, "bin" ) + geos_ats_fname = os.path.join( bin_dir, "run_geos_ats" ) + ats_dir = os.path.abspath( os.path.join( src_path, "integratedTests", "tests", "allTests" ) ) + test_path = os.path.join( build_path, "integratedTests" ) + link_path = os.path.join( test_path, "integratedTests" ) + run_script_fname = os.path.join( test_path, "geos_ats.sh" ) + log_dir = os.path.join( test_path, "TestResults" ) # Create a symbolic link to test directory - if os.path.islink(link_path): - print('integratedTests symlink already exists') + if os.path.islink( link_path ): + print( 'integratedTests symlink already exists' ) else: - os.symlink(ats_dir, link_path) + os.symlink( ats_dir, link_path ) # Build extra arguments that should be passed to ATS - joined_args = [' '.join(x) for x in ats_xargs] - ats_args = ' '.join([f'--ats {x}' for x in joined_args]) + joined_args = [ ' '.join( x ) for x in ats_xargs ] + ats_args = ' '.join( [ f'--ats {x}' for x in joined_args ] ) if ats_machine: ats_args += f' --machine {ats_machine}' if ats_machine_dir: ats_args += f' --machine-dir {ats_machine_dir}' # Write the bash script to run ats. - with open(run_script_fname, "w") as g: - g.write("#!/bin/bash\n") - g.write(f"{geos_ats_fname} {bin_dir} --workingDir {ats_dir} --logs {log_dir} {ats_args} \"$@\"\n") + with open( run_script_fname, "w" ) as g: + g.write( "#!/bin/bash\n" ) + g.write( f"{geos_ats_fname} {bin_dir} --workingDir {ats_dir} --logs {log_dir} {ats_args} \"$@\"\n" ) # Make the script executable - st = os.stat(run_script_fname) - os.chmod(run_script_fname, st.st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH) + st = os.stat( run_script_fname ) + os.chmod( run_script_fname, st.st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH ) def main(): # Cmake may combine the final arguments into a string literal # Manually unpack those before parsing - final_arg = sys.argv.pop(-1) - sys.argv.extend(final_arg.split()) - - parser = argparse.ArgumentParser(description="Setup ATS script") - parser.add_argument("src_path", type=str, help="GEOS src path") - parser.add_argument("build_path", type=str, help="GEOS build path") - parser.add_argument("--ats", nargs='+', default=[], action="append", help="Arguments that should be passed to ats") - parser.add_argument("--machine", type=str, default='', help="ATS machine name") - parser.add_argument("--machine-dir", type=str, default='', help="ATS machine directory") + final_arg = sys.argv.pop( -1 ) + sys.argv.extend( final_arg.split() ) + + parser = argparse.ArgumentParser( description="Setup ATS script" ) + parser.add_argument( "src_path", type=str, help="GEOS src path" ) + parser.add_argument( "build_path", type=str, help="GEOS build path" ) + parser.add_argument( "--ats", + nargs='+', + default=[], + action="append", + help="Arguments that should be passed to ats" ) + parser.add_argument( "--machine", type=str, default='', help="ATS machine name" ) + parser.add_argument( "--machine-dir", type=str, default='', help="ATS machine directory" ) options, unkown_args = parser.parse_known_args() - setup_ats(options.src_path, options.build_path, options.ats, options.machine, options.machine_dir) + setup_ats( options.src_path, options.build_path, options.ats, options.machine, options.machine_dir ) if __name__ == '__main__': diff --git a/geos_ats_package/geos_ats/geos_ats_debug.py b/geos_ats_package/geos_ats/geos_ats_debug.py index 1535a18..25f368c 100644 --- a/geos_ats_package/geos_ats/geos_ats_debug.py +++ b/geos_ats_package/geos_ats/geos_ats_debug.py @@ -3,29 +3,29 @@ import glob from pathlib import Path -mod_path = Path(__file__).resolve().parents[1] -sys.path.append(os.path.abspath(mod_path)) +mod_path = Path( __file__ ).resolve().parents[ 1 ] +sys.path.append( os.path.abspath( mod_path ) ) from geos_ats import main -def debug_geosats(build_root='~/GEOS/build-quartz-gcc@12-release', extra_args=[]): +def debug_geosats( build_root='~/GEOS/build-quartz-gcc@12-release', extra_args=[] ): # Search for and parse the ats script - build_root = os.path.expanduser(build_root) - ats_script = os.path.join(build_root, 'integratedTests', 'geos_ats.sh') - if not os.path.isfile(ats_script): + build_root = os.path.expanduser( build_root ) + ats_script = os.path.join( build_root, 'integratedTests', 'geos_ats.sh' ) + if not os.path.isfile( ats_script ): raise InputError( - 'Could not find geos_ats.sh at the expected location... Make sure to run \"make ats_environment\"') + 'Could not find geos_ats.sh at the expected location... Make sure to run \"make ats_environment\"' ) - with open(ats_script, 'r') as f: + with open( ats_script, 'r' ) as f: header = f.readline() ats_args = f.readline().split() - sys.argv.extend(ats_args[1:-1]) - sys.argv.extend(extra_args) + sys.argv.extend( ats_args[ 1:-1 ] ) + sys.argv.extend( extra_args ) main.main() -if (__name__ == '__main__'): +if ( __name__ == '__main__' ): # debug_geosats(extra_args=['-a', 'veryclean']) # debug_geosats(extra_args=['-a', 'rebaselinefailed']) debug_geosats() diff --git a/geos_ats_package/geos_ats/helpers/curve_check.py b/geos_ats_package/geos_ats/helpers/curve_check.py index e2c8bc2..bdf2213 100644 --- a/geos_ats_package/geos_ats/helpers/curve_check.py +++ b/geos_ats_package/geos_ats/helpers/curve_check.py @@ -19,7 +19,7 @@ DEFAULT_SET_NAME = 'empty_setName' -def interpolate_values_time(ta, xa, tb): +def interpolate_values_time( ta, xa, tb ): """ Interpolate array values in time @@ -31,25 +31,25 @@ def interpolate_values_time(ta, xa, tb): Returns: np.ndarray: Interpolated value array """ - N = list(np.shape(xa)) - M = len(tb) + N = list( np.shape( xa ) ) + M = len( tb ) - if (len(N) == 1): - return interp1d(ta, xa)(tb) + if ( len( N ) == 1 ): + return interp1d( ta, xa )( tb ) else: # Reshape the input array so that we can work on the non-time axes - S = np.product(N[1:]) - xc = np.reshape(xa, (N[0], S)) - xd = np.zeros((len(tb), S)) - for ii in range(S): - xd[:, ii] = interp1d(ta, xc[:, ii])(tb) + S = np.product( N[ 1: ] ) + xc = np.reshape( xa, ( N[ 0 ], S ) ) + xd = np.zeros( ( len( tb ), S ) ) + for ii in range( S ): + xd[ :, ii ] = interp1d( ta, xc[ :, ii ] )( tb ) # Return the array to it's expected shape - N[0] = M - return np.reshape(xd, N) + N[ 0 ] = M + return np.reshape( xd, N ) -def evaluate_external_script(script, fn, data): +def evaluate_external_script( script, fn, data ): """ Evaluate an external script to produce the curve @@ -60,23 +60,23 @@ def evaluate_external_script(script, fn, data): Returns: np.ndarray: Curve values """ - script = os.path.abspath(script) - if os.path.isfile(script): - module_name = os.path.split(script)[1] - module_name = module_name[:module_name.rfind('.')] - spec = importlib.util.spec_from_file_location(module_name, script) - module = importlib.util.module_from_spec(spec) - sys.modules[module_name] = module - spec.loader.exec_module(module) - if hasattr(module, fn): - return getattr(module, fn)(**data) + script = os.path.abspath( script ) + if os.path.isfile( script ): + module_name = os.path.split( script )[ 1 ] + module_name = module_name[ :module_name.rfind( '.' ) ] + spec = importlib.util.spec_from_file_location( module_name, script ) + module = importlib.util.module_from_spec( spec ) + sys.modules[ module_name ] = module + spec.loader.exec_module( module ) + if hasattr( module, fn ): + return getattr( module, fn )( **data ) else: - raise Exception(f'External script does not contain the expected function ({fn})') + raise Exception( f'External script does not contain the expected function ({fn})' ) else: - raise FileNotFoundError(f'Could not find script: {script}') + raise FileNotFoundError( f'Could not find script: {script}' ) -def check_diff(parameter_name, set_name, target, baseline, tolerance, errors, modifier='baseline'): +def check_diff( parameter_name, set_name, target, baseline, tolerance, errors, modifier='baseline' ): """ Compute the L2-norm of the diff and compare to the set tolerance @@ -92,14 +92,14 @@ def check_diff(parameter_name, set_name, target, baseline, tolerance, errors, mo np.ndarray: Interpolated value array """ dx = target - baseline - diff = np.sqrt(np.sum(dx * dx)) / dx.size - if (diff > tolerance): + diff = np.sqrt( np.sum( dx * dx ) ) / dx.size + if ( diff > tolerance ): errors.append( f'{modifier}_{parameter_name}_{set_name} diff exceeds tolerance: ||t-b||/N={diff}, {modifier}_tolerance={tolerance}' ) -def curve_check_figure(parameter_name, location_str, set_name, data, data_sizes, output_root, ncol, units_time): +def curve_check_figure( parameter_name, location_str, set_name, data, data_sizes, output_root, ncol, units_time ): """ Generate figures associated with the curve check @@ -135,61 +135,61 @@ def curve_check_figure(parameter_name, location_str, set_name, data, data_sizes, value_key = f'{parameter_name} {set_name}' location_key = f'{parameter_name} {location_str} {set_name}' - s = data_sizes[parameter_name][set_name] - N = list(s[list(s.keys())[0]]) - nrow = int(np.ceil(float(N[2]) / ncol)) - time_scale = unit_map[units_time] + s = data_sizes[ parameter_name ][ set_name ] + N = list( s[ list( s.keys() )[ 0 ] ] ) + nrow = int( np.ceil( float( N[ 2 ] ) / ncol ) ) + time_scale = unit_map[ units_time ] horizontal_label = f'Time ({units_time})' # Create the figure - fig = plt.figure(figsize=(8, 6)) - for ii in range(N[2]): - ax = plt.subplot(nrow, ncol, ii + 1) + fig = plt.figure( figsize=( 8, 6 ) ) + for ii in range( N[ 2 ] ): + ax = plt.subplot( nrow, ncol, ii + 1 ) for k in s.keys(): - t = np.squeeze(data[k][time_key]) / time_scale - x = data[k][value_key][:, :, ii] - position = data[k][location_key][0, :, 0] + t = np.squeeze( data[ k ][ time_key ] ) / time_scale + x = data[ k ][ value_key ][ :, :, ii ] + position = data[ k ][ location_key ][ 0, :, 0 ] - if (N[1] == 1): - ax.plot(t, x, label=k, **style[k]) + if ( N[ 1 ] == 1 ): + ax.plot( t, x, label=k, **style[ k ] ) else: - cmap = plt.get_cmap('jet') - if N[0] > N[1]: + cmap = plt.get_cmap( 'jet' ) + if N[ 0 ] > N[ 1 ]: # Timestep axis - for jj in range(N[1]): + for jj in range( N[ 1 ] ): try: - c = cmap(float(jj) / N[1]) + c = cmap( float( jj ) / N[ 1 ] ) kwargs = {} - if (jj == 0): - kwargs['label'] = k - ax.plot(t, x[:, jj], color=c, **style[k], **kwargs) + if ( jj == 0 ): + kwargs[ 'label' ] = k + ax.plot( t, x[ :, jj ], color=c, **style[ k ], **kwargs ) except Exception as e: - print(f'Error rendering curve {value_key}: {str(e)}') + print( f'Error rendering curve {value_key}: {str(e)}' ) else: # Spatial axis horizontal_label = 'X (m)' - for jj in range(N[0]): + for jj in range( N[ 0 ] ): try: - c = cmap(float(jj) / N[0]) + c = cmap( float( jj ) / N[ 0 ] ) kwargs = {} - if (jj == 0): - kwargs['label'] = k - ax.plot(position, x[jj, :], color=c, **style[k], **kwargs) + if ( jj == 0 ): + kwargs[ 'label' ] = k + ax.plot( position, x[ jj, : ], color=c, **style[ k ], **kwargs ) except Exception as e: - print(f'Error rendering curve {value_key}: {str(e)}') + print( f'Error rendering curve {value_key}: {str(e)}' ) # Set labels - ax.set_xlabel(horizontal_label) - ax.set_ylabel(value_key) + ax.set_xlabel( horizontal_label ) + ax.set_ylabel( value_key ) # ax.set_xlim(t[[0, -1]]) - ax.legend(loc=2) + ax.legend( loc=2 ) plt.tight_layout() - fig.savefig(os.path.join(output_root, f'{parameter_name}_{set_name}'), dpi=200) + fig.savefig( os.path.join( output_root, f'{parameter_name}_{set_name}' ), dpi=200 ) -def compare_time_history_curves(fname, baseline, curve, tolerance, output, output_n_column, units_time, - script_instructions): +def compare_time_history_curves( fname, baseline, curve, tolerance, output, output_n_column, units_time, + script_instructions ): """ Compute time history curves @@ -207,88 +207,88 @@ def compare_time_history_curves(fname, baseline, curve, tolerance, output, outpu tuple: warnings, errors """ # Setup - files = {'target': fname, 'baseline': baseline} + files = { 'target': fname, 'baseline': baseline } warnings = [] errors = [] - location_string_options = ['ReferencePosition', 'elementCenter'] + location_string_options = [ 'ReferencePosition', 'elementCenter' ] location_strings = {} tol = {} - if len(curve) != len(tolerance): + if len( curve ) != len( tolerance ): raise Exception( - f'Curvecheck inputs must be of the same length: curves ({len(curve)}) and tolerance ({len(tolerance)})') + f'Curvecheck inputs must be of the same length: curves ({len(curve)}) and tolerance ({len(tolerance)})' ) # Load data and check sizes data = {} data_sizes = {} for k, f in files.items(): - if os.path.isfile(f): - data[k] = hdf5_wrapper.hdf5_wrapper(f).get_copy() + if os.path.isfile( f ): + data[ k ] = hdf5_wrapper.hdf5_wrapper( f ).get_copy() else: - errors.append(f'{k} file not found: {f}') + errors.append( f'{k} file not found: {f}' ) continue - for (p, s), t in zip(curve, tolerance): + for ( p, s ), t in zip( curve, tolerance ): if s == DEFAULT_SET_NAME: key = f'{p}' else: key = f'{p} {s}' - if f'{p} Time' not in data[k].keys(): - errors.append(f'Value not found in {k} file: {p}') + if f'{p} Time' not in data[ k ].keys(): + errors.append( f'Value not found in {k} file: {p}' ) continue - if key not in data[k].keys(): - errors.append(f'Set not found in {k} file: {s}') + if key not in data[ k ].keys(): + errors.append( f'Set not found in {k} file: {s}' ) continue # Check for a location string (this may not be consistent across the same file) - for kb in data[k].keys(): + for kb in data[ k ].keys(): for kc in location_string_options: - if (kc in kb) and (p in kb): - location_strings[p] = kc + if ( kc in kb ) and ( p in kb ): + location_strings[ p ] = kc if p not in location_strings: - test_keys = ', '.join(location_string_options) - all_keys = ', '.join(data[k].keys()) + test_keys = ', '.join( location_string_options ) + all_keys = ', '.join( data[ k ].keys() ) errors.append( f'Could not find location string for parameter: {p}, search_options=({test_keys}), all_options={all_keys}' ) # Check data sizes in the initial loop to make later logic easier if p not in data_sizes: - data_sizes[p] = {} - tol[p] = {} + data_sizes[ p ] = {} + tol[ p ] = {} - if s not in data_sizes[p]: - data_sizes[p][s] = {} - tol[p][s] = float(t[0]) + if s not in data_sizes[ p ]: + data_sizes[ p ][ s ] = {} + tol[ p ][ s ] = float( t[ 0 ] ) - data_sizes[p][s][k] = list(np.shape(data[k][key])) + data_sizes[ p ][ s ][ k ] = list( np.shape( data[ k ][ key ] ) ) # Record requested tolerance if p not in tol: - tol[p] = {} - if s not in tol[p]: - tol[p][s] = t + tol[ p ] = {} + if s not in tol[ p ]: + tol[ p ][ s ] = t # Generate script-based curve - if script_instructions and (len(data) > 0): - data['script'] = {} + if script_instructions and ( len( data ) > 0 ): + data[ 'script' ] = {} try: for script, fn, p, s in script_instructions: - k = location_strings[p] - data['script'][f'{p} Time'] = data['target'][f'{p} Time'] + k = location_strings[ p ] + data[ 'script' ][ f'{p} Time' ] = data[ 'target' ][ f'{p} Time' ] key = f'{p} {k}' key2 = f'{p}' if s != DEFAULT_SET_NAME: key += f' {s}' key2 += f' {s}' - data['script'][key] = data['target'][key] - data['script'][key2] = evaluate_external_script(script, fn, data['target']) - data_sizes[p][s]['script'] = list(np.shape(data['script'][key2])) + data[ 'script' ][ key ] = data[ 'target' ][ key ] + data[ 'script' ][ key2 ] = evaluate_external_script( script, fn, data[ 'target' ] ) + data_sizes[ p ][ s ][ 'script' ] = list( np.shape( data[ 'script' ][ key2 ] ) ) except Exception as e: - errors.append(str(e)) + errors.append( str( e ) ) # Reshape data if necessary so that they have a predictable number of dimensions for k in data.keys(): @@ -296,12 +296,12 @@ def compare_time_history_curves(fname, baseline, curve, tolerance, output, outpu key = f'{p}' if s != DEFAULT_SET_NAME: key += f' {s}' - if (len(data_sizes[p][s][k]) == 1): - data[k][key] = np.reshape(data[k][key], (-1, 1, 1)) - data_sizes[p][s][k].append(1) - elif (len(data_sizes[p][s][k]) == 2): - data[k][key] = np.expand_dims(data[k][key], -1) - data_sizes[p][s][k].append(1) + if ( len( data_sizes[ p ][ s ][ k ] ) == 1 ): + data[ k ][ key ] = np.reshape( data[ k ][ key ], ( -1, 1, 1 ) ) + data_sizes[ p ][ s ][ k ].append( 1 ) + elif ( len( data_sizes[ p ][ s ][ k ] ) == 2 ): + data[ k ][ key ] = np.expand_dims( data[ k ][ key ], -1 ) + data_sizes[ p ][ s ][ k ].append( 1 ) # Check data diffs size_err = '{}_{} values have different sizes: target=({},{},{}) baseline=({},{},{})' @@ -311,33 +311,34 @@ def compare_time_history_curves(fname, baseline, curve, tolerance, output, outpu if s != DEFAULT_SET_NAME: key += f' {s}' - if (('baseline' in set_sizes) and ('target' in set_sizes)): - xa = data['target'][key] - xb = data['baseline'][key] - if set_sizes['target'] == set_sizes['baseline']: - check_diff(p, s, xa, xb, tol[p][s], errors) + if ( ( 'baseline' in set_sizes ) and ( 'target' in set_sizes ) ): + xa = data[ 'target' ][ key ] + xb = data[ 'baseline' ][ key ] + if set_sizes[ 'target' ] == set_sizes[ 'baseline' ]: + check_diff( p, s, xa, xb, tol[ p ][ s ], errors ) else: - warnings.append(size_err.format(p, s, *set_sizes['target'], *set_sizes['baseline'])) + warnings.append( size_err.format( p, s, *set_sizes[ 'target' ], *set_sizes[ 'baseline' ] ) ) # Check whether the data can be interpolated - if (len(set_sizes['baseline']) == 1) or (set_sizes['target'][1:] == set_sizes['baseline'][1:]): - warnings.append(f'Interpolating target curve in time: {p}_{s}') - ta = data['target'][f'{p} Time'] - tb = data['baseline'][f'{p} Time'] - xc = interpolate_values_time(ta, xa, tb) - check_diff(p, s, xc, xb, tol[p][s], errors) + if ( len( set_sizes[ 'baseline' ] ) == 1 ) or ( set_sizes[ 'target' ][ 1: ] + == set_sizes[ 'baseline' ][ 1: ] ): + warnings.append( f'Interpolating target curve in time: {p}_{s}' ) + ta = data[ 'target' ][ f'{p} Time' ] + tb = data[ 'baseline' ][ f'{p} Time' ] + xc = interpolate_values_time( ta, xa, tb ) + check_diff( p, s, xc, xb, tol[ p ][ s ], errors ) else: - errors.append(f'Cannot perform a curve check for {p}_{s}') - if (('script' in set_sizes) and ('target' in set_sizes)): - xa = data['target'][key] - xb = data['script'][key] - check_diff(p, s, xa, xb, tol[p][s], errors, modifier='script') + errors.append( f'Cannot perform a curve check for {p}_{s}' ) + if ( ( 'script' in set_sizes ) and ( 'target' in set_sizes ) ): + xa = data[ 'target' ][ key ] + xb = data[ 'script' ][ key ] + check_diff( p, s, xa, xb, tol[ p ][ s ], errors, modifier='script' ) # Render figures - output = os.path.expanduser(output) - os.makedirs(output, exist_ok=True) + output = os.path.expanduser( output ) + os.makedirs( output, exist_ok=True ) for p, set_data in data_sizes.items(): for s, set_sizes in set_data.items(): - curve_check_figure(p, location_strings[p], s, data, data_sizes, output, output_n_column, units_time) + curve_check_figure( p, location_strings[ p ], s, data, data_sizes, output, output_n_column, units_time ) return warnings, errors @@ -351,68 +352,68 @@ def curve_check_parser(): """ # Custom action class - class PairAction(argparse.Action): - - def __call__(self, parser, namespace, values, option_string=None): - pairs = getattr(namespace, self.dest) - if len(values) == 1: - pairs.append((values[0], DEFAULT_SET_NAME)) - elif len(values) == 2: - pairs.append((values[0], values[1])) + class PairAction( argparse.Action ): + + def __call__( self, parser, namespace, values, option_string=None ): + pairs = getattr( namespace, self.dest ) + if len( values ) == 1: + pairs.append( ( values[ 0 ], DEFAULT_SET_NAME ) ) + elif len( values ) == 2: + pairs.append( ( values[ 0 ], values[ 1 ] ) ) else: - raise Exception('Only a single value or a pair of values are expected') + raise Exception( 'Only a single value or a pair of values are expected' ) - setattr(namespace, self.dest, pairs) + setattr( namespace, self.dest, pairs ) # Custom action class - class ScriptAction(argparse.Action): + class ScriptAction( argparse.Action ): - def __call__(self, parser, namespace, values, option_string=None): + def __call__( self, parser, namespace, values, option_string=None ): - scripts = getattr(namespace, self.dest) - scripts.append(values) - N = len(values) - if (N < 3) or (N > 4): - raise Exception('The -s option requires 3 or 4 inputs') + scripts = getattr( namespace, self.dest ) + scripts.append( values ) + N = len( values ) + if ( N < 3 ) or ( N > 4 ): + raise Exception( 'The -s option requires 3 or 4 inputs' ) elif N == 3: - values.append(DEFAULT_SET_NAME) + values.append( DEFAULT_SET_NAME ) - setattr(namespace, self.dest, scripts) + setattr( namespace, self.dest, scripts ) parser = argparse.ArgumentParser() - parser.add_argument("filename", help="Path to the time history file") - parser.add_argument("baseline", help="Path to the baseline file") - parser.add_argument('-c', - '--curve', - nargs='+', - action=PairAction, - help='Curves to check (value) or (value, setname)', - default=[]) - parser.add_argument("-t", - "--tolerance", - nargs='+', - action='append', - help=f"The tolerance for each curve check diffs (||x-y||/N)", - default=[]) - parser.add_argument("-w", - "--Werror", - action="store_true", - help="Force all warnings to be errors, default is False.", - default=False) - parser.add_argument("-o", "--output", help="Output figures to this directory", default='./curve_check_figures') - unit_choices = list(unit_map.keys()) - parser.add_argument("-n", "--n-column", help="Number of columns for the output figure", default=1) - parser.add_argument("-u", - "--units-time", - help=f"Time units for plots (default=seconds)", - choices=unit_choices, - default='seconds') - parser.add_argument("-s", - "--script", - nargs='+', - action=ScriptAction, - help='Python script instructions (path, function, value, setname)', - default=[]) + parser.add_argument( "filename", help="Path to the time history file" ) + parser.add_argument( "baseline", help="Path to the baseline file" ) + parser.add_argument( '-c', + '--curve', + nargs='+', + action=PairAction, + help='Curves to check (value) or (value, setname)', + default=[] ) + parser.add_argument( "-t", + "--tolerance", + nargs='+', + action='append', + help=f"The tolerance for each curve check diffs (||x-y||/N)", + default=[] ) + parser.add_argument( "-w", + "--Werror", + action="store_true", + help="Force all warnings to be errors, default is False.", + default=False ) + parser.add_argument( "-o", "--output", help="Output figures to this directory", default='./curve_check_figures' ) + unit_choices = list( unit_map.keys() ) + parser.add_argument( "-n", "--n-column", help="Number of columns for the output figure", default=1 ) + parser.add_argument( "-u", + "--units-time", + help=f"Time units for plots (default=seconds)", + choices=unit_choices, + default='seconds' ) + parser.add_argument( "-s", + "--script", + nargs='+', + action=ScriptAction, + help='Python script instructions (path, function, value, setname)', + default=[] ) return parser @@ -423,22 +424,22 @@ def main(): """ parser = curve_check_parser() args = parser.parse_args() - warnings, errors = compare_time_history_curves(args.filename, args.baseline, args.curve, args.tolerance, - args.output, args.n_column, args.units_time, args.script) + warnings, errors = compare_time_history_curves( args.filename, args.baseline, args.curve, args.tolerance, + args.output, args.n_column, args.units_time, args.script ) # Write errors/warnings to the screen if args.Werror: - errors.extend(warnings[:]) + errors.extend( warnings[ : ] ) warnings = [] - if len(warnings): - print('Curve check warnings:') - print('\n'.join(warnings)) + if len( warnings ): + print( 'Curve check warnings:' ) + print( '\n'.join( warnings ) ) - if len(errors): - print('Curve check errors:') - print('\n'.join(errors)) - raise Exception(f'Curve check produced {len(errors)} errors!') + if len( errors ): + print( 'Curve check errors:' ) + print( '\n'.join( errors ) ) + raise Exception( f'Curve check produced {len(errors)} errors!' ) if __name__ == '__main__': diff --git a/geos_ats_package/geos_ats/helpers/permute_array.py b/geos_ats_package/geos_ats/helpers/permute_array.py index bbc9491..e1cc9c6 100644 --- a/geos_ats_package/geos_ats/helpers/permute_array.py +++ b/geos_ats_package/geos_ats/helpers/permute_array.py @@ -1,44 +1,44 @@ -import numpy as np # type: ignore[import] +import numpy as np # type: ignore[import] import logging -logger = logging.getLogger('geos_ats') +logger = logging.getLogger( 'geos_ats' ) -def permuteArray(data, shape, permutation): - if len(shape.shape) != 1: - msg = "The shape must be a 1D array, not %s" % len(shape.shape) +def permuteArray( data, shape, permutation ): + if len( shape.shape ) != 1: + msg = "The shape must be a 1D array, not %s" % len( shape.shape ) return None, msg - if len(permutation.shape) != 1: - msg = "The permutation must be a 1D array, not %s" % len(permutation.shape) + if len( permutation.shape ) != 1: + msg = "The permutation must be a 1D array, not %s" % len( permutation.shape ) return None, msg if shape.size != permutation.size: - msg = "The shape and permutation arrays must have the same length. %s != %s" % (shape.size, permutation.size) + msg = "The shape and permutation arrays must have the same length. %s != %s" % ( shape.size, permutation.size ) return None, msg - if np.prod(shape) != data.size: - msg = "The shape is %s which yields a total size of %s but the real size is %s." % (shape, np.prod(shape), - data.size) + if np.prod( shape ) != data.size: + msg = "The shape is %s which yields a total size of %s but the real size is %s." % ( shape, np.prod( shape ), + data.size ) return None, msg - if np.any(np.sort(permutation) != np.arange(shape.size)): + if np.any( np.sort( permutation ) != np.arange( shape.size ) ): msg = "The permutation is not valid: %s" % permutation return None, msg - shape_in_memory = np.empty_like(shape) - for i in range(shape.size): - shape_in_memory[i] = shape[permutation[i]] + shape_in_memory = np.empty_like( shape ) + for i in range( shape.size ): + shape_in_memory[ i ] = shape[ permutation[ i ] ] - data = data.reshape(shape_in_memory) + data = data.reshape( shape_in_memory ) - reverse_permutation = np.empty_like(permutation) - for i in range(permutation.size): - reverse_permutation[permutation[i]] = i + reverse_permutation = np.empty_like( permutation ) + for i in range( permutation.size ): + reverse_permutation[ permutation[ i ] ] = i - data = np.transpose(data, reverse_permutation) - if np.any(data.shape != shape): - msg = "Reshaping failed. Shape is %s but should be %s" % (data.shape, shape) + data = np.transpose( data, reverse_permutation ) + if np.any( data.shape != shape ): + msg = "Reshaping failed. Shape is %s but should be %s" % ( data.shape, shape ) return None, msg return data, None @@ -46,46 +46,46 @@ def permuteArray(data, shape, permutation): if __name__ == "__main__": - def testPermuteArray(shape, permutation): - original_data = np.arange(np.prod(shape)).reshape(shape) - transposed_data = original_data.transpose(permutation) - - reshaped_data, error_msg = permuteArray(transposed_data.flatten(), shape, permutation) - assert (error_msg is None) - assert (np.all(original_data == reshaped_data)) - - testPermuteArray(np.array([2, 3]), np.array([0, 1])) - testPermuteArray(np.array([2, 3]), np.array([1, 0])) - - testPermuteArray(np.array([2, 3, 4]), np.array([0, 1, 2])) - testPermuteArray(np.array([2, 3, 4]), np.array([1, 0, 2])) - testPermuteArray(np.array([2, 3, 4]), np.array([0, 2, 1])) - testPermuteArray(np.array([2, 3, 4]), np.array([2, 0, 1])) - testPermuteArray(np.array([2, 3, 4]), np.array([1, 2, 0])) - testPermuteArray(np.array([2, 3, 4]), np.array([2, 1, 0])) - - testPermuteArray(np.array([2, 3, 4, 5]), np.array([0, 1, 2, 3])) - testPermuteArray(np.array([2, 3, 4, 5]), np.array([1, 0, 2, 3])) - testPermuteArray(np.array([2, 3, 4, 5]), np.array([0, 2, 1, 3])) - testPermuteArray(np.array([2, 3, 4, 5]), np.array([2, 0, 1, 3])) - testPermuteArray(np.array([2, 3, 4, 5]), np.array([1, 2, 0, 3])) - testPermuteArray(np.array([2, 3, 4, 5]), np.array([2, 1, 0, 3])) - testPermuteArray(np.array([2, 3, 4, 5]), np.array([0, 1, 3, 2])) - testPermuteArray(np.array([2, 3, 4, 5]), np.array([1, 0, 3, 2])) - testPermuteArray(np.array([2, 3, 4, 5]), np.array([0, 2, 3, 1])) - testPermuteArray(np.array([2, 3, 4, 5]), np.array([2, 0, 3, 1])) - testPermuteArray(np.array([2, 3, 4, 5]), np.array([1, 2, 3, 0])) - testPermuteArray(np.array([2, 3, 4, 5]), np.array([2, 1, 3, 0])) - testPermuteArray(np.array([2, 3, 4, 5]), np.array([0, 3, 1, 2])) - testPermuteArray(np.array([2, 3, 4, 5]), np.array([1, 3, 0, 2])) - testPermuteArray(np.array([2, 3, 4, 5]), np.array([0, 3, 2, 1])) - testPermuteArray(np.array([2, 3, 4, 5]), np.array([2, 3, 0, 1])) - testPermuteArray(np.array([2, 3, 4, 5]), np.array([1, 3, 2, 0])) - testPermuteArray(np.array([2, 3, 4, 5]), np.array([2, 3, 1, 0])) - testPermuteArray(np.array([2, 3, 4, 5]), np.array([3, 0, 1, 2])) - testPermuteArray(np.array([2, 3, 4, 5]), np.array([3, 1, 0, 2])) - testPermuteArray(np.array([2, 3, 4, 5]), np.array([3, 0, 2, 1])) - testPermuteArray(np.array([2, 3, 4, 5]), np.array([3, 2, 0, 1])) - testPermuteArray(np.array([2, 3, 4, 5]), np.array([3, 1, 2, 0])) - testPermuteArray(np.array([2, 3, 4, 5]), np.array([3, 2, 1, 0])) - logger.info("Success") + def testPermuteArray( shape, permutation ): + original_data = np.arange( np.prod( shape ) ).reshape( shape ) + transposed_data = original_data.transpose( permutation ) + + reshaped_data, error_msg = permuteArray( transposed_data.flatten(), shape, permutation ) + assert ( error_msg is None ) + assert ( np.all( original_data == reshaped_data ) ) + + testPermuteArray( np.array( [ 2, 3 ] ), np.array( [ 0, 1 ] ) ) + testPermuteArray( np.array( [ 2, 3 ] ), np.array( [ 1, 0 ] ) ) + + testPermuteArray( np.array( [ 2, 3, 4 ] ), np.array( [ 0, 1, 2 ] ) ) + testPermuteArray( np.array( [ 2, 3, 4 ] ), np.array( [ 1, 0, 2 ] ) ) + testPermuteArray( np.array( [ 2, 3, 4 ] ), np.array( [ 0, 2, 1 ] ) ) + testPermuteArray( np.array( [ 2, 3, 4 ] ), np.array( [ 2, 0, 1 ] ) ) + testPermuteArray( np.array( [ 2, 3, 4 ] ), np.array( [ 1, 2, 0 ] ) ) + testPermuteArray( np.array( [ 2, 3, 4 ] ), np.array( [ 2, 1, 0 ] ) ) + + testPermuteArray( np.array( [ 2, 3, 4, 5 ] ), np.array( [ 0, 1, 2, 3 ] ) ) + testPermuteArray( np.array( [ 2, 3, 4, 5 ] ), np.array( [ 1, 0, 2, 3 ] ) ) + testPermuteArray( np.array( [ 2, 3, 4, 5 ] ), np.array( [ 0, 2, 1, 3 ] ) ) + testPermuteArray( np.array( [ 2, 3, 4, 5 ] ), np.array( [ 2, 0, 1, 3 ] ) ) + testPermuteArray( np.array( [ 2, 3, 4, 5 ] ), np.array( [ 1, 2, 0, 3 ] ) ) + testPermuteArray( np.array( [ 2, 3, 4, 5 ] ), np.array( [ 2, 1, 0, 3 ] ) ) + testPermuteArray( np.array( [ 2, 3, 4, 5 ] ), np.array( [ 0, 1, 3, 2 ] ) ) + testPermuteArray( np.array( [ 2, 3, 4, 5 ] ), np.array( [ 1, 0, 3, 2 ] ) ) + testPermuteArray( np.array( [ 2, 3, 4, 5 ] ), np.array( [ 0, 2, 3, 1 ] ) ) + testPermuteArray( np.array( [ 2, 3, 4, 5 ] ), np.array( [ 2, 0, 3, 1 ] ) ) + testPermuteArray( np.array( [ 2, 3, 4, 5 ] ), np.array( [ 1, 2, 3, 0 ] ) ) + testPermuteArray( np.array( [ 2, 3, 4, 5 ] ), np.array( [ 2, 1, 3, 0 ] ) ) + testPermuteArray( np.array( [ 2, 3, 4, 5 ] ), np.array( [ 0, 3, 1, 2 ] ) ) + testPermuteArray( np.array( [ 2, 3, 4, 5 ] ), np.array( [ 1, 3, 0, 2 ] ) ) + testPermuteArray( np.array( [ 2, 3, 4, 5 ] ), np.array( [ 0, 3, 2, 1 ] ) ) + testPermuteArray( np.array( [ 2, 3, 4, 5 ] ), np.array( [ 2, 3, 0, 1 ] ) ) + testPermuteArray( np.array( [ 2, 3, 4, 5 ] ), np.array( [ 1, 3, 2, 0 ] ) ) + testPermuteArray( np.array( [ 2, 3, 4, 5 ] ), np.array( [ 2, 3, 1, 0 ] ) ) + testPermuteArray( np.array( [ 2, 3, 4, 5 ] ), np.array( [ 3, 0, 1, 2 ] ) ) + testPermuteArray( np.array( [ 2, 3, 4, 5 ] ), np.array( [ 3, 1, 0, 2 ] ) ) + testPermuteArray( np.array( [ 2, 3, 4, 5 ] ), np.array( [ 3, 0, 2, 1 ] ) ) + testPermuteArray( np.array( [ 2, 3, 4, 5 ] ), np.array( [ 3, 2, 0, 1 ] ) ) + testPermuteArray( np.array( [ 2, 3, 4, 5 ] ), np.array( [ 3, 1, 2, 0 ] ) ) + testPermuteArray( np.array( [ 2, 3, 4, 5 ] ), np.array( [ 3, 2, 1, 0 ] ) ) + logger.info( "Success" ) diff --git a/geos_ats_package/geos_ats/helpers/restart_check.py b/geos_ats_package/geos_ats/helpers/restart_check.py index 261c3bf..6172045 100644 --- a/geos_ats_package/geos_ats/helpers/restart_check.py +++ b/geos_ats_package/geos_ats/helpers/restart_check.py @@ -1,6 +1,6 @@ -import h5py # type: ignore[import] -from mpi4py import MPI # type: ignore[import] -import numpy as np # type: ignore[import] +import h5py # type: ignore[import] +from mpi4py import MPI # type: ignore[import] +import numpy as np # type: ignore[import] import sys import os import re @@ -8,30 +8,30 @@ import logging from pathlib import Path try: - from geos_ats.helpers.permute_array import permuteArray # type: ignore[import] + from geos_ats.helpers.permute_array import permuteArray # type: ignore[import] except ImportError: # Fallback method to be used if geos_ats isn't found - from permute_array import permuteArray # type: ignore[import] + from permute_array import permuteArray # type: ignore[import] RTOL_DEFAULT = 0.0 ATOL_DEFAULT = 0.0 -EXCLUDE_DEFAULT = [".*/commandLine", ".*/schema$", ".*/globalToLocalMap", ".*/timeHistoryOutput.*/restart"] -logger = logging.getLogger('geos_ats') +EXCLUDE_DEFAULT = [ ".*/commandLine", ".*/schema$", ".*/globalToLocalMap", ".*/timeHistoryOutput.*/restart" ] +logger = logging.getLogger( 'geos_ats' ) -def write(output, msg): +def write( output, msg ): """ Write MSG to both stdout and OUTPUT. OUTPUT [in/out]: File stream to write to. MSG [in]: Message to write. """ - msg = str(msg) - sys.stdout.write(msg) + msg = str( msg ) + sys.stdout.write( msg ) sys.stdout.flush() - output.write(msg) + output.write( msg ) -def h5PathJoin(p1, p2): +def h5PathJoin( p1, p2 ): if p1 == "/": return "/" + p2 if p1 == "": @@ -40,21 +40,21 @@ def h5PathJoin(p1, p2): return p1 + "/" + p2 -class FileComparison(object): +class FileComparison( object ): """ Class that compares two hdf5 files. """ - def __init__(self, - file_path, - baseline_path, - rtol, - atol, - regex_expressions, - output, - warnings_are_errors, - skip_missing, - diff_file=None): + def __init__( self, + file_path, + baseline_path, + rtol, + atol, + regex_expressions, + output, + warnings_are_errors, + skip_missing, + diff_file=None ): """ FILE_PATH [in]: The path of the first file to compare. BASELINE_PATH [in]: The path of the baseline file to compare against. @@ -75,38 +75,38 @@ def __init__(self, self.diff_file = diff_file self.different = False - assert (self.rtol >= 0.0) - assert (self.atol >= 0.0) + assert ( self.rtol >= 0.0 ) + assert ( self.atol >= 0.0 ) - def filesDiffer(self): + def filesDiffer( self ): try: - with h5py.File(self.file_path, "r") as file, h5py.File(self.baseline_path, "r") as base_file: + with h5py.File( self.file_path, "r" ) as file, h5py.File( self.baseline_path, "r" ) as base_file: self.file_path = file.filename self.baseline_path = base_file.filename - self.output.write("\nRank %s is comparing %s with %s \n" % - (MPI.COMM_WORLD.Get_rank(), self.file_path, self.baseline_path)) - self.compareGroups(file, base_file) + self.output.write( "\nRank %s is comparing %s with %s \n" % + ( MPI.COMM_WORLD.Get_rank(), self.file_path, self.baseline_path ) ) + self.compareGroups( file, base_file ) except IOError as e: - self.logger.debug(e) - self.output.write(str(e)) + self.logger.debug( e ) + self.output.write( str( e ) ) self.different = True return self.different - def add_links(self, path, message): + def add_links( self, path, message ): # When comparing the root groups self.diff_file is none. if self.diff_file is None: return - base_name = os.path.basename(self.file_path) + base_name = os.path.basename( self.file_path ) diff_group_name = base_name + "/" + path - diff_group = self.diff_file.create_group(diff_group_name) - diff_group.create_dataset("message", data=message) - diff_group["run"] = h5py.ExternalLink(self.file_path, path) - diff_group["baseline"] = h5py.ExternalLink(self.baseline_path, path) + diff_group = self.diff_file.create_group( diff_group_name ) + diff_group.create_dataset( "message", data=message ) + diff_group[ "run" ] = h5py.ExternalLink( self.file_path, path ) + diff_group[ "baseline" ] = h5py.ExternalLink( self.baseline_path, path ) - def errorMsg(self, path, message, add_to_diff=False): + def errorMsg( self, path, message, add_to_diff=False ): """ Issue an error which occurred at PATH in the files with the contents of MESSAGE. Sets self.different to True and rites the error to both stdout and OUTPUT. @@ -117,14 +117,14 @@ def errorMsg(self, path, message, add_to_diff=False): self.different = True msg = '*' * 80 + "\n" msg += "Error: %s\n" % path - msg += "\t" + "\n\t".join(message.split("\n"))[:-1] + msg += "\t" + "\n\t".join( message.split( "\n" ) )[ :-1 ] msg += '*' * 80 + "\n" - self.output.write(msg) + self.output.write( msg ) if add_to_diff: - self.add_links(path, message) + self.add_links( path, message ) - def warningMsg(self, path, message): + def warningMsg( self, path, message ): """ Issue a warning which occurred at PATH in the files with the contents of MESSAGE. Writes the warning to both stdout and OUTPUT. If WARNINGS_ARE_ERRORS then this @@ -134,26 +134,26 @@ def warningMsg(self, path, message): MESSAGE [in]: The warning message. """ if self.warnings_are_errors: - return self.errorMsg(path, message) + return self.errorMsg( path, message ) msg = '*' * 80 + "\n" msg += "Warning: %s\n" % path - msg += "\t" + "\n\t".join(message.split("\n"))[:-1] + msg += "\t" + "\n\t".join( message.split( "\n" ) )[ :-1 ] msg += '*' * 80 + "\n" - self.output.write(msg) + self.output.write( msg ) - def isExcluded(self, path): + def isExcluded( self, path ): """ Return True iff path matches any of the regex expressions in self.regex_expressions. PATH [in]: The path to match. """ for regex in self.regex_expressions: - if regex.match(path) is not None: + if regex.match( path ) is not None: return True return False - def compareFloatScalars(self, path, val, base_val): + def compareFloatScalars( self, path, val, base_val ): """ Compare floating point scalars. @@ -161,12 +161,12 @@ def compareFloatScalars(self, path, val, base_val): VAL [in]: The value to compare. BASE_VAL [in]: The baseline value to compare against. """ - dif = abs(val - base_val) - if dif > self.atol and dif > self.rtol * abs(base_val): - msg = "Scalar values of types %s and %s differ: %s, %s.\n" % (val.dtype, base_val.dtype, val, base_val) - self.errorMsg(path, msg, True) + dif = abs( val - base_val ) + if dif > self.atol and dif > self.rtol * abs( base_val ): + msg = "Scalar values of types %s and %s differ: %s, %s.\n" % ( val.dtype, base_val.dtype, val, base_val ) + self.errorMsg( path, msg, True ) - def compareIntScalars(self, path, val, base_val): + def compareIntScalars( self, path, val, base_val ): """ Compare integer scalars. @@ -175,10 +175,10 @@ def compareIntScalars(self, path, val, base_val): BASE_VAL [in]: The baseline value to compare against. """ if val != base_val: - msg = "Scalar values of types %s and %s differ: %s, %s.\n" % (val.dtype, base_val.dtype, val, base_val) - self.errorMsg(path, msg, True) + msg = "Scalar values of types %s and %s differ: %s, %s.\n" % ( val.dtype, base_val.dtype, val, base_val ) + self.errorMsg( path, msg, True ) - def compareStringScalars(self, path, val, base_val): + def compareStringScalars( self, path, val, base_val ): """ Compare string scalars. @@ -187,10 +187,10 @@ def compareStringScalars(self, path, val, base_val): BASE_VAL [in]: The baseline value to compare against. """ if val != base_val: - msg = "Scalar values of types %s and %s differ: %s, %s.\n" % (val.dtype, base_val.dtype, val, base_val) - self.errorMsg(path, msg, True) + msg = "Scalar values of types %s and %s differ: %s, %s.\n" % ( val.dtype, base_val.dtype, val, base_val ) + self.errorMsg( path, msg, True ) - def compareFloatArrays(self, path, arr, base_arr): + def compareFloatArrays( self, path, arr, base_arr ): """ Compares two arrays ARR and BASEARR of floating point values. Entries x1 and x2 are considered equal iff: @@ -221,26 +221,26 @@ def compareFloatArrays(self, path, arr, base_arr): """ # If we have zero tolerance then just call the compareIntArrays function. if self.rtol == 0.0 and self.atol == 0.0: - return self.compareIntArrays(path, arr, base_arr) + return self.compareIntArrays( path, arr, base_arr ) # If the shapes are different they can't be compared. if arr.shape != base_arr.shape: - msg = "Datasets have different shapes and therefore can't be compared: %s, %s.\n" % (arr.shape, - base_arr.shape) - self.errorMsg(path, msg, True) + msg = "Datasets have different shapes and therefore can't be compared: %s, %s.\n" % ( arr.shape, + base_arr.shape ) + self.errorMsg( path, msg, True ) return # First create a copy of the data in the datasets. - arr_cpy = np.copy(arr) - base_arr_cpy = np.copy(base_arr) + arr_cpy = np.copy( arr ) + base_arr_cpy = np.copy( base_arr ) # Now compute the difference and store the result in ARR1_CPY # which is appropriately renamed DIFFERENCE. - difference = np.subtract(arr, base_arr, out=arr_cpy) - np.abs(difference, out=difference) + difference = np.subtract( arr, base_arr, out=arr_cpy ) + np.abs( difference, out=difference ) # Take the absolute value of BASE_ARR_CPY and rename it to ABS_BASE_ARR - abs_base_arr = np.abs(base_arr_cpy, out=base_arr_cpy) + abs_base_arr = np.abs( base_arr_cpy, out=base_arr_cpy ) # max_abs_base_arr = np.max( abs_base_arr ) # comm = MPI.COMM_WORLD @@ -252,15 +252,15 @@ def compareFloatArrays(self, path, arr, base_arr): absTol = self.atol # Get the indices of the max absolute and relative error - max_absolute_index = np.unravel_index(np.argmax(difference), difference.shape) + max_absolute_index = np.unravel_index( np.argmax( difference ), difference.shape ) - relative_difference = difference / (abs_base_arr + 1e-20) + relative_difference = difference / ( abs_base_arr + 1e-20 ) # If the absolute tolerance is not zero, replace all nan's with zero. if self.atol != 0: - relative_difference = np.nan_to_num(relative_difference, 0) + relative_difference = np.nan_to_num( relative_difference, 0 ) - max_relative_index = np.unravel_index(np.argmax(relative_difference), relative_difference.shape) + max_relative_index = np.unravel_index( np.argmax( relative_difference ), relative_difference.shape ) if self.rtol != 0.0: relative_difference /= self.rtol @@ -268,69 +268,69 @@ def compareFloatArrays(self, path, arr, base_arr): if self.rtol == 0.0: difference /= absTol q = difference - absolute_limited = np.ones(q.shape, dtype=bool) + absolute_limited = np.ones( q.shape, dtype=bool ) elif self.atol == 0.0: q = relative_difference - absolute_limited = np.zeros(q.shape, dtype=bool) + absolute_limited = np.zeros( q.shape, dtype=bool ) else: # Multiply ABS_BASE_ARR by RTOL and rename it to RTOL_ABS_BASE - rtol_abs_base = np.multiply(self.rtol, abs_base_arr, out=abs_base_arr) + rtol_abs_base = np.multiply( self.rtol, abs_base_arr, out=abs_base_arr ) # Calculate which entries are limited by the relative tolerance. relative_limited = rtol_abs_base > absTol # Rename DIFFERENCE to Q where we will store the scaling parameter q. q = difference - q[relative_limited] = relative_difference[relative_limited] + q[ relative_limited ] = relative_difference[ relative_limited ] # Compute q for the entries which are limited by the absolute tolerance. - absolute_limited = np.logical_not(relative_limited, out=relative_limited) - q[absolute_limited] /= absTol + absolute_limited = np.logical_not( relative_limited, out=relative_limited ) + q[ absolute_limited ] /= absTol # If the maximum q value is greater than 1.0 than issue an error. - if np.max(q) > 1.0: - offenders = np.greater(q, 1.0) - n_offenders = np.sum(offenders) + if np.max( q ) > 1.0: + offenders = np.greater( q, 1.0 ) + n_offenders = np.sum( offenders ) - absolute_offenders = np.logical_and(offenders, absolute_limited, out=offenders) - q_num_absolute = np.sum(absolute_offenders) + absolute_offenders = np.logical_and( offenders, absolute_limited, out=offenders ) + q_num_absolute = np.sum( absolute_offenders ) if q_num_absolute > 0: absolute_qs = q * absolute_offenders - q_max_absolute = np.max(absolute_qs) - q_max_absolute_index = np.unravel_index(np.argmax(absolute_qs), absolute_qs.shape) - q_mean_absolute = np.mean(absolute_qs) - q_std_absolute = np.std(absolute_qs) - - offenders = np.greater(q, 1.0, out=offenders) - relative_limited = np.logical_not(absolute_limited, out=absolute_limited) - relative_offenders = np.logical_and(offenders, relative_limited, out=offenders) - q_num_relative = np.sum(relative_offenders) + q_max_absolute = np.max( absolute_qs ) + q_max_absolute_index = np.unravel_index( np.argmax( absolute_qs ), absolute_qs.shape ) + q_mean_absolute = np.mean( absolute_qs ) + q_std_absolute = np.std( absolute_qs ) + + offenders = np.greater( q, 1.0, out=offenders ) + relative_limited = np.logical_not( absolute_limited, out=absolute_limited ) + relative_offenders = np.logical_and( offenders, relative_limited, out=offenders ) + q_num_relative = np.sum( relative_offenders ) if q_num_relative > 0: relative_qs = q * relative_offenders - q_max_relative = np.max(relative_qs) - q_max_relative_index = np.unravel_index(np.argmax(relative_qs), q.shape) - q_mean_relative = np.mean(relative_qs) - q_std_relative = np.std(relative_qs) + q_max_relative = np.max( relative_qs ) + q_max_relative_index = np.unravel_index( np.argmax( relative_qs ), q.shape ) + q_mean_relative = np.mean( relative_qs ) + q_std_relative = np.std( relative_qs ) message = "Arrays of types %s and %s have %d values of which %d fail both the relative and absolute tests.\n" % ( - arr.dtype, base_arr.dtype, offenders.size, n_offenders) + arr.dtype, base_arr.dtype, offenders.size, n_offenders ) message += "\tMax absolute difference is at index %s: value = %s, base_value = %s\n" % ( - max_absolute_index, arr[max_absolute_index], base_arr[max_absolute_index]) + max_absolute_index, arr[ max_absolute_index ], base_arr[ max_absolute_index ] ) message += "\tMax relative difference is at index %s: value = %s, base_value = %s\n" % ( - max_relative_index, arr[max_relative_index], base_arr[max_relative_index]) + max_relative_index, arr[ max_relative_index ], base_arr[ max_relative_index ] ) message += "Statistics of the q values greater than 1.0 defined by absolute tolerance: N = %d\n" % q_num_absolute if q_num_absolute > 0: - message += "\tmax = %s, mean = %s, std = %s\n" % (q_max_absolute, q_mean_absolute, q_std_absolute) + message += "\tmax = %s, mean = %s, std = %s\n" % ( q_max_absolute, q_mean_absolute, q_std_absolute ) message += "\tmax is at index %s, value = %s, base_value = %s\n" % ( - q_max_absolute_index, arr[q_max_absolute_index], base_arr[q_max_absolute_index]) + q_max_absolute_index, arr[ q_max_absolute_index ], base_arr[ q_max_absolute_index ] ) message += "Statistics of the q values greater than 1.0 defined by relative tolerance: N = %d\n" % q_num_relative if q_num_relative > 0: - message += "\tmax = %s, mean = %s, std = %s\n" % (q_max_relative, q_mean_relative, q_std_relative) + message += "\tmax = %s, mean = %s, std = %s\n" % ( q_max_relative, q_mean_relative, q_std_relative ) message += "\tmax is at index %s, value = %s, base_value = %s\n" % ( - q_max_relative_index, arr[q_max_relative_index], base_arr[q_max_relative_index]) - self.errorMsg(path, message, True) + q_max_relative_index, arr[ q_max_relative_index ], base_arr[ q_max_relative_index ] ) + self.errorMsg( path, message, True ) - def compareIntArrays(self, path, arr, base_arr): + def compareIntArrays( self, path, arr, base_arr ): """ Compare two integer datasets. Exact equality is used as the acceptance criteria. @@ -340,34 +340,34 @@ def compareIntArrays(self, path, arr, base_arr): """ # If the shapes are different they can't be compared. if arr.shape != base_arr.shape: - msg = "Datasets have different shapes and therefore can't be compared: %s, %s.\n" % (arr.shape, - base_arr.shape) - self.errorMsg(path, msg, True) + msg = "Datasets have different shapes and therefore can't be compared: %s, %s.\n" % ( arr.shape, + base_arr.shape ) + self.errorMsg( path, msg, True ) return # Create a copy of the arrays. # Calculate the absolute difference. - difference = np.subtract(arr, base_arr) - np.abs(difference, out=difference) + difference = np.subtract( arr, base_arr ) + np.abs( difference, out=difference ) offenders = difference != 0.0 - n_offenders = np.sum(offenders) + n_offenders = np.sum( offenders ) if n_offenders != 0: - max_index = np.unravel_index(np.argmax(difference), difference.shape) - max_difference = difference[max_index] - offenders_mean = np.mean(difference[offenders]) - offenders_std = np.std(difference[offenders]) + max_index = np.unravel_index( np.argmax( difference ), difference.shape ) + max_difference = difference[ max_index ] + offenders_mean = np.mean( difference[ offenders ] ) + offenders_std = np.std( difference[ offenders ] ) message = "Arrays of types %s and %s have %s values of which %d have differing values.\n" % ( - arr.dtype, base_arr.dtype, offenders.size, n_offenders) + arr.dtype, base_arr.dtype, offenders.size, n_offenders ) message += "Statistics of the differences greater than 0:\n" - message += "\tmax_index = %s, max = %s, mean = %s, std = %s\n" % (max_index, max_difference, offenders_mean, - offenders_std) - self.errorMsg(path, message, True) + message += "\tmax_index = %s, max = %s, mean = %s, std = %s\n" % ( max_index, max_difference, + offenders_mean, offenders_std ) + self.errorMsg( path, message, True ) - def compareStringArrays(self, path, arr, base_arr): + def compareStringArrays( self, path, arr, base_arr ): """ Compare two string datasets. Exact equality is used as the acceptance criteria. @@ -375,13 +375,13 @@ def compareStringArrays(self, path, arr, base_arr): ARR [in]: The hdf5 Dataset to compare. BASE_ARR [in]: The hdf5 Dataset to compare against. """ - if arr.shape != base_arr.shape or np.any(arr[:] != base_arr[:]): + if arr.shape != base_arr.shape or np.any( arr[ : ] != base_arr[ : ] ): message = "String arrays differ.\n" - message += "String to compare: %s\n" % "".join(arr[:]) - message += "Baseline string : %s\n" % "".join(base_arr[:]) - self.errorMsg(path, message, True) + message += "String to compare: %s\n" % "".join( arr[ : ] ) + message += "Baseline string : %s\n" % "".join( base_arr[ : ] ) + self.errorMsg( path, message, True ) - def compareData(self, path, arr, base_arr): + def compareData( self, path, arr, base_arr ): """ Compare the numerical portion of two datasets. @@ -390,19 +390,19 @@ def compareData(self, path, arr, base_arr): BASE_ARR [in]: The hdf5 Dataset to compare against. """ # Get the type of comparison to do. - np_floats = set(['f', 'c']) - np_ints = set(['?', 'b', 'B', 'i', 'u', 'm', 'M', 'V']) + np_floats = set( [ 'f', 'c' ] ) + np_ints = set( [ '?', 'b', 'B', 'i', 'u', 'm', 'M', 'V' ] ) np_numeric = np_floats | np_ints - np_strings = set(['S', 'a', 'U']) + np_strings = set( [ 'S', 'a', 'U' ] ) int_compare = arr.dtype.kind in np_ints and base_arr.dtype.kind in np_ints - float_compare = not int_compare and (arr.dtype.kind in np_numeric and base_arr.dtype.kind in np_numeric) + float_compare = not int_compare and ( arr.dtype.kind in np_numeric and base_arr.dtype.kind in np_numeric ) string_compare = arr.dtype.kind in np_strings and base_arr.dtype.kind in np_strings # If the datasets have different types issue a warning. if arr.dtype != base_arr.dtype: - msg = "Datasets have different types: %s, %s.\n" % (arr.dtype, base_arr.dtype) - self.warningMsg(path, msg) + msg = "Datasets have different types: %s, %s.\n" % ( arr.dtype, base_arr.dtype ) + self.warningMsg( path, msg ) # Handle empty datasets if arr.shape is None and base_arr.shape is None: @@ -412,40 +412,41 @@ def compareData(self, path, arr, base_arr): if arr.size == 0 and base_arr.size == 0: return elif arr.size is None and base_arr.size is not None: - self.errorMsg(path, "File to compare has an empty dataset where the baseline's dataset is not empty.\n") + self.errorMsg( path, "File to compare has an empty dataset where the baseline's dataset is not empty.\n" ) elif base_arr.size is None and arr.size is not None: - self.warningMsg(path, "Baseline has an empty dataset where the file to compare's dataset is not empty.\n") + self.warningMsg( path, "Baseline has an empty dataset where the file to compare's dataset is not empty.\n" ) # If either of the datasets is a scalar convert it to an array. if arr.shape == (): - arr = np.array([arr]) + arr = np.array( [ arr ] ) if base_arr.shape == (): - base_arr = np.array([base_arr]) + base_arr = np.array( [ base_arr ] ) # If the datasets only contain one value call the compare scalar functions. if arr.size == 1 and base_arr.size == 1: - val = arr[:].flat[0] - base_val = base_arr[:].flat[0] + val = arr[ : ].flat[ 0 ] + base_val = base_arr[ : ].flat[ 0 ] if float_compare: - return self.compareFloatScalars(path, val, base_val) + return self.compareFloatScalars( path, val, base_val ) elif int_compare: - return self.compareIntScalars(path, val, base_val) + return self.compareIntScalars( path, val, base_val ) elif string_compare: - return self.compareStringScalars(path, val, base_val) + return self.compareStringScalars( path, val, base_val ) else: - return self.warningMsg(path, "Unrecognized type combination: %s %s.\n" % (arr.dtype, base_arr.dtype)) + return self.warningMsg( path, + "Unrecognized type combination: %s %s.\n" % ( arr.dtype, base_arr.dtype ) ) # Do the actual comparison. if float_compare: - return self.compareFloatArrays(path, arr, base_arr) + return self.compareFloatArrays( path, arr, base_arr ) elif int_compare: - return self.compareIntArrays(path, arr, base_arr) + return self.compareIntArrays( path, arr, base_arr ) elif string_compare: - return self.compareStringArrays(path, arr, base_arr) + return self.compareStringArrays( path, arr, base_arr ) else: - return self.warningMsg(path, "Unrecognized type combination: %s %s.\n" % (arr.dtype, base_arr.dtype)) + return self.warningMsg( path, "Unrecognized type combination: %s %s.\n" % ( arr.dtype, base_arr.dtype ) ) - def compareAttributes(self, path, attrs, base_attrs): + def compareAttributes( self, path, attrs, base_attrs ): """ Compare two sets of attributes. @@ -453,215 +454,216 @@ def compareAttributes(self, path, attrs, base_attrs): ATTRS [in]: The hdf5 AttributeManager to compare. BASE_ATTRS [in]: The hdf5 AttributeManager to compare against. """ - for attrName in set(list(attrs.keys()) + list(base_attrs.keys())): + for attrName in set( list( attrs.keys() ) + list( base_attrs.keys() ) ): if attrName not in attrs: msg = "Attribute %s is in the baseline file but not the file to compare.\n" % attrName - self.errorMsg(path, msg) + self.errorMsg( path, msg ) continue if attrName not in base_attrs: msg = "Attribute %s is in the file to compare but not the baseline file.\n" % attrName - self.warningMsg(path, msg) + self.warningMsg( path, msg ) continue attrsPath = path + ".attrs[" + attrName + "]" - self.compareData(attrsPath, attrs[attrName], base_attrs[attrName]) + self.compareData( attrsPath, attrs[ attrName ], base_attrs[ attrName ] ) - def compareDatasets(self, dset, base_dset): + def compareDatasets( self, dset, base_dset ): """ Compare two datasets. DSET [in]: The Dataset to compare. BASE_DSET [in]: The Dataset to compare against. """ - assert isinstance(dset, h5py.Dataset) - assert isinstance(base_dset, h5py.Dataset) + assert isinstance( dset, h5py.Dataset ) + assert isinstance( base_dset, h5py.Dataset ) path = dset.name - self.compareAttributes(path, dset.attrs, base_dset.attrs) + self.compareAttributes( path, dset.attrs, base_dset.attrs ) - self.compareData(path, dset, base_dset) + self.compareData( path, dset, base_dset ) - def canCompare(self, group, base_group, name): + def canCompare( self, group, base_group, name ): name_in_group = name in group name_in_base_group = name in base_group if not name_in_group and not name_in_base_group: return False - elif self.isExcluded(h5PathJoin(group.name, name)): + elif self.isExcluded( h5PathJoin( group.name, name ) ): return False if not name_in_group: msg = "Group has a child '%s' in the baseline file but not the file to compare.\n" % name if not self.skip_missing: - self.errorMsg(base_group.name, msg) + self.errorMsg( base_group.name, msg ) return False if not name_in_base_group: msg = "Group has a child '%s' in the file to compare but not the baseline file.\n" % name if not self.skip_missing: - self.errorMsg(group.name, msg) + self.errorMsg( group.name, msg ) return False return True - def compareLvArrays(self, group, base_group, other_children_to_check): - if self.canCompare(group, base_group, "__dimensions__") and self.canCompare( - group, base_group, "__permutation__") and self.canCompare(group, base_group, "__values__"): - other_children_to_check.remove("__dimensions__") - other_children_to_check.remove("__permutation__") - other_children_to_check.remove("__values__") + def compareLvArrays( self, group, base_group, other_children_to_check ): + if self.canCompare( group, base_group, "__dimensions__" ) and self.canCompare( + group, base_group, "__permutation__" ) and self.canCompare( group, base_group, "__values__" ): + other_children_to_check.remove( "__dimensions__" ) + other_children_to_check.remove( "__permutation__" ) + other_children_to_check.remove( "__values__" ) - dimensions = group["__dimensions__"][:] - base_dimensions = base_group["__dimensions__"][:] + dimensions = group[ "__dimensions__" ][ : ] + base_dimensions = base_group[ "__dimensions__" ][ : ] - if len(dimensions.shape) != 1: - msg = "The dimensions of an LvArray must itself be a 1D array not %s\n" % len(dimensions.shape) - self.errorMsg(group.name, msg) + if len( dimensions.shape ) != 1: + msg = "The dimensions of an LvArray must itself be a 1D array not %s\n" % len( dimensions.shape ) + self.errorMsg( group.name, msg ) - if dimensions.shape != base_dimensions.shape or np.any(dimensions != base_dimensions): + if dimensions.shape != base_dimensions.shape or np.any( dimensions != base_dimensions ): msg = "Cannot compare LvArrays because they have different dimensions. Dimensions = %s, base dimensions = %s\n" % ( - dimensions, base_dimensions) - self.errorMsg(group.name, msg) + dimensions, base_dimensions ) + self.errorMsg( group.name, msg ) return True - permutation = group["__permutation__"][:] - base_permutation = base_group["__permutation__"][:] + permutation = group[ "__permutation__" ][ : ] + base_permutation = base_group[ "__permutation__" ][ : ] - if len(permutation.shape) != 1: - msg = "The permutation of an LvArray must itself be a 1D array not %s\n" % len(permutation.shape) - self.errorMsg(group.name, msg) + if len( permutation.shape ) != 1: + msg = "The permutation of an LvArray must itself be a 1D array not %s\n" % len( permutation.shape ) + self.errorMsg( group.name, msg ) - if permutation.shape != dimensions.shape or np.any(np.sort(permutation) != np.arange(permutation.size)): + if permutation.shape != dimensions.shape or np.any( + np.sort( permutation ) != np.arange( permutation.size ) ): msg = "LvArray in the file to compare has an invalid permutation. Dimensions = %s, Permutation = %s\n" % ( - dimensions, permutation) - self.errorMsg(group.name, msg) + dimensions, permutation ) + self.errorMsg( group.name, msg ) return True if base_permutation.shape != base_dimensions.shape or np.any( - np.sort(base_permutation) != np.arange(base_permutation.size)): + np.sort( base_permutation ) != np.arange( base_permutation.size ) ): msg = "LvArray in the baseline has an invalid permutation. Dimensions = %s, Permutation = %s\n" % ( - base_dimensions, base_permutation) - self.errorMsg(group.name, msg) + base_dimensions, base_permutation ) + self.errorMsg( group.name, msg ) return True - values = group["__values__"][:] - base_values = base_group["__values__"][:] + values = group[ "__values__" ][ : ] + base_values = base_group[ "__values__" ][ : ] - values, errorMsg = permuteArray(values, dimensions, permutation) + values, errorMsg = permuteArray( values, dimensions, permutation ) if values is None: msg = "Failed to permute the LvArray: %s\n" % errorMsg - self.errorMsg(group.name, msg) + self.errorMsg( group.name, msg ) return True - base_values, errorMsg = permuteArray(base_values, base_dimensions, base_permutation) + base_values, errorMsg = permuteArray( base_values, base_dimensions, base_permutation ) if base_values is None: msg = "Failed to permute the baseline LvArray: %s\n" % errorMsg - self.errorMsg(group.name, msg) + self.errorMsg( group.name, msg ) return True - self.compareData(group.name, values, base_values) + self.compareData( group.name, values, base_values ) return True return False - def compareGroups(self, group, base_group): + def compareGroups( self, group, base_group ): """ Compare hdf5 groups. GROUP [in]: The Group to compare. BASE_GROUP [in]: The Group to compare against. """ - assert (isinstance(group, (h5py.Group, h5py.File))) - assert (isinstance(base_group, (h5py.Group, h5py.File))) + assert ( isinstance( group, ( h5py.Group, h5py.File ) ) ) + assert ( isinstance( base_group, ( h5py.Group, h5py.File ) ) ) path = group.name # Compare the attributes in the two groups. - self.compareAttributes(path, group.attrs, base_group.attrs) + self.compareAttributes( path, group.attrs, base_group.attrs ) - children_to_check = set(list(group.keys()) + list(base_group.keys())) - self.compareLvArrays(group, base_group, children_to_check) + children_to_check = set( list( group.keys() ) + list( base_group.keys() ) ) + self.compareLvArrays( group, base_group, children_to_check ) # Compare the sub groups and datasets. for name in children_to_check: - if self.canCompare(group, base_group, name): - item1 = group[name] - item2 = base_group[name] - if not isinstance(item1, type(item2)): + if self.canCompare( group, base_group, name ): + item1 = group[ name ] + item2 = base_group[ name ] + if not isinstance( item1, type( item2 ) ): msg = "Child %s has differing types in the file to compare and the baseline: %s, %s.\n" % ( - name, type(item1), type(item2)) - self.errorMsg(path, msg) + name, type( item1 ), type( item2 ) ) + self.errorMsg( path, msg ) continue - if isinstance(item1, h5py.Group): - self.compareGroups(item1, item2) - elif isinstance(item1, h5py.Dataset): - self.compareDatasets(item1, item2) + if isinstance( item1, h5py.Group ): + self.compareGroups( item1, item2 ) + elif isinstance( item1, h5py.Dataset ): + self.compareDatasets( item1, item2 ) else: - self.warningMsg(path, "Child %s has unknown type: %s.\n" % (name, type(item1))) + self.warningMsg( path, "Child %s has unknown type: %s.\n" % ( name, type( item1 ) ) ) -def findFiles(file_pattern, baseline_pattern, comparison_args): +def findFiles( file_pattern, baseline_pattern, comparison_args ): # Find the matching files. - file_path = findMaxMatchingFile(file_pattern) + file_path = findMaxMatchingFile( file_pattern ) if file_path is None: - raise ValueError("No files found matching %s." % file_pattern) + raise ValueError( "No files found matching %s." % file_pattern ) - baseline_path = findMaxMatchingFile(baseline_pattern) + baseline_path = findMaxMatchingFile( baseline_pattern ) if baseline_path is None: - raise ValueError("No files found matching %s." % baseline_pattern) + raise ValueError( "No files found matching %s." % baseline_pattern ) # Get the output path. - output_base_path = os.path.splitext(file_path)[0] + output_base_path = os.path.splitext( file_path )[ 0 ] output_path = output_base_path + ".restartcheck" # Open the output file and diff file files_to_compare = None - with open(output_path, 'w') as output_file: - comparison_args["output"] = output_file - writeHeader(file_pattern, file_path, baseline_pattern, baseline_path, comparison_args) + with open( output_path, 'w' ) as output_file: + comparison_args[ "output" ] = output_file + writeHeader( file_pattern, file_path, baseline_pattern, baseline_path, comparison_args ) # Check if comparing root files. - if file_path.endswith(".root") and baseline_path.endswith(".root"): - p = [re.compile("/file_pattern"), re.compile("/protocol/version")] - comp = FileComparison(file_path, baseline_path, 0.0, 0.0, p, output_file, True, False) + if file_path.endswith( ".root" ) and baseline_path.endswith( ".root" ): + p = [ re.compile( "/file_pattern" ), re.compile( "/protocol/version" ) ] + comp = FileComparison( file_path, baseline_path, 0.0, 0.0, p, output_file, True, False ) if comp.filesDiffer(): - write(output_file, "The root files are different, cannot compare data files.\n") + write( output_file, "The root files are different, cannot compare data files.\n" ) return output_base_path, None else: - write(output_file, "The root files are similar.\n") + write( output_file, "The root files are similar.\n" ) # Get the number of files and the file patterns. # We know the number of files are the same from the above comparison. - with h5py.File(file_path, "r") as f: - numberOfFiles = f["number_of_files"][0] - file_data_pattern = "".join(f["file_pattern"][:].tobytes().decode('ascii')[:-1]) + with h5py.File( file_path, "r" ) as f: + numberOfFiles = f[ "number_of_files" ][ 0 ] + file_data_pattern = "".join( f[ "file_pattern" ][ : ].tobytes().decode( 'ascii' )[ :-1 ] ) - with h5py.File(baseline_path, "r") as f: - baseline_data_pattern = "".join(f["file_pattern"][:].tobytes().decode('ascii')[:-1]) + with h5py.File( baseline_path, "r" ) as f: + baseline_data_pattern = "".join( f[ "file_pattern" ][ : ].tobytes().decode( 'ascii' )[ :-1 ] ) # Get the paths to the data files. files_to_compare = [] - for i in range(numberOfFiles): - path_to_data = os.path.join(os.path.dirname(file_path), file_data_pattern % i) - path_to_baseline_data = os.path.join(os.path.dirname(baseline_path), baseline_data_pattern % i) - files_to_compare += [(path_to_data, path_to_baseline_data)] + for i in range( numberOfFiles ): + path_to_data = os.path.join( os.path.dirname( file_path ), file_data_pattern % i ) + path_to_baseline_data = os.path.join( os.path.dirname( baseline_path ), baseline_data_pattern % i ) + files_to_compare += [ ( path_to_data, path_to_baseline_data ) ] else: - files_to_compare = [(file_path, baseline_path)] + files_to_compare = [ ( file_path, baseline_path ) ] return output_base_path, files_to_compare -def gatherOutput(output_file, output_base_path, n_files): - for i in range(n_files): - output_path = "%s.%d.restartcheck" % (output_base_path, i) - with open(output_path, "r") as file: +def gatherOutput( output_file, output_base_path, n_files ): + for i in range( n_files ): + output_path = "%s.%d.restartcheck" % ( output_base_path, i ) + with open( output_path, "r" ) as file: for line in file: - write(output_file, line) + write( output_file, line ) -def findMaxMatchingFile(file_path): +def findMaxMatchingFile( file_path ): """ Given a path FILE_PATH where the base name of FILE_PATH is treated as a regular expression find and return the path of the greatest matching file/folder or None if no match is found. @@ -673,26 +675,26 @@ def findMaxMatchingFile(file_path): 'test/plot_*.hdf5' will return the file with the greatest name in the ./test directory that begins with 'plot' and ends with '.hdf5'. """ - file_directory, pattern = os.path.split(file_path) + file_directory, pattern = os.path.split( file_path ) if file_directory == "": file_directory = "." - if not os.path.isdir(file_directory): + if not os.path.isdir( file_directory ): return None - pattern = re.compile(pattern) + pattern = re.compile( pattern ) max_match = "" - for file in os.listdir(file_directory): - if pattern.match(file) is not None: - max_match = max(file, max_match) + for file in os.listdir( file_directory ): + if pattern.match( file ) is not None: + max_match = max( file, max_match ) if max_match == "": return None - return os.path.join(file_directory, max_match) + return os.path.join( file_directory, max_match ) -def writeHeader(file_pattern, file_path, baseline_pattern, baseline_path, args): +def writeHeader( file_pattern, file_path, baseline_pattern, baseline_path, args ): """ Write the header. @@ -702,15 +704,15 @@ def writeHeader(file_pattern, file_path, baseline_pattern, baseline_path, args): BASELINE_PATH [in]: THE path to the file to compare against. ARGS [in]: A dictionary of arguments to FileComparison. """ - output = args["output"] - msg = "Comparison of file %s from pattern %s\n" % (file_path, file_pattern) - msg += "Baseline file %s from pattern %s\n" % (baseline_path, baseline_pattern) - msg += "Relative tolerance: %s\n" % args["rtol"] - msg += "Absolute tolerance: %s\n" % args["atol"] + output = args[ "output" ] + msg = "Comparison of file %s from pattern %s\n" % ( file_path, file_pattern ) + msg += "Baseline file %s from pattern %s\n" % ( baseline_path, baseline_pattern ) + msg += "Relative tolerance: %s\n" % args[ "rtol" ] + msg += "Absolute tolerance: %s\n" % args[ "atol" ] msg += "Output file: %s\n" % output.name - msg += "Excluded groups: %s\n" % list(map(lambda e: e.pattern, args["regex_expressions"])) - msg += "Warnings are errors: %s\n\n" % args["warnings_are_errors"] - write(output, msg) + msg += "Excluded groups: %s\n" % list( map( lambda e: e.pattern, args[ "regex_expressions" ] ) ) + msg += "Warnings are errors: %s\n\n" % args[ "warnings_are_errors" ] + write( output, msg ) def main(): @@ -727,103 +729,103 @@ def main(): n_ranks = comm.Get_size() parser = argparse.ArgumentParser() - parser.add_argument("file_pattern", help="The pattern used to find the file to compare.") - parser.add_argument("baseline_pattern", help="The pattern used to find the baseline file.") - parser.add_argument("-r", - "--relative", - type=float, - help="The relative tolerance for floating point differences, default is %s." % RTOL_DEFAULT, - default=RTOL_DEFAULT) - parser.add_argument("-a", - "--absolute", - type=float, - help="The absolute tolerance for floating point differences, default is %s." % ATOL_DEFAULT, - default=ATOL_DEFAULT) - parser.add_argument("-e", - "--exclude", - action='append', - help="Regular expressions specifying which groups to skip, default is %s." % EXCLUDE_DEFAULT, - default=EXCLUDE_DEFAULT) - parser.add_argument("-m", - "--skip-missing", - action="store_true", - help="Ignore values that are missing from either the baseline or target file.", - default=False) - parser.add_argument("-w", - "--Werror", - action="store_true", - help="Force all warnings to be errors, default is False.", - default=False) + parser.add_argument( "file_pattern", help="The pattern used to find the file to compare." ) + parser.add_argument( "baseline_pattern", help="The pattern used to find the baseline file." ) + parser.add_argument( "-r", + "--relative", + type=float, + help="The relative tolerance for floating point differences, default is %s." % RTOL_DEFAULT, + default=RTOL_DEFAULT ) + parser.add_argument( "-a", + "--absolute", + type=float, + help="The absolute tolerance for floating point differences, default is %s." % ATOL_DEFAULT, + default=ATOL_DEFAULT ) + parser.add_argument( "-e", + "--exclude", + action='append', + help="Regular expressions specifying which groups to skip, default is %s." % EXCLUDE_DEFAULT, + default=EXCLUDE_DEFAULT ) + parser.add_argument( "-m", + "--skip-missing", + action="store_true", + help="Ignore values that are missing from either the baseline or target file.", + default=False ) + parser.add_argument( "-w", + "--Werror", + action="store_true", + help="Force all warnings to be errors, default is False.", + default=False ) args = parser.parse_args() # Check the command line arguments if args.relative < 0.0: - raise ValueError("Relative tolerance cannot be less than 0.0.") + raise ValueError( "Relative tolerance cannot be less than 0.0." ) if args.absolute < 0.0: - raise ValueError("Absolute tolerance cannot be less than 0.0.") + raise ValueError( "Absolute tolerance cannot be less than 0.0." ) # Extract the command line arguments. file_pattern = args.file_pattern baseline_pattern = args.baseline_pattern comparison_args = {} - comparison_args["rtol"] = args.relative - comparison_args["atol"] = args.absolute - comparison_args["regex_expressions"] = list(map(re.compile, args.exclude)) - comparison_args["warnings_are_errors"] = args.Werror - comparison_args["skip_missing"] = args.skip_missing + comparison_args[ "rtol" ] = args.relative + comparison_args[ "atol" ] = args.absolute + comparison_args[ "regex_expressions" ] = list( map( re.compile, args.exclude ) ) + comparison_args[ "warnings_are_errors" ] = args.Werror + comparison_args[ "skip_missing" ] = args.skip_missing if rank == 0: - output_base_path, files_to_compare = findFiles(file_pattern, baseline_pattern, comparison_args) + output_base_path, files_to_compare = findFiles( file_pattern, baseline_pattern, comparison_args ) else: output_base_path, files_to_compare = None, None - files_to_compare = comm.bcast(files_to_compare, root=0) - output_base_path = comm.bcast(output_base_path, root=0) + files_to_compare = comm.bcast( files_to_compare, root=0 ) + output_base_path = comm.bcast( output_base_path, root=0 ) if files_to_compare is None: return 1 differing_files = [] - for i in range(rank, len(files_to_compare), n_ranks): - output_path = "%s.%d.restartcheck" % (output_base_path, i) - diff_path = "%s.%d.diff.hdf5" % (output_base_path, i) - with open(output_path, 'w') as output_file, h5py.File(diff_path, "w") as diff_file: - comparison_args["output"] = output_file - comparison_args["diff_file"] = diff_file - file_path, baseline_path = files_to_compare[i] - - logger.info(f"About to compare {file_path} and {baseline_path}") - if FileComparison(file_path, baseline_path, **comparison_args).filesDiffer(): - differing_files += [files_to_compare[i]] - output_file.write("The files are different.\n") + for i in range( rank, len( files_to_compare ), n_ranks ): + output_path = "%s.%d.restartcheck" % ( output_base_path, i ) + diff_path = "%s.%d.diff.hdf5" % ( output_base_path, i ) + with open( output_path, 'w' ) as output_file, h5py.File( diff_path, "w" ) as diff_file: + comparison_args[ "output" ] = output_file + comparison_args[ "diff_file" ] = diff_file + file_path, baseline_path = files_to_compare[ i ] + + logger.info( f"About to compare {file_path} and {baseline_path}" ) + if FileComparison( file_path, baseline_path, **comparison_args ).filesDiffer(): + differing_files += [ files_to_compare[ i ] ] + output_file.write( "The files are different.\n" ) else: - output_file.write("The files are similar.\n") + output_file.write( "The files are similar.\n" ) - differing_files = comm.allgather(differing_files) + differing_files = comm.allgather( differing_files ) all_differing_files = [] for file_list in differing_files: all_differing_files += file_list - difference_found = len(all_differing_files) > 0 + difference_found = len( all_differing_files ) > 0 if rank == 0: output_path = output_base_path + ".restartcheck" - with open(output_path, 'a') as output_file: - gatherOutput(output_file, output_base_path, len(files_to_compare)) + with open( output_path, 'a' ) as output_file: + gatherOutput( output_file, output_base_path, len( files_to_compare ) ) if difference_found: write( output_file, "\nCompared %d pairs of files of which %d are different.\n" % - (len(files_to_compare), len(all_differing_files))) + ( len( files_to_compare ), len( all_differing_files ) ) ) for file_path, base_path in all_differing_files: - write(output_file, "\t" + file_path + " and " + base_path + "\n") + write( output_file, "\t" + file_path + " and " + base_path + "\n" ) return 1 else: - write(output_file, - "\nThe root files and the %d pairs of files compared are similar.\n" % len(files_to_compare)) + write( output_file, + "\nThe root files and the %d pairs of files compared are similar.\n" % len( files_to_compare ) ) return difference_found if __name__ == "__main__" and not sys.flags.interactive: - sys.exit(main()) + sys.exit( main() ) diff --git a/geos_ats_package/geos_ats/machine_utilities.py b/geos_ats_package/geos_ats/machine_utilities.py index 5e66b48..356de83 100644 --- a/geos_ats_package/geos_ats/machine_utilities.py +++ b/geos_ats_package/geos_ats/machine_utilities.py @@ -5,16 +5,16 @@ import os -def CheckForEarlyTimeOut(test, retval, fraction): +def CheckForEarlyTimeOut( test, retval, fraction ): if not retval: return retval, fraction else: - if (config.max_retry > 0) and (config.retry_err_regexp != "") and (not hasattr(test, "checkstart")): - sourceFile = getattr(test, "errname") - if os.path.exists(sourceFile): + if ( config.max_retry > 0 ) and ( config.retry_err_regexp != "" ) and ( not hasattr( test, "checkstart" ) ): + sourceFile = getattr( test, "errname" ) + if os.path.exists( sourceFile ): test.checkstart = 1 - with open(sourceFile) as f: + with open( sourceFile ) as f: erroutput = f.read() - if re.search(config.retry_err_regexp, erroutput): + if re.search( config.retry_err_regexp, erroutput ): return 1, fraction return 0, fraction diff --git a/geos_ats_package/geos_ats/machines/batchGeosatsMoab.py b/geos_ats_package/geos_ats/machines/batchGeosatsMoab.py index 4f453cf..fb1e4f0 100644 --- a/geos_ats_package/geos_ats/machines/batchGeosatsMoab.py +++ b/geos_ats_package/geos_ats/machines/batchGeosatsMoab.py @@ -1,36 +1,36 @@ #BATS:batchGeosatsMoab batchGeosatsMoab BatchGeosatsMoab -1 -from ats import machines, configuration, log, atsut, times, AtsTest # type: ignore[import] +from ats import machines, configuration, log, atsut, times, AtsTest # type: ignore[import] import subprocess, sys, os, shlex, time, socket, re -import utils, batchTemplate # type: ignore[import] -from batch import BatchMachine # type: ignore[import] +import utils, batchTemplate # type: ignore[import] +from batch import BatchMachine # type: ignore[import] import logging debug = configuration.debug -logger = logging.getLogger('geos_ats') +logger = logging.getLogger( 'geos_ats' ) -class BatchGeosatsMoab(BatchMachine): +class BatchGeosatsMoab( BatchMachine ): """The batch machine """ - def init(self): + def init( self ): - super(BatchGeosatsMoab, self).init() + super( BatchGeosatsMoab, self ).init() if "SLURM_NNODES" in os.environ.keys(): - self.ppn = int(os.getenv("SLURM_TASKS_PER_NODE", "1").split("(")[0]) + self.ppn = int( os.getenv( "SLURM_TASKS_PER_NODE", "1" ).split( "(" )[ 0 ] ) elif "SLURM_JOB_NUM_NODES" in os.environ.keys(): - self.ppn = int(os.getenv("SLURM_JOB_CPUS_PER_NODE", "1").split("(")[0]) + self.ppn = int( os.getenv( "SLURM_JOB_CPUS_PER_NODE", "1" ).split( "(" )[ 0 ] ) else: self.ppn = 0 self.numberTestsRunningMax = 2048 - def canRun(self, test): + def canRun( self, test ): return '' - def load(self, testlist): + def load( self, testlist ): """Receive a list of tests to possibly run. Submit the set of tests to batch. """ @@ -41,20 +41,20 @@ def load(self, testlist): # for each test group make an msub file if t.groupSerialNumber == 1: - testCase = getattr(t, "geos_atsTestCase", None) + testCase = getattr( t, "geos_atsTestCase", None ) if testCase: - batchFilename = os.path.join(testCase.dirnamefull, "batch_%s.msub" % testCase.name) - self.writeSubmitScript(batchFilename, testCase) - self.jobid = self.submitBatchScript(testCase.name, batchFilename) + batchFilename = os.path.join( testCase.dirnamefull, "batch_%s.msub" % testCase.name ) + self.writeSubmitScript( batchFilename, testCase ) + self.jobid = self.submitBatchScript( testCase.name, batchFilename ) - def writeSubmitScript(self, batchFilename, testCase): + def writeSubmitScript( self, batchFilename, testCase ): - fc = open(batchFilename, "w") + fc = open( batchFilename, "w" ) batch = testCase.batch # get references to the options and configuration options = AtsTest.getOptions() - config = options.get("config", None) + config = options.get( "config", None ) # ppn # 1. first check batch object @@ -69,10 +69,10 @@ def writeSubmitScript(self, batchFilename, testCase): ppn = self.ppn if ppn == 0: - raise RuntimeError(""" + raise RuntimeError( """ Unable to find the number of processors per node in BatchGeosatsMoab. Try setting batch_ppn= on the - command line.""") + command line.""" ) # Specifies parallel Lustre file system. gresLine = "" @@ -81,7 +81,7 @@ def writeSubmitScript(self, batchFilename, testCase): # determine the max number of processors in this job maxprocs = testCase.findMaxNumberOfProcessors() - minNodes = maxprocs / ppn + (maxprocs % ppn != 0) + minNodes = maxprocs / ppn + ( maxprocs % ppn != 0 ) # MSUB options msub_str = '#!/bin/csh' @@ -114,20 +114,20 @@ def writeSubmitScript(self, batchFilename, testCase): msub_str += f"\n\ncd {testCase.dirnamefull}" # pull out options to construct the command line - action = options.get("action") - checkoption = options.get("checkoption") - configFile = options.get("configFile") - configOverride = options.get("configOverride") - atsFlags = options.get("atsFlags") - geos_atsPath = options.get("geos_atsPath") - machine = options.get("machine") + action = options.get( "action" ) + checkoption = options.get( "checkoption" ) + configFile = options.get( "configFile" ) + configOverride = options.get( "configOverride" ) + atsFlags = options.get( "atsFlags" ) + geos_atsPath = options.get( "geos_atsPath" ) + machine = options.get( "machine" ) # construct the command line msub_str += f'\n{geos_atsPath} -a {action} -c {checkoption}' msub_str += f' -f {configFile} -N {minNodes:d} --machine={machine}' for key, value in configOverride.items(): - if key.startswith("batch"): + if key.startswith( "batch" ): continue msub_str += f' {key}="{value}"' @@ -138,26 +138,26 @@ def writeSubmitScript(self, batchFilename, testCase): msub_str += f" batch_interactive=True {testCase.name}" # Write and close the file - fc.write(msub_str) + fc.write( msub_str ) fc.close() - def submitBatchScript(self, testname, batchFilename): + def submitBatchScript( self, testname, batchFilename ): options = AtsTest.getOptions() - config = options.get("config", None) + config = options.get( "config", None ) if config and config.batch_dryrun: return - p = subprocess.Popen(["msub", batchFilename], stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - out = p.communicate()[0] + p = subprocess.Popen( [ "msub", batchFilename ], stdout=subprocess.PIPE, stderr=subprocess.STDOUT ) + out = p.communicate()[ 0 ] if p.returncode: - raise RuntimeError(f"Error submitting {testname} to batch: {out}") + raise RuntimeError( f"Error submitting {testname} to batch: {out}" ) try: - jobid = int(out) - logger.info(f" Submitting {testname}, jobid = {jobid:d}") + jobid = int( out ) + logger.info( f" Submitting {testname}, jobid = {jobid:d}" ) except: err = f"Error submitting {testname} to batch: {out}" - logger.error(err) - raise RuntimeError(err) + logger.error( err ) + raise RuntimeError( err ) diff --git a/geos_ats_package/geos_ats/machines/bgqos_0_ASQ.py b/geos_ats_package/geos_ats/machines/bgqos_0_ASQ.py index 6c470db..4854a41 100644 --- a/geos_ats_package/geos_ats/machines/bgqos_0_ASQ.py +++ b/geos_ats_package/geos_ats/machines/bgqos_0_ASQ.py @@ -1,19 +1,19 @@ #ATS:bgqos_0_ASQ machines.bgqos_0_ASQ bgqos_0_ASQMachine 16 -from ats import machines, debug, atsut # type: ignore[import] +from ats import machines, debug, atsut # type: ignore[import] from ats import log, terminal from ats import configuration -from ats.atsut import RUNNING, TIMEDOUT, SKIPPED, BATCHED, INVALID, PASSED, FAILED, CREATED, FILTERED, HALTED, EXPECTED # type: ignore[import] -import utils # type: ignore[import] +from ats.atsut import RUNNING, TIMEDOUT, SKIPPED, BATCHED, INVALID, PASSED, FAILED, CREATED, FILTERED, HALTED, EXPECTED # type: ignore[import] +import utils # type: ignore[import] import time import sys -class bgqos_0_ASQMachine(machines.Machine): +class bgqos_0_ASQMachine( machines.Machine ): """The chaos family with processor scheduling. """ - def init(self): + def init( self ): self.npBusy = 0 @@ -24,35 +24,35 @@ def init(self): self.nodeProcAvailDic = {} - def addOptions(self, parser): + def addOptions( self, parser ): "Add options needed on this machine." - parser.add_option("--partition", - action="store", - type="string", - dest='partition', - default='pdebug', - help="Partition in which to run jobs with np > 0") - parser.add_option("--numNodes", - action="store", - type="int", - dest='numNodes', - default=-1, - help="Number of nodes to use") - parser.add_option("--srunOnlyWhenNecessary", - action="store_true", - dest='srun', - default=False, - help="Use srun only for np > 1") - parser.add_option("--removeSrunStep", - action="store_true", - dest='removeSrunStep', - default=True, - help="Set to use srun job step.") - - def examineOptions(self, options): + parser.add_option( "--partition", + action="store", + type="string", + dest='partition', + default='pdebug', + help="Partition in which to run jobs with np > 0" ) + parser.add_option( "--numNodes", + action="store", + type="int", + dest='numNodes', + default=-1, + help="Number of nodes to use" ) + parser.add_option( "--srunOnlyWhenNecessary", + action="store_true", + dest='srun', + default=False, + help="Use srun only for np > 1" ) + parser.add_option( "--removeSrunStep", + action="store_true", + dest='removeSrunStep', + default=True, + help="Set to use srun job step." ) + + def examineOptions( self, options ): "Examine options from command line, possibly override command line choices." # Grab option values. - super(bgqos_0_ASQMachine, self).examineOptions(options) + super( bgqos_0_ASQMachine, self ).examineOptions( options ) self.npMax = self.numberTestsRunningMax import os @@ -63,7 +63,7 @@ def examineOptions(self, options): if options.numNodes == -1: if 'SLURM_NNODES' in os.environ: - options.numNodes = int(os.environ['SLURM_NNODES']) + options.numNodes = int( os.environ[ 'SLURM_NNODES' ] ) else: options.numNodes = 1 @@ -78,24 +78,24 @@ def examineOptions(self, options): if not self.removeSrunStep: self.allNodeList = utils.getAllHostnames() - if len(self.allNodeList) == 0: + if len( self.allNodeList ) == 0: self.removeSrunStep = True else: - self.stepId, self.nodeStepNumDic = utils.setStepNumWithNode(len(self.allNodeList)) + self.stepId, self.nodeStepNumDic = utils.setStepNumWithNode( len( self.allNodeList ) ) for oneNode in self.allNodeList: - self.nodeProcAvailDic[oneNode] = self.npMax + self.nodeProcAvailDic[ oneNode ] = self.npMax self.stepInUse = self.stepToUse # Let's check if there exists a srun process - if len(self.allNodeList) > 0: - srunDefunct = utils.checkForSrunDefunct(self.allNodeList[0]) + if len( self.allNodeList ) > 0: + srunDefunct = utils.checkForSrunDefunct( self.allNodeList[ 0 ] ) self.numberMaxProcessors -= srunDefunct - self.nodeProcAvailDic[self.allNodeList[0]] -= srunDefunct + self.nodeProcAvailDic[ self.allNodeList[ 0 ] ] -= srunDefunct self.numberTestsRunningMax = self.numberMaxProcessors - def getResults(self): - results = super(bgqos_0_ASQMachine, self).getResults() + def getResults( self ): + results = super( bgqos_0_ASQMachine, self ).getResults() results.srunOnlyWhenNecessary = self.srunOnlyWhenNecessary results.partition = self.partition results.numNodes = self.numNodes @@ -106,25 +106,25 @@ def getResults(self): return results - def label(self): - return "BG/Q %d nodes %d processors per node." % (self.numNodes, self.npMax) + def label( self ): + return "BG/Q %d nodes %d processors per node." % ( self.numNodes, self.npMax ) - def calculateCommandList(self, test): + def calculateCommandList( self, test ): """Prepare for run of executable using a suitable command. First we get the plain command line that would be executed on a vanilla serial machine, then we modify it if necessary for use on this machines. """ - commandList = self.calculateBasicCommandList(test) + commandList = self.calculateBasicCommandList( test ) if self.srunOnlyWhenNecessary and test.np <= 1: return commandList - if test.options.get('checker_test'): + if test.options.get( 'checker_test' ): return commandList # namebase is a space-free version of the name test.jobname = f"t{test.np}_{test.serialNumber}{test.namebase}" - np = max(test.np, 1) - minNodes = np / self.npMax + (np % self.npMax != 0) + np = max( test.np, 1 ) + minNodes = np / self.npMax + ( np % self.npMax != 0 ) # # These should be removed @@ -156,25 +156,25 @@ def calculateCommandList(self, test): # return [ "srun", - "-N%i-%i" % (minNodes, minNodes), "-n", - str(np), "-p", self.partition, "--label", "-J", test.jobname + "-N%i-%i" % ( minNodes, minNodes ), "-n", + str( np ), "-p", self.partition, "--label", "-J", test.jobname ] + commandList - def canRun(self, test): + def canRun( self, test ): "Do some precalculations here to make canRunNow quicker." - test.requiredNP = max(test.np, 1) - test.numberOfNodesNeeded, r = divmod(test.requiredNP, self.npMax) + test.requiredNP = max( test.np, 1 ) + test.numberOfNodesNeeded, r = divmod( test.requiredNP, self.npMax ) if r: test.numberOfNodesNeeded += 1 if self.removeSrunStep: - test.requiredNP = max(test.np, self.npMax * test.numberOfNodesNeeded) + test.requiredNP = max( test.np, self.npMax * test.numberOfNodesNeeded ) if test.requiredNP > self.numberMaxProcessors: - return "Too many processors required, %d (limit is %d)" % (test.requiredNP, self.numberMaxProcessors) + return "Too many processors required, %d (limit is %d)" % ( test.requiredNP, self.numberMaxProcessors ) - def canRunNow(self, test): + def canRunNow( self, test ): "Is this machine able to run this test now? Return True/False" - if (self.npBusy + test.requiredNP) > self.numberMaxProcessors: + if ( self.npBusy + test.requiredNP ) > self.numberMaxProcessors: return False elif self.removeSrunStep: @@ -182,60 +182,61 @@ def canRunNow(self, test): return True - def noteLaunch(self, test): + def noteLaunch( self, test ): """A test has been launched.""" if not self.removeSrunStep: if test.srunRelativeNode < 0: - self.nodeProcAvailDic = utils.removeFromUsedTotalDicNoSrun(self.nodeProcAvailDic, self.nodeStepNumDic, - self.npMax, test.np, self.allNodeList) + self.nodeProcAvailDic = utils.removeFromUsedTotalDicNoSrun( self.nodeProcAvailDic, self.nodeStepNumDic, + self.npMax, test.np, self.allNodeList ) else: - self.nodeProcAvailDic = utils.removeFromUsedTotalDic(self.nodeProcAvailDic, self.nodeStepNumDic, - self.npMax, test.step, test.np, - test.numberOfNodesNeeded, test.numNodesToUse, - test.srunRelativeNode, self.stepId, - self.allNodeList) - self.npBusy += max(test.np, 1) + self.nodeProcAvailDic = utils.removeFromUsedTotalDic( self.nodeProcAvailDic, self.nodeStepNumDic, + self.npMax, test.step, test.np, + test.numberOfNodesNeeded, test.numNodesToUse, + test.srunRelativeNode, self.stepId, + self.allNodeList ) + self.npBusy += max( test.np, 1 ) else: # this is necessary when srun exclusive is used. - self.npBusy += max(test.np, test.numberOfNodesNeeded * self.npMax) + self.npBusy += max( test.np, test.numberOfNodesNeeded * self.npMax ) if debug(): - log(f"Max np={self.numberMaxProcessors}. Launched {test.name} with np= {test.np} tests, total proc in use = {self.npBusy}", - echo=True) + log( + f"Max np={self.numberMaxProcessors}. Launched {test.name} with np= {test.np} tests, total proc in use = {self.npBusy}", + echo=True ) self.scheduler.schedule( f"Max np= {self.numberMaxProcessors}. Launched {test.name} with np= {test.np} tests, total proc in use = self.npBusy" ) self.numberTestsRunning = self.npBusy - def noteEnd(self, test): + def noteEnd( self, test ): """A test has finished running. """ if not self.removeSrunStep: - self.npBusy -= max(test.np, 1) + self.npBusy -= max( test.np, 1 ) else: # this is necessary when srun exclusive is used. - self.npBusy -= max(test.np, test.numberOfNodesNeeded * self.npMax) + self.npBusy -= max( test.np, test.numberOfNodesNeeded * self.npMax ) if debug(): - log("Finished %s, #total proc in use = %d" % (test.name, self.npBusy), echo=True) - self.scheduler.schedule("Finished %s, #total proc in use = %d" % (test.name, self.npBusy)) + log( "Finished %s, #total proc in use = %d" % ( test.name, self.npBusy ), echo=True ) + self.scheduler.schedule( "Finished %s, #total proc in use = %d" % ( test.name, self.npBusy ) ) self.numberTestsRunning = self.npBusy - def periodicReport(self): + def periodicReport( self ): "Report on current status of tasks" # Let's also write out the tests that are waiting .... - super(bgqos_0_ASQMachine, self).periodicReport() - currentEligible = [t.name for t in self.scheduler.testlist() if t.status is atsut.CREATED] + super( bgqos_0_ASQMachine, self ).periodicReport() + currentEligible = [ t.name for t in self.scheduler.testlist() if t.status is atsut.CREATED ] - if len(currentEligible) > 1: - terminal("WAITING:", ", ".join(currentEligible[:5]), "... (more)") + if len( currentEligible ) > 1: + terminal( "WAITING:", ", ".join( currentEligible[ :5 ] ), "... (more)" ) - def kill(self, test): + def kill( self, test ): "Final cleanup if any." # kill the test # This is necessary -- killing the srun command itself is not enough to end the job... it is still running (squeue will show this) @@ -243,12 +244,12 @@ def kill(self, test): if test.status is RUNNING or test.status is TIMEDOUT: try: - retcode = subprocess.call("scancel" + " -n " + test.jobname, shell=True) + retcode = subprocess.call( "scancel" + " -n " + test.jobname, shell=True ) if retcode < 0: - log("---- kill() in bgqos_0_ASQ.py, command= scancel -n %s failed with return code -%d ----" % - (test.jobname, retcode), - echo=True) + log( "---- kill() in bgqos_0_ASQ.py, command= scancel -n %s failed with return code -%d ----" % + ( test.jobname, retcode ), + echo=True ) except OSError as e: - log("---- kill() in bgqos_0_ASQ.py, execution of command failed (scancel -n %s) failed: %s----" % - (test.jobname, e), - echo=True) + log( "---- kill() in bgqos_0_ASQ.py, execution of command failed (scancel -n %s) failed: %s----" % + ( test.jobname, e ), + echo=True ) diff --git a/geos_ats_package/geos_ats/machines/darwin.py b/geos_ats_package/geos_ats/machines/darwin.py index af7b8b5..944e59a 100644 --- a/geos_ats_package/geos_ats/machines/darwin.py +++ b/geos_ats_package/geos_ats/machines/darwin.py @@ -1,8 +1,8 @@ #ATS:darwin machines.darwin DarwinMachine 16 -from openmpi import OpenmpiMachine # type: ignore[import] +from openmpi import OpenmpiMachine # type: ignore[import] -class DarwinMachine(OpenmpiMachine): +class DarwinMachine( OpenmpiMachine ): "Darwin Machine." pass diff --git a/geos_ats_package/geos_ats/machines/geosAtsSlurmProcessorScheduled.py b/geos_ats_package/geos_ats/machines/geosAtsSlurmProcessorScheduled.py index 1770402..ef24cbb 100644 --- a/geos_ats_package/geos_ats/machines/geosAtsSlurmProcessorScheduled.py +++ b/geos_ats_package/geos_ats/machines/geosAtsSlurmProcessorScheduled.py @@ -8,16 +8,16 @@ from geos_ats.scheduler import scheduler from geos_ats.machine_utilities import CheckForEarlyTimeOut -from slurmProcessorScheduled import SlurmProcessorScheduled # type: ignore[import] +from slurmProcessorScheduled import SlurmProcessorScheduled # type: ignore[import] import subprocess import logging -class GeosAtsSlurmProcessorScheduled(SlurmProcessorScheduled): +class GeosAtsSlurmProcessorScheduled( SlurmProcessorScheduled ): - def init(self): - super(GeosAtsSlurmProcessorScheduled, self).init() - self.logger = logging.getLogger('geos_ats') + def init( self ): + super( GeosAtsSlurmProcessorScheduled, self ).init() + self.logger = logging.getLogger( 'geos_ats' ) if not self.runWithSalloc: try: # Try to get the number of processors per node via sinfo. @@ -25,23 +25,23 @@ def init(self): # CPUs (%c) is actually threads, so multiply sockets (%X) X cores (%Y) # to get actual number of processors (we ignore hyprethreading). sinfoCmd = 'sinfo -o"%X %Y"' - proc = subprocess.Popen(sinfoCmd, shell=True, stdout=subprocess.PIPE) - stdout_value = proc.communicate()[0] - (sockets, cores) = stdout_value.split('\n')[1].split() - self.npMaxH = int(sockets) * int(cores) + proc = subprocess.Popen( sinfoCmd, shell=True, stdout=subprocess.PIPE ) + stdout_value = proc.communicate()[ 0 ] + ( sockets, cores ) = stdout_value.split( '\n' )[ 1 ].split() + self.npMaxH = int( sockets ) * int( cores ) except: - self.logger.debug("Failed to identify npMaxH") + self.logger.debug( "Failed to identify npMaxH" ) else: self.npMaxH = self.npMax self.scheduler = scheduler() - def label(self): - return "GeosAtsSlurmProcessorScheduled: %d nodes, %d processors per node." % (self.numNodes, self.npMax) + def label( self ): + return "GeosAtsSlurmProcessorScheduled: %d nodes, %d processors per node." % ( self.numNodes, self.npMax ) - def checkForTimeOut(self, test): + def checkForTimeOut( self, test ): """ Check the time elapsed since test's start time. If greater then the timelimit, return true, else return false. test's end time is set if time elapsed exceeds time limit. Also return true if retry string if found.""" - retval, fraction = super(GeosAtsSlurmProcessorScheduled, self).checkForTimeOut(test) - return CheckForEarlyTimeOut(test, retval, fraction) + retval, fraction = super( GeosAtsSlurmProcessorScheduled, self ).checkForTimeOut( test ) + return CheckForEarlyTimeOut( test, retval, fraction ) diff --git a/geos_ats_package/geos_ats/machines/lassen.py b/geos_ats_package/geos_ats/machines/lassen.py index 1bdd93b..0694c5c 100644 --- a/geos_ats_package/geos_ats/machines/lassen.py +++ b/geos_ats_package/geos_ats/machines/lassen.py @@ -1,106 +1,106 @@ #ATS:SequentialMachine SELF lassenMachine 1 #ATS:lassen SELF lassenMachine 1 -from ats import machines # type: ignore[import] +from ats import machines # type: ignore[import] from ats import machines, debug, atsut from ats import log, terminal from ats import configuration -from ats.atsut import RUNNING, TIMEDOUT # type: ignore[import] +from ats.atsut import RUNNING, TIMEDOUT # type: ignore[import] from ats import AtsTest import os import subprocess import logging -class lassenMachine(machines.Machine): +class lassenMachine( machines.Machine ): """ run from a backend node on Lassen """ - def init(self): + def init( self ): self.numtests = 0 self.numProcsAvailable = 0 - self.logger = logging.getLogger('geos_ats') + self.logger = logging.getLogger( 'geos_ats' ) - def examineOptions(self, options): + def examineOptions( self, options ): "Examine options from command line, possibly override command line choices." - super(lassenMachine, self).examineOptions(options) + super( lassenMachine, self ).examineOptions( options ) # Get total cpu cores available, and convert to number of gpus! - self.numberMaxProcessors = int(os.getenv("LSB_MAX_NUM_PROCESSORS", "0")) - 1 + self.numberMaxProcessors = int( os.getenv( "LSB_MAX_NUM_PROCESSORS", "0" ) ) - 1 self.numberMaxGPUS = self.numberMaxProcessors / 10 self.numberTestsRunningMax = self.numberMaxProcessors self.numProcsAvailable = self.numberMaxProcessors - def getNumberOfProcessors(self): + def getNumberOfProcessors( self ): return self.numberMaxProcessors - def getNumberOfGPUS(self): + def getNumberOfGPUS( self ): return self.numberMaxGPUS - def addOptions(self, parser): + def addOptions( self, parser ): "Add options needed on this machine." - parser.add_option("--numNodes", - action="store", - type="int", - dest='numNodes', - default=1, - help="Number of nodes to use") + parser.add_option( "--numNodes", + action="store", + type="int", + dest='numNodes', + default=1, + help="Number of nodes to use" ) return - def calculateCommandList(self, test): + def calculateCommandList( self, test ): """Prepare for run of executable using a suitable command. First we get the plain command line that would be executed on a vanilla serial machine, then we modify it if necessary for use on this machines. """ options = AtsTest.getOptions() - basicCommands = self.calculateBasicCommandList(test) + basicCommands = self.calculateBasicCommandList( test ) commandList = [] ngpu = test.ngpu - commandList += ["lrun", "-n", "%d" % test.np, "--pack", "-g", "%d" % ngpu] + commandList += [ "lrun", "-n", "%d" % test.np, "--pack", "-g", "%d" % ngpu ] commandList += basicCommands return commandList - def canRun(self, test): + def canRun( self, test ): """Is this machine able to run the test interactively when resources become available? If so return ''. Otherwise return the reason it cannot be run here. """ - np = max(test.np, 1) + np = max( test.np, 1 ) if np > self.numberMaxProcessors: return f"Too many processors needed ({np})" gpusPerTask = test.ngpu if np * gpusPerTask > self.numberMaxGPUS: err = f"Too many gpus needed ({np * gpusPerTask:d})" - self.logger.error(err) + self.logger.error( err ) return err - def canRunNow(self, test): + def canRunNow( self, test ): """We let lrun do the scheduling so return true.""" return True - def noteLaunch(self, test): + def noteLaunch( self, test ): """A test has been launched.""" self.numtests += 1 - def noteEnd(self, test): + def noteEnd( self, test ): """A test has finished running. """ self.numtests -= 1 - def periodicReport(self): + def periodicReport( self ): "Report on current status of tasks" - terminal("-" * 80) - terminal("Running jobs:") - os.system("jslist -r") - terminal("Queued jobs:") - os.system("jslist -p") - terminal("-" * 80) - - def kill(self, test): + terminal( "-" * 80 ) + terminal( "Running jobs:" ) + os.system( "jslist -r" ) + terminal( "Queued jobs:" ) + os.system( "jslist -p" ) + terminal( "-" * 80 ) + + def kill( self, test ): "Final cleanup if any." if test.status is RUNNING or test.status is TIMEDOUT: @@ -109,11 +109,11 @@ def kill(self, test): try: test.child.terminate() except: - logger.info("Terminating job") + logger.info( "Terminating job" ) try: - retcode = subprocess.call("jskill all", shell=True) + retcode = subprocess.call( "jskill all", shell=True ) if retcode < 0: - log("command= %s failed with return code %d" % ("jskill all", retcode), echo=True) + log( "command= %s failed with return code %d" % ( "jskill all", retcode ), echo=True ) except: - logger.info("Killing job") + logger.info( "Killing job" ) diff --git a/geos_ats_package/geos_ats/machines/nersc.py b/geos_ats_package/geos_ats/machines/nersc.py index 1eb2443..e92a3e9 100644 --- a/geos_ats_package/geos_ats/machines/nersc.py +++ b/geos_ats_package/geos_ats/machines/nersc.py @@ -8,16 +8,16 @@ from geos_ats.scheduler import scheduler from geos_ats.machine_utilities import CheckForEarlyTimeOut -from slurmProcessorScheduled import SlurmProcessorScheduled # type: ignore[import] +from slurmProcessorScheduled import SlurmProcessorScheduled # type: ignore[import] import subprocess import logging -class Nersc(SlurmProcessorScheduled): +class Nersc( SlurmProcessorScheduled ): - def init(self): - super(Nersc, self).init() - self.logger = logging.getLogger('geos_ats') + def init( self ): + super( Nersc, self ).init() + self.logger = logging.getLogger( 'geos_ats' ) if not self.runWithSalloc: try: # Try to get the number of processors per node via sinfo. @@ -25,28 +25,28 @@ def init(self): # CPUs (%c) is actually threads, so multiply sockets (%X) X cores (%Y) # to get actual number of processors (we ignore hyprethreading). sinfoCmd = 'sinfo -o"%X %Y"' - proc = subprocess.Popen(sinfoCmd, shell=True, stdout=subprocess.PIPE) - stdout_value = proc.communicate()[0] - (sockets, cores) = stdout_value.split('\n')[1].split() - self.npMaxH = int(sockets) * int(cores) + proc = subprocess.Popen( sinfoCmd, shell=True, stdout=subprocess.PIPE ) + stdout_value = proc.communicate()[ 0 ] + ( sockets, cores ) = stdout_value.split( '\n' )[ 1 ].split() + self.npMaxH = int( sockets ) * int( cores ) except: - self.logger.debug("Failed to identify npMaxH") + self.logger.debug( "Failed to identify npMaxH" ) else: self.npMaxH = self.npMax self.scheduler = scheduler() - def label(self): - return "Nersc: %d nodes, %d processors per node." % (self.numNodes, self.npMax) + def label( self ): + return "Nersc: %d nodes, %d processors per node." % ( self.numNodes, self.npMax ) - def checkForTimeOut(self, test): + def checkForTimeOut( self, test ): """ Check the time elapsed since test's start time. If greater then the timelimit, return true, else return false. test's end time is set if time elapsed exceeds time limit. Also return true if retry string if found.""" - retval, fraction = super(Nersc, self).checkForTimeOut(test) - return CheckForEarlyTimeOut(test, retval, fraction) + retval, fraction = super( Nersc, self ).checkForTimeOut( test ) + return CheckForEarlyTimeOut( test, retval, fraction ) - def calculateCommandList(self, test): - command = super(Nersc, self).calculateCommandList(test) - command.remove("--mpibind=off") + def calculateCommandList( self, test ): + command = super( Nersc, self ).calculateCommandList( test ) + command.remove( "--mpibind=off" ) return command diff --git a/geos_ats_package/geos_ats/machines/openmpi.py b/geos_ats_package/geos_ats/machines/openmpi.py index 842564d..8304ca8 100644 --- a/geos_ats_package/geos_ats/machines/openmpi.py +++ b/geos_ats_package/geos_ats/machines/openmpi.py @@ -1,28 +1,28 @@ #ATS:openmpi machines.openmpi OpenmpiMachine 16 import os -import ats # type: ignore[import] +import ats # type: ignore[import] from ats import machines from ats import terminal from ats import log import shlex -from ats.atsut import RUNNING, TIMEDOUT # type: ignore[import] +from ats.atsut import RUNNING, TIMEDOUT # type: ignore[import] from ats import AtsTest import logging -class OpenmpiMachine(machines.Machine): +class OpenmpiMachine( machines.Machine ): "Openmpi Machine." - def init(self): + def init( self ): self.numtests = 0 self.maxtests = 0 self.numProcsAvailable = 0 - self.logger = logging.getLogger('geos_ats') + self.logger = logging.getLogger( 'geos_ats' ) - def examineOptions(self, options): + def examineOptions( self, options ): "Examine options from command line, possibly override command line choices." - super(OpenmpiMachine, self).examineOptions(options) + super( OpenmpiMachine, self ).examineOptions( options ) # openmpi_numnodes is actually number of jobs self.precommand = options.openmpi_precommand self.terminate = options.openmpi_terminate @@ -32,7 +32,7 @@ def examineOptions(self, options): mpirunCmd = options.openmpi_mpirun else: mpirunCmd = "mpirun" - self.mpirun = os.path.join(self.install, "bin", mpirunCmd) + self.mpirun = os.path.join( self.install, "bin", mpirunCmd ) self.openmpi_args = options.openmpi_args.split() # numberTestsRunningMax is actually the number of processors @@ -53,126 +53,126 @@ def examineOptions(self, options): self.openmpi_install = options.openmpi_install self.openmpi_mpirun = options.openmpi_mpirun - def getNumberOfProcessors(self): + def getNumberOfProcessors( self ): return self.numberMaxProcessors - def addOptions(self, parser): + def addOptions( self, parser ): "Add options needed on this machine." - parser.add_option("--openmpi_numnodes", - "--numNodes", - action="store", - type="int", - dest='openmpi_numnodes', - default=2, - help="Number of nodes to use") - - parser.add_option("--openmpi_maxprocs", - "--maxProcs", - action="store", - type="int", - dest='openmpi_maxprocs', - default=0, - help="Maximum number of processors to use") - - parser.add_option("--openmpi_procspernode", - "--procsPerNode", - action="store", - type="int", - dest='openmpi_procspernode', - default=1, - help="Number of processors per node") - - parser.add_option("--openmpi_precommand", - action="store", - type="str", - dest='openmpi_precommand', - default="", - help="Prepend to each command") - - parser.add_option("--openmpi_terminate", - action="store", - type="str", - dest='openmpi_terminate', - default="", - help="Command for abnormal termination") - - parser.add_option("--openmpi_install", - action="store", - type="str", - dest='openmpi_install', - default="", - help="Location of the openmpi install") - - parser.add_option("--openmpi_args", - action="store", - type="str", - dest='openmpi_args', - default="", - help="Arguments for openmpi mpirun command") - - parser.add_option("--openmpi_mpirun", - action="store", - type="str", - dest='openmpi_mpirun', - default="", - help="Set the mpi execution command") - - def calculateCommandList(self, test): + parser.add_option( "--openmpi_numnodes", + "--numNodes", + action="store", + type="int", + dest='openmpi_numnodes', + default=2, + help="Number of nodes to use" ) + + parser.add_option( "--openmpi_maxprocs", + "--maxProcs", + action="store", + type="int", + dest='openmpi_maxprocs', + default=0, + help="Maximum number of processors to use" ) + + parser.add_option( "--openmpi_procspernode", + "--procsPerNode", + action="store", + type="int", + dest='openmpi_procspernode', + default=1, + help="Number of processors per node" ) + + parser.add_option( "--openmpi_precommand", + action="store", + type="str", + dest='openmpi_precommand', + default="", + help="Prepend to each command" ) + + parser.add_option( "--openmpi_terminate", + action="store", + type="str", + dest='openmpi_terminate', + default="", + help="Command for abnormal termination" ) + + parser.add_option( "--openmpi_install", + action="store", + type="str", + dest='openmpi_install', + default="", + help="Location of the openmpi install" ) + + parser.add_option( "--openmpi_args", + action="store", + type="str", + dest='openmpi_args', + default="", + help="Arguments for openmpi mpirun command" ) + + parser.add_option( "--openmpi_mpirun", + action="store", + type="str", + dest='openmpi_mpirun', + default="", + help="Set the mpi execution command" ) + + def calculateCommandList( self, test ): """Prepare for run of executable using a suitable command. First we get the plain command line that would be executed on a vanilla serial machine, then we modify it if necessary for use on this machines. """ options = AtsTest.getOptions() - basicCommands = self.calculateBasicCommandList(test) + basicCommands = self.calculateBasicCommandList( test ) if self.precommand: import time - timeNow = time.strftime('%H%M%S', time.localtime()) - test.jobname = "t%d_%d%s%s" % (test.np, test.serialNumber, test.namebase, timeNow) - pre = self.precommand % {"np": test.np, "J": test.jobname} + timeNow = time.strftime( '%H%M%S', time.localtime() ) + test.jobname = "t%d_%d%s%s" % ( test.np, test.serialNumber, test.namebase, timeNow ) + pre = self.precommand % { "np": test.np, "J": test.jobname } commandList = pre.split() else: commandList = [] - commandList += [self.mpirun, "-n", "%d" % test.np] + commandList += [ self.mpirun, "-n", "%d" % test.np ] commandList += self.openmpi_args commandList += basicCommands return commandList - def canRun(self, test): + def canRun( self, test ): """Is this machine able to run the test interactively when resources become available? If so return ''. Otherwise return the reason it cannot be run here. """ - np = max(test.np, 1) + np = max( test.np, 1 ) if np > self.numberMaxProcessors: return "Too many processors needed (%d)" % np - def canRunNow(self, test): + def canRunNow( self, test ): "Is this machine able to run this test now? Return True/False" - np = max(test.np, 1) - return ((self.numtests < self.maxtests) and (self.numProcsAvailable >= np)) + np = max( test.np, 1 ) + return ( ( self.numtests < self.maxtests ) and ( self.numProcsAvailable >= np ) ) - def noteLaunch(self, test): + def noteLaunch( self, test ): """A test has been launched.""" - np = max(test.np, 1) + np = max( test.np, 1 ) self.numProcsAvailable -= np self.numtests += 1 - def noteEnd(self, test): + def noteEnd( self, test ): """A test has finished running. """ - np = max(test.np, 1) + np = max( test.np, 1 ) self.numProcsAvailable += np self.numtests -= 1 - def periodicReport(self): + def periodicReport( self ): "Report on current status of tasks" - terminal("-" * 80) - terminal("CURRENTLY RUNNING %d of %d tests." % (self.numtests, self.maxtests)) - terminal("-" * 80) - terminal("CURRENTLY UTILIZING %d processors (max %d)." % - (self.numberMaxProcessors - self.numProcsAvailable, self.numberMaxProcessors)) - terminal("-" * 80) - - def kill(self, test): + terminal( "-" * 80 ) + terminal( "CURRENTLY RUNNING %d of %d tests." % ( self.numtests, self.maxtests ) ) + terminal( "-" * 80 ) + terminal( "CURRENTLY UTILIZING %d processors (max %d)." % + ( self.numberMaxProcessors - self.numProcsAvailable, self.numberMaxProcessors ) ) + terminal( "-" * 80 ) + + def kill( self, test ): "Final cleanup if any." import subprocess @@ -182,14 +182,14 @@ def kill(self, test): try: test.child.terminate() except: - self.logger.debug("Terminating job`") + self.logger.debug( "Terminating job`" ) if self.terminate: try: - term = self.terminate % {"J": test.jobname} - retcode = subprocess.call(term, shell=True) + term = self.terminate % { "J": test.jobname } + retcode = subprocess.call( term, shell=True ) if retcode < 0: - log(f"---- kill() in openmpi.py, command= {term} failed with return code -{retcode} ----", - echo=True) + log( f"---- kill() in openmpi.py, command= {term} failed with return code -{retcode} ----", + echo=True ) except: - self.logger.debug("Terminating job`") + self.logger.debug( "Terminating job`" ) diff --git a/geos_ats_package/geos_ats/machines/summit.py b/geos_ats_package/geos_ats/machines/summit.py index 07b2c58..016d5f2 100644 --- a/geos_ats_package/geos_ats/machines/summit.py +++ b/geos_ats_package/geos_ats/machines/summit.py @@ -1,106 +1,106 @@ #ATS:SequentialMachine machines.summit summitMachine 1 #ATS:summit machines.summit summitMachine 1 -from ats import machines # type: ignore[import] +from ats import machines # type: ignore[import] from ats import machines, debug, atsut from ats import log, terminal from ats import configuration -from ats.atsut import RUNNING, TIMEDOUT # type: ignore[import] +from ats.atsut import RUNNING, TIMEDOUT # type: ignore[import] from ats import AtsTest import os import subprocess import logging -class summitMachine(machines.Machine): +class summitMachine( machines.Machine ): """ run using on Summit using jsrun. """ - def init(self): + def init( self ): self.numtests = 0 self.numProcsAvailable = 0 - self.logger = logging.getLogger('geos_ats') + self.logger = logging.getLogger( 'geos_ats' ) - def examineOptions(self, options): + def examineOptions( self, options ): "Examine options from command line, possibly override command line choices." - super(summitMachine, self).examineOptions(options) + super( summitMachine, self ).examineOptions( options ) # Get total cpu cores available, and convert to number of gpus! - self.numberMaxProcessors = int(os.getenv("LSB_MAX_NUM_PROCESSORS", "0")) - 1 - self.numberMaxGPUS = (self.numberMaxProcessors / 42) * 6 + self.numberMaxProcessors = int( os.getenv( "LSB_MAX_NUM_PROCESSORS", "0" ) ) - 1 + self.numberMaxGPUS = ( self.numberMaxProcessors / 42 ) * 6 self.numberTestsRunningMax = self.numberMaxProcessors self.numProcsAvailable = self.numberMaxProcessors - def getNumberOfProcessors(self): + def getNumberOfProcessors( self ): return self.numberMaxProcessors - def getNumberOfGPUS(self): + def getNumberOfGPUS( self ): return self.numberMaxGPUS - def addOptions(self, parser): + def addOptions( self, parser ): "Add options needed on this machine." - parser.add_option("--numNodes", - action="store", - type="int", - dest='numNodes', - default=1, - help="Number of nodes to use") + parser.add_option( "--numNodes", + action="store", + type="int", + dest='numNodes', + default=1, + help="Number of nodes to use" ) return - def calculateCommandList(self, test): + def calculateCommandList( self, test ): """Prepare for run of executable using a suitable command. First we get the plain command line that would be executed on a vanilla serial machine, then we modify it if necessary for use on this machines. """ options = AtsTest.getOptions() - basicCommands = self.calculateBasicCommandList(test) + basicCommands = self.calculateBasicCommandList( test ) commandList = [] ngpu = test.ngpu - commandList += ["jsrun", "--np", "%d" % test.np, "-g", "%d" % ngpu, "-c", "1", "-b", "rs"] + commandList += [ "jsrun", "--np", "%d" % test.np, "-g", "%d" % ngpu, "-c", "1", "-b", "rs" ] commandList += basicCommands return commandList - def canRun(self, test): + def canRun( self, test ): """Is this machine able to run the test interactively when resources become available? If so return ''. Otherwise return the reason it cannot be run here. """ - np = max(test.np, 1) + np = max( test.np, 1 ) if np > self.numberMaxProcessors: return f"Too many processors needed ({np})" gpusPerTask = test.ngpu if np * gpusPerTask > self.numberMaxGPUS: err = "Too many gpus needed ({np * gpusPerTask:d})" - self.logger(err) + self.logger( err ) return err - def canRunNow(self, test): + def canRunNow( self, test ): """We let lrun do the scheduling so return true.""" return True - def noteLaunch(self, test): + def noteLaunch( self, test ): """A test has been launched.""" self.numtests += 1 - def noteEnd(self, test): + def noteEnd( self, test ): """A test has finished running. """ self.numtests -= 1 - def periodicReport(self): + def periodicReport( self ): "Report on current status of tasks" - terminal("-" * 80) - terminal("Running jobs:") - os.system("jslist -r") - terminal("Queued jobs:") - os.system("jslist -p") - terminal("-" * 80) - - def kill(self, test): + terminal( "-" * 80 ) + terminal( "Running jobs:" ) + os.system( "jslist -r" ) + terminal( "Queued jobs:" ) + os.system( "jslist -p" ) + terminal( "-" * 80 ) + + def kill( self, test ): "Final cleanup if any." if test.status is RUNNING or test.status is TIMEDOUT: @@ -109,11 +109,11 @@ def kill(self, test): try: test.child.terminate() except: - self.logger.debug("Terminating job`") + self.logger.debug( "Terminating job`" ) try: - retcode = subprocess.call("jskill all", shell=True) + retcode = subprocess.call( "jskill all", shell=True ) if retcode < 0: - log("command= %s failed with return code %d" % ("jskill all", retcode), echo=True) + log( "command= %s failed with return code %d" % ( "jskill all", retcode ), echo=True ) except: - self.logger.debug("Terminating job`") + self.logger.debug( "Terminating job`" ) diff --git a/geos_ats_package/geos_ats/main.py b/geos_ats_package/geos_ats/main.py index 2f37158..f03b63f 100644 --- a/geos_ats_package/geos_ats/main.py +++ b/geos_ats_package/geos_ats/main.py @@ -8,12 +8,12 @@ import logging from geos_ats import command_line_parsers -test_actions = ("run", "rerun", "check", "continue") -report_actions = ("run", "rerun", "report", "continue") +test_actions = ( "run", "rerun", "check", "continue" ) +report_actions = ( "run", "rerun", "report", "continue" ) # Setup the logger -logging.basicConfig(level=logging.DEBUG, format='(%(asctime)s %(module)s:%(lineno)d) %(message)s') -logger = logging.getLogger('geos_ats') +logging.basicConfig( level=logging.DEBUG, format='(%(asctime)s %(module)s:%(lineno)d) %(message)s' ) +logger = logging.getLogger( 'geos_ats' ) # Job records current_subproc = None @@ -21,144 +21,144 @@ geos_atsStartTime = 0 -def check_ats_targets(options, testcases, configOverride, args): +def check_ats_targets( options, testcases, configOverride, args ): """ Determine which files, directories, or tests to run. Handle command line config options. """ - configOverride["executable_path"] = options.geos_bin_dir + configOverride[ "executable_path" ] = options.geos_bin_dir ats_files = [] for a in options.ats_targets: if "=" in a: - key, value = a.split("=") - configOverride[key] = value - args.remove(a) + key, value = a.split( "=" ) + configOverride[ key ] = value + args.remove( a ) elif not options.info: - if os.path.exists(a): - args.remove(a) - if os.path.isdir(a): - newfiles = glob.glob(os.path.join(a, "*.ats")) - ats_files.extend(newfiles) + if os.path.exists( a ): + args.remove( a ) + if os.path.isdir( a ): + newfiles = glob.glob( os.path.join( a, "*.ats" ) ) + ats_files.extend( newfiles ) else: - ats_files.append(a) + ats_files.append( a ) else: - testcases.append(a) + testcases.append( a ) else: if options.action in test_actions: - logger.error(f"The command line arg '{a}' is not recognized." - " An ats file or a directory name is expected.") - sys.exit(1) + logger.error( f"The command line arg '{a}' is not recognized." + " An ats file or a directory name is expected." ) + sys.exit( 1 ) # If no files were specified, look in the target directories - for d in ['.', 'integratedTests']: - if len(ats_files) == 0: - if os.path.isdir(d): - ats_files.extend(glob.glob(os.path.join(d, "*.ats"))) + for d in [ '.', 'integratedTests' ]: + if len( ats_files ) == 0: + if os.path.isdir( d ): + ats_files.extend( glob.glob( os.path.join( d, "*.ats" ) ) ) # prune out ats continue files. - for a in ats_files[:]: - if a.endswith("continue.ats"): - ats_files.remove(a) + for a in ats_files[ : ]: + if a.endswith( "continue.ats" ): + ats_files.remove( a ) return ats_files -def build_ats_arguments(options, ats_files, originalargv, config): +def build_ats_arguments( options, ats_files, originalargv, config ): # construct the argv to pass to the ATS: atsargv = [] - atsargv.append(originalargv[0]) - atsargv.append("--showGroupStartOnly") - atsargv.append("--logs=%s" % options.logs) + atsargv.append( originalargv[ 0 ] ) + atsargv.append( "--showGroupStartOnly" ) + atsargv.append( "--logs=%s" % options.logs ) if config.batch_interactive: - atsargv.append("--allInteractive") - atsargv.extend(config.machine_options) + atsargv.append( "--allInteractive" ) + atsargv.extend( config.machine_options ) for x in options.ats: # Add the appropriate argument indicators back based on their length - if len(x[0]) == 1: - x[0] = '-' + x[0] + if len( x[ 0 ] ) == 1: + x[ 0 ] = '-' + x[ 0 ] else: - x[0] = '--' + x[0] - atsargv.extend(x) + x[ 0 ] = '--' + x[ 0 ] + atsargv.extend( x ) - for f in os.environ.get('ATS_FILTER', '').split(','): - atsargv.extend(['-f', f]) + for f in os.environ.get( 'ATS_FILTER', '' ).split( ',' ): + atsargv.extend( [ '-f', f ] ) - atsargv.extend(ats_files) + atsargv.extend( ats_files ) sys.argv = atsargv -def write_log_dir_summary(logdir, originalargv): +def write_log_dir_summary( logdir, originalargv ): from geos_ats import configuration_record - with open(os.path.join(logdir, "geos_ats.config"), "w") as logconfig: - tmp = " ".join(originalargv[1:]) - logconfig.write(f'Run with: "{tmp}"\n') - configuration_record.infoConfigShow(True, logconfig) + with open( os.path.join( logdir, "geos_ats.config" ), "w" ) as logconfig: + tmp = " ".join( originalargv[ 1: ] ) + logconfig.write( f'Run with: "{tmp}"\n' ) + configuration_record.infoConfigShow( True, logconfig ) -def handleShutdown(signum, frame): +def handleShutdown( signum, frame ): if current_jobid is not None: term = "scancel -n %s" % current_jobid - subprocess.call(term, shell=True) - sys.exit(1) + subprocess.call( term, shell=True ) + sys.exit( 1 ) -def handle_salloc_relaunch(options, originalargv, configOverride): +def handle_salloc_relaunch( options, originalargv, configOverride ): tests = [ options.action in test_actions, options.salloc, options.machine - in ("SlurmProcessorScheduled", "GeosAtsSlurmProcessorScheduled"), "SLURM_JOB_ID" not in os.environ + in ( "SlurmProcessorScheduled", "GeosAtsSlurmProcessorScheduled" ), "SLURM_JOB_ID" not in os.environ ] - if all(tests): + if all( tests ): if options.sallocOptions != "": - sallocCommand = ["salloc"] + options.sallocOptions.split(" ") + sallocCommand = [ "salloc" ] + options.sallocOptions.split( " " ) else: - sallocCommand = ["salloc", "-ppdebug", "--exclusive", "-N", "%d" % options.numNodes] + sallocCommand = [ "salloc", "-ppdebug", "--exclusive", "-N", "%d" % options.numNodes ] if "testmodifier" in configOverride: - if configOverride["testmodifier"] == "memcheck": - p = subprocess.Popen(['sinfo', '-o', '%l', '-h', '-ppdebug'], stdout=subprocess.PIPE) + if configOverride[ "testmodifier" ] == "memcheck": + p = subprocess.Popen( [ 'sinfo', '-o', '%l', '-h', '-ppdebug' ], stdout=subprocess.PIPE ) out, err = p.communicate() - tarray = out.split(":") + tarray = out.split( ":" ) seconds = tarray.pop() minutes = tarray.pop() hours = 0 days = 0 - if len(tarray) > 0: + if len( tarray ) > 0: hours = tarray.pop() try: - days, hours = hours.split('-') + days, hours = hours.split( '-' ) except ValueError as e: - logger.debug(e) - limit = min(360, (24 * int(days) + int(hours)) * 60 + int(minutes)) - sallocCommand.extend(["-t", "%d" % limit]) + logger.debug( e ) + limit = min( 360, ( 24 * int( days ) + int( hours ) ) * 60 + int( minutes ) ) + sallocCommand.extend( [ "-t", "%d" % limit ] ) # generate a "unique" name for the salloc job so we can remove it later - timeNow = time.strftime('%H%M%S', time.localtime()) + timeNow = time.strftime( '%H%M%S', time.localtime() ) current_jobid = "geos_ats_%s" % timeNow # add the name to the arguments (this will override any previous name specification) - sallocCommand.extend(["-J", "%s" % current_jobid]) + sallocCommand.extend( [ "-J", "%s" % current_jobid ] ) # register our signal handler - signal.signal(signal.SIGTERM, handleShutdown) + signal.signal( signal.SIGTERM, handleShutdown ) command = sallocCommand # omit --workingDir on relaunch, as we have already changed directories - relaunchargv = [x for x in originalargv if not x.startswith("--workingDir")] + relaunchargv = [ x for x in originalargv if not x.startswith( "--workingDir" ) ] command += relaunchargv - command += ["--logs=%s" % options.logs] - p = subprocess.Popen(command) + command += [ "--logs=%s" % options.logs ] + p = subprocess.Popen( command ) p.wait() - sys.exit(p.returncode) + sys.exit( p.returncode ) def getLogDirBaseName(): return "TestLogs" -def create_log_directory(options): +def create_log_directory( options ): """ When the action will run tests (e.g. "run", "rerun", "check", "continue", then the LogDir is numbered, and saved. When the action does not run @@ -171,192 +171,193 @@ def create_log_directory(options): basename = getLogDirBaseName() index = 1 while True: - options.logs = "%s.%03d" % (basename, index) - if not os.path.exists(options.logs): + options.logs = "%s.%03d" % ( basename, index ) + if not os.path.exists( options.logs ): break index += 1 # make the options.logs - os.mkdir(options.logs) + os.mkdir( options.logs ) # make symlink try: - if os.path.exists(basename): - if os.path.islink(basename): - os.remove(basename) + if os.path.exists( basename ): + if os.path.islink( basename ): + os.remove( basename ) else: - logger.error(f"unable to replace {basename} with a symlink to {options.logs}") + logger.error( f"unable to replace {basename} with a symlink to {options.logs}" ) - if not os.path.exists(basename): - os.symlink(options.logs, basename) + if not os.path.exists( basename ): + os.symlink( options.logs, basename ) except: - logger.error("unable to name a symlink to to logdir") + logger.error( "unable to name a symlink to to logdir" ) else: if options.action in test_actions: - options.logs = "%s.%s" % (getLogDirBaseName(), options.action) + options.logs = "%s.%s" % ( getLogDirBaseName(), options.action ) elif options.info: - options.logs = "%s.info" % (getLogDirBaseName()) + options.logs = "%s.info" % ( getLogDirBaseName() ) else: - if not os.path.join(options.logs): - os.mkdir(options.logs) + if not os.path.join( options.logs ): + os.mkdir( options.logs ) -def check_timing_file(options, config): - if options.action in ["run", "rerun", "continue"]: +def check_timing_file( options, config ): + if options.action in [ "run", "rerun", "continue" ]: if config.timing_file: - if not os.path.isfile(config.timing_file): - logger.warning(f'Timing file does not exist {config.timing_file}') + if not os.path.isfile( config.timing_file ): + logger.warning( f'Timing file does not exist {config.timing_file}' ) return from geos_ats import configuration_record - with open(config.timing_file, "r") as filep: + with open( config.timing_file, "r" ) as filep: for line in filep: - if not line.startswith('#'): + if not line.startswith( '#' ): tokens = line.split() - configuration_record.globalTestTimings[tokens[0]] = int(tokens[1]) + configuration_record.globalTestTimings[ tokens[ 0 ] ] = int( tokens[ 1 ] ) -def append_test_end_step(machine): +def append_test_end_step( machine ): """ Add extra processing to the end of tests """ originalNoteEnd = machine.noteEnd - def noteEndWrapper(test): - test.geos_atsTestCase.status.noteEnd(test) - return originalNoteEnd(test) + def noteEndWrapper( test ): + test.geos_atsTestCase.status.noteEnd( test ) + return originalNoteEnd( test ) machine.noteEnd = noteEndWrapper -def check_working_dir(workingDir): +def check_working_dir( workingDir ): if workingDir: - if os.path.isdir(workingDir): - os.chdir(workingDir) + if os.path.isdir( workingDir ): + os.chdir( workingDir ) else: - logger.error(f"The requested working dir does not appear to exist: {workingDir}") + logger.error( f"The requested working dir does not appear to exist: {workingDir}" ) quit() -def infoOptions(title, options): +def infoOptions( title, options ): from geos_ats import common_utilities - topic = common_utilities.InfoTopic(title) + topic = common_utilities.InfoTopic( title ) topic.startBanner() - table = common_utilities.TextTable(2) + table = common_utilities.TextTable( 2 ) for opt, desc in options: - table.addRow(opt, desc) + table.addRow( opt, desc ) table.printTable() topic.endBanner() -def infoParagraph(title, paragraphs): +def infoParagraph( title, paragraphs ): from geos_ats import common_utilities - topic = common_utilities.InfoTopic(title) + topic = common_utilities.InfoTopic( title ) topic.startBanner() - table = common_utilities.TextTable(1) + table = common_utilities.TextTable( 1 ) for p in paragraphs: - table.addRow(p) + table.addRow( p ) table.rowbreak = 1 table.maxwidth = 75 table.printTable() topic.endBanner() -def info(args): - from geos_ats import (common_utilities, configuration_record, test_steps, suite_settings, test_case, test_modifier) - - infoLabels = lambda *x: suite_settings.infoLabels(suite_settings.__file__) - infoOwners = lambda *x: suite_settings.infoOwners(suite_settings.__file__) - - menu = common_utilities.InfoTopic("geos_ats info menu") - menu.addTopic("teststep", "Reference on all the TestStep", test_steps.infoTestSteps) - menu.addTopic("testcase", "Reference on the TestCase", test_case.infoTestCase) - menu.addTopic("labels", "List of labels", infoLabels) - menu.addTopic("owners", "List of owners", infoOwners) - menu.addTopic("config", "Reference on config options", configuration_record.infoConfig) - menu.addTopic("actions", "Description of the command line action options", - lambda *x: infoOptions("command line actions", command_line_parsers.action_ptions)) - menu.addTopic("checks", "Description of the command line check options", - lambda *x: infoOptions("command line checks", command_line_parsers.check_options)) - menu.addTopic("modifiers", "List of test modifiers", test_modifier.infoTestModifier) +def info( args ): + from geos_ats import ( common_utilities, configuration_record, test_steps, suite_settings, test_case, + test_modifier ) + + infoLabels = lambda *x: suite_settings.infoLabels( suite_settings.__file__ ) + infoOwners = lambda *x: suite_settings.infoOwners( suite_settings.__file__ ) + + menu = common_utilities.InfoTopic( "geos_ats info menu" ) + menu.addTopic( "teststep", "Reference on all the TestStep", test_steps.infoTestSteps ) + menu.addTopic( "testcase", "Reference on the TestCase", test_case.infoTestCase ) + menu.addTopic( "labels", "List of labels", infoLabels ) + menu.addTopic( "owners", "List of owners", infoOwners ) + menu.addTopic( "config", "Reference on config options", configuration_record.infoConfig ) + menu.addTopic( "actions", "Description of the command line action options", + lambda *x: infoOptions( "command line actions", command_line_parsers.action_ptions ) ) + menu.addTopic( "checks", "Description of the command line check options", + lambda *x: infoOptions( "command line checks", command_line_parsers.check_options ) ) + menu.addTopic( "modifiers", "List of test modifiers", test_modifier.infoTestModifier ) # menu.addTopic("testconfig", "Information on the testconfig.py file", # lambda *x: infoParagraph("testconfig", command_line_parsers.test_config_info)) - menu.process(args) + menu.process( args ) -def report(manager): +def report( manager ): """The report action""" - from geos_ats import (test_case, reporting, configuration_record) + from geos_ats import ( test_case, reporting, configuration_record ) testcases = test_case.TESTS.values() if configuration_record.config.report_wait: - reporter = reporting.ReportWait(testcases) - reporter.report(sys.stdout) + reporter = reporting.ReportWait( testcases ) + reporter.report( sys.stdout ) if configuration_record.config.report_text: - reporter = reporting.ReportText(testcases) - with open(configuration_record.config.report_text_file, "w") as filep: - reporter.report(filep) + reporter = reporting.ReportText( testcases ) + with open( configuration_record.config.report_text_file, "w" ) as filep: + reporter.report( filep ) if configuration_record.config.report_text_echo: - with open(configuration_record.config.report_text_file, "r") as filep: - sys.stdout.write(filep.read()) + with open( configuration_record.config.report_text_file, "r" ) as filep: + sys.stdout.write( filep.read() ) if configuration_record.config.report_html: - reporter = reporting.ReportHTML(testcases) + reporter = reporting.ReportHTML( testcases ) reporter.report() if configuration_record.config.report_ini: - reporter = reporting.ReportIni(testcases) - with open(configuration_record.config.report_ini_file, "w") as filep: - reporter.report(filep) + reporter = reporting.ReportIni( testcases ) + with open( configuration_record.config.report_ini_file, "w" ) as filep: + reporter.report( filep ) if configuration_record.config.report_timing: - reporter = reporting.ReportTiming(testcases) + reporter = reporting.ReportTiming( testcases ) if not configuration_record.config.report_timing_overwrite: try: - with open(configuration_record.config.timing_file, "r") as filep: - reporter.getOldTiming(filep) + with open( configuration_record.config.timing_file, "r" ) as filep: + reporter.getOldTiming( filep ) except IOError as e: - logger.debug(e) - with open(configuration_record.config.timing_file, "w") as filep: - reporter.report(filep) + logger.debug( e ) + with open( configuration_record.config.timing_file, "w" ) as filep: + reporter.report( filep ) -def summary(manager, alog, short=False): +def summary( manager, alog, short=False ): """Periodic summary and final summary""" - from geos_ats import (reporting, configuration_record, test_case) + from geos_ats import ( reporting, configuration_record, test_case ) - if len(manager.testlist) == 0: + if len( manager.testlist ) == 0: return - if hasattr(manager.machine, "getNumberOfProcessors"): - totalNumberOfProcessors = getattr(manager.machine, "getNumberOfProcessors", None)() + if hasattr( manager.machine, "getNumberOfProcessors" ): + totalNumberOfProcessors = getattr( manager.machine, "getNumberOfProcessors", None )() else: totalNumberOfProcessors = 1 - reporter = reporting.ReportTextPeriodic(manager.testlist) - reporter.report(geos_atsStartTime, totalNumberOfProcessors) + reporter = reporting.ReportTextPeriodic( manager.testlist ) + reporter.report( geos_atsStartTime, totalNumberOfProcessors ) if configuration_record.config.report_html and configuration_record.config.report_html_periodic: testcases = test_case.TESTS.values() - reporter = reporting.ReportHTML(testcases) - reporter.report(refresh=30) + reporter = reporting.ReportHTML( testcases ) + reporter.report( refresh=30 ) if configuration_record.config.report_text: testcases = test_case.TESTS.values() - reporter = reporting.ReportText(testcases) - with open(configuration_record.config.report_text_file, "w") as filep: - reporter.report(filep) + reporter = reporting.ReportText( testcases ) + with open( configuration_record.config.report_text_file, "w" ) as filep: + reporter.report( filep ) -def append_geos_ats_summary(manager): +def append_geos_ats_summary( manager ): initial_summary = manager.summary - def new_summary(*xargs, **kwargs): - initial_summary(*xargs, **kwargs) - summary(manager, None) + def new_summary( *xargs, **kwargs ): + initial_summary( *xargs, **kwargs ) + summary( manager, None ) manager.summary = new_summary @@ -369,8 +370,8 @@ def main(): # --------------------------------- # Handle command line arguments # --------------------------------- - originalargv = sys.argv[:] - options = command_line_parsers.parse_command_line_arguments(originalargv) + originalargv = sys.argv[ : ] + options = command_line_parsers.parse_command_line_arguments( originalargv ) # Set logging verbosity verbosity_options = { @@ -379,24 +380,24 @@ def main(): 'warning': logging.WARNING, 'error': logging.ERROR } - logger.setLevel(verbosity_options[options.verbose]) + logger.setLevel( verbosity_options[ options.verbose ] ) # Set key environment variables before importing ats from geos_ats import machines search_path = '' if options.machine_dir is not None: - if os.path.isdir(options.machine_dir): + if os.path.isdir( options.machine_dir ): search_path = options.machine_dir else: - logger.error(f'Target machine dir does not exist: {options.machine_dir}') - logger.error('geos_ats will continue searching in the default path') + logger.error( f'Target machine dir does not exist: {options.machine_dir}' ) + logger.error( 'geos_ats will continue searching in the default path' ) if not search_path: - search_path = os.path.dirname(machines.__file__) - os.environ['MACHINE_DIR'] = search_path + search_path = os.path.dirname( machines.__file__ ) + os.environ[ 'MACHINE_DIR' ] = search_path if options.machine: - os.environ["MACHINE_TYPE"] = options.machine + os.environ[ "MACHINE_TYPE" ] = options.machine # --------------------------------- # Setup ATS @@ -404,12 +405,12 @@ def main(): configOverride = {} testcases = [] configFile = '' - check_working_dir(options.workingDir) - create_log_directory(options) + check_working_dir( options.workingDir ) + create_log_directory( options ) # Check the test configuration from geos_ats import configuration_record - configuration_record.initializeConfig(configFile, configOverride, options) + configuration_record.initializeConfig( configFile, configOverride, options ) config = configuration_record.config config.geos_bin_dir = options.geos_bin_dir @@ -417,26 +418,26 @@ def main(): if 'skip_missing' in r: config.restart_skip_missing = True elif 'exclude' in r: - config.restart_exclude_pattern.append(r[-1]) + config.restart_exclude_pattern.append( r[ -1 ] ) # Check the report location if options.logs: - config.report_html_file = os.path.join(options.logs, 'test_results.html') - config.report_text_file = os.path.join(options.logs, 'test_results.txt') - config.report_ini_file = os.path.join(options.logs, 'test_results.ini') + config.report_html_file = os.path.join( options.logs, 'test_results.html' ) + config.report_text_file = os.path.join( options.logs, 'test_results.txt' ) + config.report_ini_file = os.path.join( options.logs, 'test_results.ini' ) - ats_files = check_ats_targets(options, testcases, configOverride, originalargv) - build_ats_arguments(options, ats_files, originalargv, config) + ats_files = check_ats_targets( options, testcases, configOverride, originalargv ) + build_ats_arguments( options, ats_files, originalargv, config ) # Additional setup tasks - check_timing_file(options, config) - handle_salloc_relaunch(options, originalargv, configOverride) + check_timing_file( options, config ) + handle_salloc_relaunch( options, originalargv, configOverride ) # Print config information - logger.debug("*" * 80) + logger.debug( "*" * 80 ) for notation in config.report_notations: - logger.debug(notation) - logger.debug("*" * 80) + logger.debug( notation ) + logger.debug( "*" * 80 ) # --------------------------------- # Initialize ATS @@ -444,37 +445,37 @@ def main(): geos_atsStartTime = time.time() # Note: the sys.argv is read here by default - import ats # type: ignore[import] + import ats # type: ignore[import] ats.manager.init() - logger.debug('Copying options to the geos_ats config record file') - config.copy_values(ats.manager.machine) + logger.debug( 'Copying options to the geos_ats config record file' ) + config.copy_values( ats.manager.machine ) # Glue global values - ats.AtsTest.glue(action=options.action) - ats.AtsTest.glue(checkoption=options.check) - ats.AtsTest.glue(configFile=configFile) - ats.AtsTest.glue(configOverride=configOverride) - ats.AtsTest.glue(testmode=False) - ats.AtsTest.glue(atsFlags=options.ats) - ats.AtsTest.glue(atsFiles=ats_files) - ats.AtsTest.glue(machine=options.machine) - ats.AtsTest.glue(config=config) - if len(testcases): - ats.AtsTest.glue(testcases=testcases) + ats.AtsTest.glue( action=options.action ) + ats.AtsTest.glue( checkoption=options.check ) + ats.AtsTest.glue( configFile=configFile ) + ats.AtsTest.glue( configOverride=configOverride ) + ats.AtsTest.glue( testmode=False ) + ats.AtsTest.glue( atsFlags=options.ats ) + ats.AtsTest.glue( atsFiles=ats_files ) + ats.AtsTest.glue( machine=options.machine ) + ats.AtsTest.glue( config=config ) + if len( testcases ): + ats.AtsTest.glue( testcases=testcases ) else: - ats.AtsTest.glue(testcases="all") + ats.AtsTest.glue( testcases="all" ) - from geos_ats import (common_utilities, suite_settings, test_case, test_steps, user_utilities) + from geos_ats import ( common_utilities, suite_settings, test_case, test_steps, user_utilities ) # Set ats options - append_geos_ats_summary(ats.manager) - append_test_end_step(ats.manager.machine) + append_geos_ats_summary( ats.manager ) + append_test_end_step( ats.manager.machine ) ats.manager.machine.naptime = 0.2 ats.log.echo = True # Logging - if options.action in ("run", "rerun", "check", "continue"): - write_log_dir_summary(options.logs, originalargv) + if options.action in ( "run", "rerun", "check", "continue" ): + write_log_dir_summary( options.logs, originalargv ) if options.action in test_actions: ats.manager.firstBanner() @@ -486,41 +487,41 @@ def main(): # Make sure all the testcases requested were found if testcases != "all": - if len(testcases): - logger.error(f"ERROR: Unknown testcases {str(testcases)}") - logger.error(f"ATS files: {str(ats_files)}") - sys.exit(1) + if len( testcases ): + logger.error( f"ERROR: Unknown testcases {str(testcases)}" ) + logger.error( f"ATS files: {str(ats_files)}" ) + sys.exit( 1 ) # Report: if options.action in report_actions: - report(ats.manager) + report( ats.manager ) # clean if options.action == "veryclean": - common_utilities.removeLogDirectories(os.getcwd()) - files = [config.report_html_file, config.report_ini_file, config.report_text_file] + common_utilities.removeLogDirectories( os.getcwd() ) + files = [ config.report_html_file, config.report_ini_file, config.report_text_file ] for f in files: - if os.path.exists(f): - os.remove(f) + if os.path.exists( f ): + os.remove( f ) # clean the temporary logfile that is not needed for certain actions. if options.action not in test_actions: if options.logs is not None: - if os.path.exists(options.logs): - shutil.rmtree(options.logs) + if os.path.exists( options.logs ): + shutil.rmtree( options.logs ) # return 0 if all tests passed, 1 otherwise try: if options.failIfTestsFail: - with open(os.path.join(options.logs, "test_results.html"), 'r') as f: - contents = ''.join(f.readlines()).split("DETAILED RESULTS")[1] + with open( os.path.join( options.logs, "test_results.html" ), 'r' ) as f: + contents = ''.join( f.readlines() ).split( "DETAILED RESULTS" )[ 1 ] messages = [ "class=\"red\">FAIL", "class=\"yellow\">SKIPPED", "class=\"reddish\">FAIL", "class=\"yellow\">NOT RUN" ] - result = any([m in contents for m in messages]) + result = any( [ m in contents for m in messages ] ) except IOError as e: - logger.debug(e) + logger.debug( e ) # Other ATS steps not previously included: ats.manager.postprocess() @@ -529,13 +530,13 @@ def main(): ats.manager.finalBanner() # Remove unnecessary log dirs created with clean runs - none_dir = os.path.join(options.workingDir, 'None') - if os.path.exists(none_dir): - shutil.rmtree(none_dir) + none_dir = os.path.join( options.workingDir, 'None' ) + if os.path.exists( none_dir ): + shutil.rmtree( none_dir ) return result if __name__ == "__main__": result = main() - sys.exit(result) + sys.exit( result ) diff --git a/geos_ats_package/geos_ats/reporting.py b/geos_ats_package/geos_ats/reporting.py index 60ff3e4..1f9a3bb 100644 --- a/geos_ats_package/geos_ats/reporting.py +++ b/geos_ats_package/geos_ats/reporting.py @@ -5,12 +5,12 @@ import re from geos_ats.configuration_record import config import sys -import ats # type: ignore[import] +import ats # type: ignore[import] from configparser import ConfigParser import logging # Get the active logger instance -logger = logging.getLogger('geos_ats') +logger = logging.getLogger( 'geos_ats' ) # The following are ALEATS test status values. # The order is important for the ReportGroup: lower values take precendence @@ -31,202 +31,203 @@ UNEXPECTEDPASS = 14 # A tuple of test status values. -STATUS = (FAILRUN, UNEXPECTEDPASS, FAILRUNOPTIONAL, FAILCHECK, FAILCHECKMINOR, TIMEOUT, NOTRUN, FILTERED, RUNNING, - INPROGRESS, PASS, EXPECTEDFAIL, SKIP, BATCH, NOTBUILT) +STATUS = ( FAILRUN, UNEXPECTEDPASS, FAILRUNOPTIONAL, FAILCHECK, FAILCHECKMINOR, TIMEOUT, NOTRUN, FILTERED, RUNNING, + INPROGRESS, PASS, EXPECTEDFAIL, SKIP, BATCH, NOTBUILT ) -STATUS_NOTDONE = (NOTRUN, RUNNING, INPROGRESS, BATCH) +STATUS_NOTDONE = ( NOTRUN, RUNNING, INPROGRESS, BATCH ) -class ReportBase(object): +class ReportBase( object ): """Base class for reporting. The constructor takes in a sequence of testcases (of type test_case), and from each testcase, a ReportTestCase object is created.""" - def __init__(self, testcases): + def __init__( self, testcases ): pass -class ReportTiming(ReportBase): +class ReportTiming( ReportBase ): """Reporting class that is used for outputting test timings""" - def __init__(self, testcases): - self.reportcases = [ReportTestCase(t) for t in testcases] + def __init__( self, testcases ): + self.reportcases = [ ReportTestCase( t ) for t in testcases ] self.timings = {} - def getOldTiming(self, fp): + def getOldTiming( self, fp ): for line in fp: - if not line.startswith('#'): + if not line.startswith( '#' ): tokens = line.split() - self.timings[tokens[0]] = int(tokens[1]) + self.timings[ tokens[ 0 ] ] = int( tokens[ 1 ] ) - def report(self, fp): + def report( self, fp ): for testcase in self.reportcases: - if testcase.status in [PASS, TIMEOUT]: - self.timings[testcase.testcase.name] = int(testcase.testcase.status.totalTime()) + if testcase.status in [ PASS, TIMEOUT ]: + self.timings[ testcase.testcase.name ] = int( testcase.testcase.status.totalTime() ) output = "" - for key in sorted(self.timings): - output += "%s %d\n" % (key, self.timings[key]) - fp.writelines(output) + for key in sorted( self.timings ): + output += "%s %d\n" % ( key, self.timings[ key ] ) + fp.writelines( output ) -class ReportIni(ReportBase): +class ReportIni( ReportBase ): """Minimal reporting class that is used for bits status emails""" - def __init__(self, testcases): - self.reportcases = [ReportTestCase(t) for t in testcases] + def __init__( self, testcases ): + self.reportcases = [ ReportTestCase( t ) for t in testcases ] # A dictionary where the key is a status, and the value is a sequence of ReportTestCases self.reportcaseResults = {} for status in STATUS: - self.reportcaseResults[status] = [t for t in self.reportcases if t.status == status] + self.reportcaseResults[ status ] = [ t for t in self.reportcases if t.status == status ] self.displayName = {} - self.displayName[FAILRUN] = "FAILRUN" - self.displayName[FAILRUNOPTIONAL] = "FAILRUNOPTIONAL" - self.displayName[FAILCHECK] = "FAILCHECK" - self.displayName[FAILCHECKMINOR] = "FAILCHECKMINOR" - self.displayName[TIMEOUT] = "TIMEOUT" - self.displayName[NOTRUN] = "NOTRUN" - self.displayName[INPROGRESS] = "INPROGRESS" - self.displayName[FILTERED] = "FILTERED" - self.displayName[RUNNING] = "RUNNING" - self.displayName[PASS] = "PASSED" - self.displayName[SKIP] = "SKIPPED" - self.displayName[BATCH] = "BATCHED" - self.displayName[NOTBUILT] = "NOTBUILT" - self.displayName[EXPECTEDFAIL] = "EXPECTEDFAIL" - self.displayName[UNEXPECTEDPASS] = "UNEXPECTEDPASS" - - def __getTestCaseName(testcase): + self.displayName[ FAILRUN ] = "FAILRUN" + self.displayName[ FAILRUNOPTIONAL ] = "FAILRUNOPTIONAL" + self.displayName[ FAILCHECK ] = "FAILCHECK" + self.displayName[ FAILCHECKMINOR ] = "FAILCHECKMINOR" + self.displayName[ TIMEOUT ] = "TIMEOUT" + self.displayName[ NOTRUN ] = "NOTRUN" + self.displayName[ INPROGRESS ] = "INPROGRESS" + self.displayName[ FILTERED ] = "FILTERED" + self.displayName[ RUNNING ] = "RUNNING" + self.displayName[ PASS ] = "PASSED" + self.displayName[ SKIP ] = "SKIPPED" + self.displayName[ BATCH ] = "BATCHED" + self.displayName[ NOTBUILT ] = "NOTBUILT" + self.displayName[ EXPECTEDFAIL ] = "EXPECTEDFAIL" + self.displayName[ UNEXPECTEDPASS ] = "UNEXPECTEDPASS" + + def __getTestCaseName( testcase ): return testcase.testcase.name - def report(self, fp): + def report( self, fp ): configParser = ConfigParser() - configParser.add_section("Info") - configParser.set("Info", "Time", time.strftime("%a, %d %b %Y %H:%M:%S")) + configParser.add_section( "Info" ) + configParser.set( "Info", "Time", time.strftime( "%a, %d %b %Y %H:%M:%S" ) ) try: platform = socket.gethostname() except: - logger.debug("Could not get host name") + logger.debug( "Could not get host name" ) platform = "unknown" - configParser.set("Info", "Platform", platform) + configParser.set( "Info", "Platform", platform ) extraNotations = "" for line in config.report_notations: - line_split = line.split(":") - if len(line_split) != 2: - line_split = line.split("=") - if len(line_split) != 2: + line_split = line.split( ":" ) + if len( line_split ) != 2: + line_split = line.split( "=" ) + if len( line_split ) != 2: extraNotations += "\"" + line.strip() + "\"" continue - configParser.set("Info", line_split[0].strip(), line_split[1].strip()) + configParser.set( "Info", line_split[ 0 ].strip(), line_split[ 1 ].strip() ) if extraNotations != "": - configParser.set("Info", "Extra Notations", extraNotations) + configParser.set( "Info", "Extra Notations", extraNotations ) - configParser.add_section("Results") - configParser.add_section("Custodians") - configParser.add_section("Documentation") + configParser.add_section( "Results" ) + configParser.add_section( "Custodians" ) + configParser.add_section( "Documentation" ) undocumentedTests = [] for status in STATUS: testNames = [] - for reportcaseResult in self.reportcaseResults[status]: + for reportcaseResult in self.reportcaseResults[ status ]: testName = reportcaseResult.testcase.name - testNames.append(testName) + testNames.append( testName ) - owner = getowner(testName, reportcaseResult.testcase) + owner = getowner( testName, reportcaseResult.testcase ) if owner is not None: - configParser.set("Custodians", testName, owner) + configParser.set( "Custodians", testName, owner ) if config.report_doc_link: - linkToDocumentation = os.path.join(config.report_doc_dir, testName, testName + ".html") - if os.path.exists(linkToDocumentation): - configParser.set("Documentation", testName, linkToDocumentation) + linkToDocumentation = os.path.join( config.report_doc_dir, testName, testName + ".html" ) + if os.path.exists( linkToDocumentation ): + configParser.set( "Documentation", testName, linkToDocumentation ) else: if not reportcaseResult.testcase.nodoc: - undocumentedTests.append(testName) - linkToDocumentation = getowner(testName, reportcaseResult.testcase) - testNames = sorted(testNames) - configParser.set("Results", self.displayName[status], ";".join(testNames)) - undocumentedTests = sorted(undocumentedTests) - configParser.set("Documentation", "undocumented", ";".join(undocumentedTests)) - configParser.write(fp) + undocumentedTests.append( testName ) + linkToDocumentation = getowner( testName, reportcaseResult.testcase ) + testNames = sorted( testNames ) + configParser.set( "Results", self.displayName[ status ], ";".join( testNames ) ) + undocumentedTests = sorted( undocumentedTests ) + configParser.set( "Documentation", "undocumented", ";".join( undocumentedTests ) ) + configParser.write( fp ) -class ReportText(ReportBase): +class ReportText( ReportBase ): - def __init__(self, testcases): + def __init__( self, testcases ): - ReportBase.__init__(self, testcases) + ReportBase.__init__( self, testcases ) - self.reportcases = [ReportTestCase(t) for t in testcases] + self.reportcases = [ ReportTestCase( t ) for t in testcases ] # A dictionary where the key is a status, and the value is a sequence of ReportTestCases self.reportcaseResults = {} for status in STATUS: - self.reportcaseResults[status] = [t for t in self.reportcases if t.status == status] + self.reportcaseResults[ status ] = [ t for t in self.reportcases if t.status == status ] self.displayName = {} - self.displayName[FAILRUN] = "FAIL RUN" - self.displayName[FAILRUNOPTIONAL] = "FAIL RUN (OPTIONAL STEP)" - self.displayName[FAILCHECK] = "FAIL CHECK" - self.displayName[FAILCHECKMINOR] = "FAIL CHECK (MINOR)" - self.displayName[TIMEOUT] = "TIMEOUT" - self.displayName[NOTRUN] = "NOT RUN" - self.displayName[INPROGRESS] = "INPROGRESS" - self.displayName[FILTERED] = "FILTERED" - self.displayName[RUNNING] = "RUNNING" - self.displayName[PASS] = "PASSED" - self.displayName[SKIP] = "SKIPPED" - self.displayName[BATCH] = "BATCHED" - self.displayName[NOTBUILT] = "NOT BUILT" - self.displayName[EXPECTEDFAIL] = "EXPECTEDFAIL" - self.displayName[UNEXPECTEDPASS] = "UNEXPECTEDPASS" - - def report(self, fp): + self.displayName[ FAILRUN ] = "FAIL RUN" + self.displayName[ FAILRUNOPTIONAL ] = "FAIL RUN (OPTIONAL STEP)" + self.displayName[ FAILCHECK ] = "FAIL CHECK" + self.displayName[ FAILCHECKMINOR ] = "FAIL CHECK (MINOR)" + self.displayName[ TIMEOUT ] = "TIMEOUT" + self.displayName[ NOTRUN ] = "NOT RUN" + self.displayName[ INPROGRESS ] = "INPROGRESS" + self.displayName[ FILTERED ] = "FILTERED" + self.displayName[ RUNNING ] = "RUNNING" + self.displayName[ PASS ] = "PASSED" + self.displayName[ SKIP ] = "SKIPPED" + self.displayName[ BATCH ] = "BATCHED" + self.displayName[ NOTBUILT ] = "NOT BUILT" + self.displayName[ EXPECTEDFAIL ] = "EXPECTEDFAIL" + self.displayName[ UNEXPECTEDPASS ] = "UNEXPECTEDPASS" + + def report( self, fp ): """Write out the text report to the give file pointer""" - self.writeSummary(fp, (FAILRUN, UNEXPECTEDPASS, FAILRUNOPTIONAL, FAILCHECK, FAILCHECKMINOR, TIMEOUT, NOTRUN, - INPROGRESS, FILTERED, PASS, EXPECTEDFAIL, SKIP, BATCH, NOTBUILT)) - self.writeLongest(fp, 5) - self.writeDetails(fp, (FAILRUN, UNEXPECTEDPASS, FAILRUNOPTIONAL, FAILCHECK, FAILCHECKMINOR, TIMEOUT, FILTERED)) + self.writeSummary( fp, ( FAILRUN, UNEXPECTEDPASS, FAILRUNOPTIONAL, FAILCHECK, FAILCHECKMINOR, TIMEOUT, NOTRUN, + INPROGRESS, FILTERED, PASS, EXPECTEDFAIL, SKIP, BATCH, NOTBUILT ) ) + self.writeLongest( fp, 5 ) + self.writeDetails( fp, + ( FAILRUN, UNEXPECTEDPASS, FAILRUNOPTIONAL, FAILCHECK, FAILCHECKMINOR, TIMEOUT, FILTERED ) ) - def writeSummary(self, fp, statuses=STATUS): + def writeSummary( self, fp, statuses=STATUS ): """The summary groups each TestCase by its status.""" - fp.write("=" * 80) + fp.write( "=" * 80 ) from geos_ats import common_utilities for status in statuses: - tests = self.reportcaseResults[status] - num = len(tests) - fp.write(f"\n {self.displayName[status]} : {num}") + tests = self.reportcaseResults[ status ] + num = len( tests ) + fp.write( f"\n {self.displayName[status]} : {num}" ) if num > 0: testlist = [] for test in tests: testname = test.testcase.name - retries = getattr(test.testcase.atsGroup, "retries", 0) + retries = getattr( test.testcase.atsGroup, "retries", 0 ) if retries > 0: testname += '[retry:%d]' % retries - testlist.append(testname) - fp.write(f' ( {" ".join( testlist )} ) ') + testlist.append( testname ) + fp.write( f' ( {" ".join( testlist )} ) ' ) - def writeDetails(self, - fp, - statuses=(FAILRUN, UNEXPECTEDPASS, FAILRUNOPTIONAL, FAILCHECK, FAILCHECKMINOR, INPROGRESS), - columns=("Status", "TestCase", "Elapsed", "Resources", "TestStep", "OutFile")): + def writeDetails( self, + fp, + statuses=( FAILRUN, UNEXPECTEDPASS, FAILRUNOPTIONAL, FAILCHECK, FAILCHECKMINOR, INPROGRESS ), + columns=( "Status", "TestCase", "Elapsed", "Resources", "TestStep", "OutFile" ) ): """This function provides more information about each of the test cases""" from geos_ats import common_utilities - table = common_utilities.TextTable(len(columns)) - table.setHeader(*columns) + table = common_utilities.TextTable( len( columns ) ) + table.setHeader( *columns ) table.rowbreakstyle = "-" printTable = False for status in statuses: - tests = self.reportcaseResults[status] + tests = self.reportcaseResults[ status ] - if len(tests) == 0: + if len( tests ) == 0: continue printTable = True @@ -235,43 +236,43 @@ def writeDetails(self, label = "" pathstr = "" if test.laststep: - paths = testcase.resultPaths(test.laststep) + paths = testcase.resultPaths( test.laststep ) label = test.laststep.label() - pathstr = " ".join([os.path.relpath(x) for x in paths]) + pathstr = " ".join( [ os.path.relpath( x ) for x in paths ] ) row = [] for col in columns: if col == "Status": - statusDisplay = self.displayName[test.status] - retries = getattr(testcase.atsGroup, "retries", 0) + statusDisplay = self.displayName[ test.status ] + retries = getattr( testcase.atsGroup, "retries", 0 ) if retries > 0: statusDisplay += "/retry:%d" % retries - row.append(statusDisplay) + row.append( statusDisplay ) elif col == "Directory": - row.append(os.path.relpath(testcase.path)) + row.append( os.path.relpath( testcase.path ) ) elif col == "TestCase": - row.append(testcase.name) + row.append( testcase.name ) elif col == "TestStep": - row.append(label) + row.append( label ) elif col == "OutFile": - row.append(pathstr) + row.append( pathstr ) elif col == "Elapsed": - row.append(ats.times.hms(test.elapsed)) + row.append( ats.times.hms( test.elapsed ) ) elif col == "Resources": - row.append(ats.times.hms(test.resources)) + row.append( ats.times.hms( test.resources ) ) else: - raise RuntimeError(f"Unknown column {col}") + raise RuntimeError( f"Unknown column {col}" ) - table.addRow(*row) + table.addRow( *row ) table.addRowBreak() - fp.write('\n') + fp.write( '\n' ) if printTable: - table.printTable(fp) - fp.write('\n') + table.printTable( fp ) + fp.write( '\n' ) - def writeLongest(self, fp, num=5): + def writeLongest( self, fp, num=5 ): """The longer running tests are reported""" timing = [] @@ -279,39 +280,39 @@ def writeLongest(self, fp, num=5): for test in self.reportcases: elapsed = test.elapsed if elapsed > 0: - timing.append((elapsed, test)) + timing.append( ( elapsed, test ) ) - timing = sorted(timing, reverse=True) + timing = sorted( timing, reverse=True ) - if len(timing) > 0: - fp.write('\n') - fp.write('\n LONGEST RUNNING TESTS:') - for elapsed, test in timing[:num]: - fp.write(f" {ats.times.hms(elapsed)} {test.testcase.name}") + if len( timing ) > 0: + fp.write( '\n' ) + fp.write( '\n LONGEST RUNNING TESTS:' ) + for elapsed, test in timing[ :num ]: + fp.write( f" {ats.times.hms(elapsed)} {test.testcase.name}" ) -class ReportTextPeriodic(ReportText): +class ReportTextPeriodic( ReportText ): """This class is used during the periodic reports. It is initialized with the actual ATS tests from the ATS manager object. The report inherits from ReportText, and extend that behavior with """ - def __init__(self, atstests): + def __init__( self, atstests ): self.atstest = atstests - testcases = list(set([test.geos_atsTestCase for test in atstests])) - ReportText.__init__(self, testcases) - - def report(self, startTime, totalProcessors=None): - self.writeSummary(sys.stdout, - (FAILRUN, UNEXPECTEDPASS, FAILRUNOPTIONAL, FAILCHECK, FAILCHECKMINOR, TIMEOUT, NOTRUN, - INPROGRESS, FILTERED, RUNNING, PASS, EXPECTEDFAIL, SKIP, BATCH, NOTBUILT)) - self.writeUtilization(sys.stdout, startTime, totalProcessors) - self.writeLongest(sys.stdout) - self.writeDetails(sys.stdout, (FAILRUN, UNEXPECTEDPASS, FAILRUNOPTIONAL, FAILCHECK, FAILCHECKMINOR, RUNNING), - ("Status", "TestCase", "Directory", "Elapsed", "Resources", "TestStep")) - - def writeUtilization(self, fp, startTime, totalProcessors=None): + testcases = list( set( [ test.geos_atsTestCase for test in atstests ] ) ) + ReportText.__init__( self, testcases ) + + def report( self, startTime, totalProcessors=None ): + self.writeSummary( sys.stdout, + ( FAILRUN, UNEXPECTEDPASS, FAILRUNOPTIONAL, FAILCHECK, FAILCHECKMINOR, TIMEOUT, NOTRUN, + INPROGRESS, FILTERED, RUNNING, PASS, EXPECTEDFAIL, SKIP, BATCH, NOTBUILT ) ) + self.writeUtilization( sys.stdout, startTime, totalProcessors ) + self.writeLongest( sys.stdout ) + self.writeDetails( sys.stdout, ( FAILRUN, UNEXPECTEDPASS, FAILRUNOPTIONAL, FAILCHECK, FAILCHECKMINOR, RUNNING ), + ( "Status", "TestCase", "Directory", "Elapsed", "Resources", "TestStep" ) ) + + def writeUtilization( self, fp, startTime, totalProcessors=None ): """Machine utilization is reported""" totalResourcesUsed = 0.0 totaltime = time.time() - startTime @@ -321,27 +322,27 @@ def writeUtilization(self, fp, startTime, totalProcessors=None): totalResourcesUsed += resources if totalResourcesUsed > 0: - fp.write('\n') - fp.write(f"\n TOTAL TIME : {ats.times.hms( totaltime )}") - fp.write(f"\n TOTAL PROCESSOR-TIME : {ats.times.hms(totalResourcesUsed )}") + fp.write( '\n' ) + fp.write( f"\n TOTAL TIME : {ats.times.hms( totaltime )}" ) + fp.write( f"\n TOTAL PROCESSOR-TIME : {ats.times.hms(totalResourcesUsed )}" ) if totalProcessors: availableResources = totalProcessors * totaltime utilization = totalResourcesUsed / availableResources * 100.0 - fp.write(f" AVAIL PROCESSOR-TIME : {ats.times.hms(availableResources )}") - fp.write(f" RESOURCE UTILIZATION : {utilization:5.3g}%") + fp.write( f" AVAIL PROCESSOR-TIME : {ats.times.hms(availableResources )}" ) + fp.write( f" RESOURCE UTILIZATION : {utilization:5.3g}%" ) -class ReportHTML(ReportBase): +class ReportHTML( ReportBase ): """HTML Reporting""" # only launch a web browser once. launchedBrowser = False - def __init__(self, testcases): - ReportBase.__init__(self, testcases) + def __init__( self, testcases ): + ReportBase.__init__( self, testcases ) - self.reportcases = [ReportTestCase(t) for t in testcases] + self.reportcases = [ ReportTestCase( t ) for t in testcases ] # A dictionary keyed by Status. The value is a list of ReportGroup self.groupResults = None @@ -353,125 +354,125 @@ def __init__(self, testcases): self.initializeReportGroups() self.color = {} - self.color[FAILRUN] = "red" - self.color[FAILRUNOPTIONAL] = "yellow" - self.color[FAILCHECK] = "reddish" - self.color[FAILCHECKMINOR] = "reddish" - self.color[TIMEOUT] = "reddish" - self.color[NOTRUN] = "yellow" - self.color[INPROGRESS] = "blue" - self.color[FILTERED] = "blueish" - self.color[RUNNING] = "orange" - self.color[PASS] = "green" - self.color[SKIP] = "yellow" - self.color[BATCH] = "yellow" - self.color[NOTBUILT] = "blueish" - self.color[EXPECTEDFAIL] = "green" - self.color[UNEXPECTEDPASS] = "red" + self.color[ FAILRUN ] = "red" + self.color[ FAILRUNOPTIONAL ] = "yellow" + self.color[ FAILCHECK ] = "reddish" + self.color[ FAILCHECKMINOR ] = "reddish" + self.color[ TIMEOUT ] = "reddish" + self.color[ NOTRUN ] = "yellow" + self.color[ INPROGRESS ] = "blue" + self.color[ FILTERED ] = "blueish" + self.color[ RUNNING ] = "orange" + self.color[ PASS ] = "green" + self.color[ SKIP ] = "yellow" + self.color[ BATCH ] = "yellow" + self.color[ NOTBUILT ] = "blueish" + self.color[ EXPECTEDFAIL ] = "green" + self.color[ UNEXPECTEDPASS ] = "red" self.displayName = {} - self.displayName[FAILRUN] = "FAIL RUN" - self.displayName[FAILRUNOPTIONAL] = "FAIL RUN (OPTIONAL STEP)" - self.displayName[FAILCHECK] = "FAIL CHECK" - self.displayName[FAILCHECKMINOR] = "FAIL CHECK (MINOR)" - self.displayName[TIMEOUT] = "TIMEOUT" - self.displayName[NOTRUN] = "NOT RUN" - self.displayName[INPROGRESS] = "INPROGRESS" - self.displayName[FILTERED] = "FILTERED" - self.displayName[RUNNING] = "RUNNING" - self.displayName[PASS] = "PASSED" - self.displayName[SKIP] = "SKIPPED" - self.displayName[BATCH] = "BATCHED" - self.displayName[NOTBUILT] = "NOTBUILT" - self.displayName[EXPECTEDFAIL] = "EXPECTEDFAIL" - self.displayName[UNEXPECTEDPASS] = "UNEXPECTEDPASS" + self.displayName[ FAILRUN ] = "FAIL RUN" + self.displayName[ FAILRUNOPTIONAL ] = "FAIL RUN (OPTIONAL STEP)" + self.displayName[ FAILCHECK ] = "FAIL CHECK" + self.displayName[ FAILCHECKMINOR ] = "FAIL CHECK (MINOR)" + self.displayName[ TIMEOUT ] = "TIMEOUT" + self.displayName[ NOTRUN ] = "NOT RUN" + self.displayName[ INPROGRESS ] = "INPROGRESS" + self.displayName[ FILTERED ] = "FILTERED" + self.displayName[ RUNNING ] = "RUNNING" + self.displayName[ PASS ] = "PASSED" + self.displayName[ SKIP ] = "SKIPPED" + self.displayName[ BATCH ] = "BATCHED" + self.displayName[ NOTBUILT ] = "NOTBUILT" + self.displayName[ EXPECTEDFAIL ] = "EXPECTEDFAIL" + self.displayName[ UNEXPECTEDPASS ] = "UNEXPECTEDPASS" self.html_filename = config.report_html_file - def initializeReportGroups(self): + def initializeReportGroups( self ): testdir = {} # place testcases into groups for reportcase in self.reportcases: dirname = reportcase.testcase.dirname if dirname not in testdir: - testdir[dirname] = [] - testdir[dirname].append(reportcase) + testdir[ dirname ] = [] + testdir[ dirname ].append( reportcase ) - self.groups = [ReportGroup(key, value) for key, value in testdir.items()] + self.groups = [ ReportGroup( key, value ) for key, value in testdir.items() ] # place groups into a dictionary keyed on the group status self.groupResults = {} for status in STATUS: - self.groupResults[status] = [g for g in self.groups if g.status == status] + self.groupResults[ status ] = [ g for g in self.groups if g.status == status ] - def report(self, refresh=0): + def report( self, refresh=0 ): # potentially regenerate the html documentation for the test suite. # # This doesn't seem to work: # self.generateDocumentation() - sp = open(self.html_filename, 'w') + sp = open( self.html_filename, 'w' ) if refresh: - if not any(g.status in (RUNNING, NOTRUN, INPROGRESS) for g in self.groups): + if not any( g.status in ( RUNNING, NOTRUN, INPROGRESS ) for g in self.groups ): refresh = 0 - self.writeHeader(sp, refresh) - self.writeSummary(sp) + self.writeHeader( sp, refresh ) + self.writeSummary( sp ) if config.report_doc_link: - self.writeDoclink(sp) + self.writeDoclink( sp ) # Set the columns to display if config.report_doc_link: - groupColumns = ("Name", "Custodian", "Status") + groupColumns = ( "Name", "Custodian", "Status" ) else: - groupColumns = ("Name", "Status") + groupColumns = ( "Name", "Status" ) - testcaseColumns = ("Status", "Name", "TestStep", "Age", "Elapsed", "Resources", "Output") + testcaseColumns = ( "Status", "Name", "TestStep", "Age", "Elapsed", "Resources", "Output" ) # write the details - self.writeTable(sp, groupColumns, testcaseColumns) - self.writeFooter(sp) + self.writeTable( sp, groupColumns, testcaseColumns ) + self.writeFooter( sp ) sp.close() # launch the browser, if requested. self.browser() - def generateDocumentation(self): + def generateDocumentation( self ): """Generate the HTML documentation using atddoc""" if not config.report_doc_link: return - testdocfile = os.path.join(config.report_doc_dir, "testdoc.html") - if (os.path.exists(testdocfile) and not config.report_doc_remake): + testdocfile = os.path.join( config.report_doc_dir, "testdoc.html" ) + if ( os.path.exists( testdocfile ) and not config.report_doc_remake ): # Check for any atd files newer than the test html documentation newest = 0 - for root, dirs, files in os.walk(config.report_doc_dir): + for root, dirs, files in os.walk( config.report_doc_dir ): for file in files: - if file.endswith(".atd"): - filetime = os.path.getmtime(os.path.join(root, file)) + if file.endswith( ".atd" ): + filetime = os.path.getmtime( os.path.join( root, file ) ) if filetime > newest: newest = filetime - if os.path.getmtime(testdocfile) > newest: - logger.info(f"HTML documentation found in {os.path.relpath(testdocfile)}. Not regenerating.") + if os.path.getmtime( testdocfile ) > newest: + logger.info( f"HTML documentation found in {os.path.relpath(testdocfile)}. Not regenerating." ) return - logger.info("Generating HTML documentation files (running 'atddoc')...") + logger.info( "Generating HTML documentation files (running 'atddoc')..." ) retcode = True try: - geos_atsdir = os.path.realpath(os.path.dirname(__file__)) - atddoc = os.path.join(geos_atsdir, "atddoc.py") + geos_atsdir = os.path.realpath( os.path.dirname( __file__ ) ) + atddoc = os.path.join( geos_atsdir, "atddoc.py" ) #retcode = subprocess.call( atddoc, cwd=config.report_doc_dir, stdout=subprocess.PIPE) - retcode = subprocess.call(atddoc, cwd=config.report_doc_dir) + retcode = subprocess.call( atddoc, cwd=config.report_doc_dir ) except OSError as e: - logger.debug(e) + logger.debug( e ) if retcode: - logger.info(f" Failed to create HTML documentation in {config.report_doc_dir}") + logger.info( f" Failed to create HTML documentation in {config.report_doc_dir}" ) else: - logger.info(f" HTML documentation created in {config.report_doc_dir}") + logger.info( f" HTML documentation created in {config.report_doc_dir}" ) - def writeRowHeader(self, sp, groupColumns, testcaseColumns): + def writeRowHeader( self, sp, groupColumns, testcaseColumns ): header = f""" """ - sp.write(header) + sp.write( header ) - def writeTable(self, sp, groupColumns, testcaseColumns): - colspan = len(groupColumns) + len(testcaseColumns) + def writeTable( self, sp, groupColumns, testcaseColumns ): + colspan = len( groupColumns ) + len( testcaseColumns ) header = f""" @@ -532,12 +533,12 @@ def writeTable(self, sp, groupColumns, testcaseColumns): rowcount = 0 testgroups = [] for status in STATUS: - testgroups.extend(self.groupResults[status]) + testgroups.extend( self.groupResults[ status ] ) for test in testgroups: - rowspan = len(test.testcases) + rowspan = len( test.testcases ) if rowcount <= 0: - self.writeRowHeader(sp, groupColumns, testcaseColumns) + self.writeRowHeader( sp, groupColumns, testcaseColumns ) rowcount += 30 rowcount -= rowspan @@ -557,7 +558,7 @@ def writeTable(self, sp, groupColumns, testcaseColumns): elif col == "Custodian": if config.report_doc_link: - owner = getowner(test.name, test.testcases[0].testcase) + owner = getowner( test.name, test.testcases[ 0 ].testcase ) if owner is not None: header += f'\n ' else: @@ -567,14 +568,14 @@ def writeTable(self, sp, groupColumns, testcaseColumns): elif col == "Status": header += f'' else: - raise RuntimeError(f"Unknown column {col}") + raise RuntimeError( f"Unknown column {col}" ) for testcase in test.testcases: for col in testcaseColumns: if col == "Status": - statusDisplay = self.displayName[testcase.status] - retries = getattr(testcase.testcase.atsGroup, "retries", 0) + statusDisplay = self.displayName[ testcase.status ] + retries = getattr( testcase.testcase.atsGroup, "retries", 0 ) if retries > 0: statusDisplay += "
retry: %d" % retries header += f'\n' @@ -586,33 +587,34 @@ def writeTable(self, sp, groupColumns, testcaseColumns): if config.report_doc_link: docfound = False # first check for the full problem name, with the domain extension - testhtml = os.path.join(config.report_doc_dir, test.name, testcase.testcase.name + ".html") - if os.path.exists(testhtml): + testhtml = os.path.join( config.report_doc_dir, test.name, + testcase.testcase.name + ".html" ) + if os.path.exists( testhtml ): docfound = True else: # next check for the full problem name without the domain extension - testhtml = os.path.join(config.report_doc_dir, test.name, - testcase.testcase.name + ".html") - if os.path.exists(testhtml): + testhtml = os.path.join( config.report_doc_dir, test.name, + testcase.testcase.name + ".html" ) + if os.path.exists( testhtml ): docfound = True else: # final check for any of the input file names for step in testcase.testcase.steps: - if getattr(step.p, "deck", None): - [inputname, suffix] = getattr(step.p, "deck").rsplit('.', 1) - testhtml = os.path.join(config.report_doc_dir, test.name, - inputname + ".html") - if os.path.exists(testhtml): + if getattr( step.p, "deck", None ): + [ inputname, suffix ] = getattr( step.p, "deck" ).rsplit( '.', 1 ) + testhtml = os.path.join( config.report_doc_dir, test.name, + inputname + ".html" ) + if os.path.exists( testhtml ): # match with the first input file docfound = True break if docfound: - testref = 'href="%s"' % (testhtml) + testref = 'href="%s"' % ( testhtml ) else: if not testcase.testcase.nodoc: testlinksuffix += '
undocumented' - undocumented.append(testcase.testcase.name) + undocumented.append( testcase.testcase.name ) header += f"\n" @@ -629,10 +631,10 @@ def writeTable(self, sp, groupColumns, testcaseColumns): if testcase.diffage: difftime = testcase.diffage - days = int(difftime) / 86400 + days = int( difftime ) / 86400 if days > 0: difftime -= days * 86400 - hours = int(difftime) / 3600 + hours = int( difftime ) / 3600 if days == 0: # "New" diff file - don't color header += f'\n' @@ -661,17 +663,17 @@ def writeTable(self, sp, groupColumns, testcaseColumns): header += "\n" else: - raise RuntimeError(f"Unknown column {col}") + raise RuntimeError( f"Unknown column {col}" ) header += '\n' @@ -680,17 +682,17 @@ def writeTable(self, sp, groupColumns, testcaseColumns): if config.report_doc_link: header += '\n

Undocumented test problems:

' header += '\n\n" - sp.write(header) + sp.write( header ) - def writeHeader(self, sp, refresh): - gentime = time.strftime("%a, %d %b %Y %H:%M:%S") + def writeHeader( self, sp, refresh ): + gentime = time.strftime( "%a, %d %b %Y %H:%M:%S" ) header = """ @@ -763,13 +765,13 @@ def writeHeader(self, sp, refresh): try: platform = socket.gethostname() except: - logger.debug("Could not get host name") + logger.debug( "Could not get host name" ) platform = "unknown" if os.name == "nt": - username = os.getenv("USERNAME") + username = os.getenv( "USERNAME" ) else: - username = os.getenv("USER") + username = os.getenv( "USER" ) header += f"""

@@ -789,9 +791,9 @@ def writeHeader(self, sp, refresh):

""" - sp.write(header) + sp.write( header ) - def writeSummary(self, sp): + def writeSummary( self, sp ): summary = """
  {owner} {self.displayName[test.status]}{statusDisplay}{testcase.testcase.name}{testlinksuffix}{hours}h" seen = {} - for stepnum, step in enumerate(testcase.testcase.steps): - paths = testcase.testcase.resultPaths(step) + for stepnum, step in enumerate( testcase.testcase.steps ): + paths = testcase.testcase.resultPaths( step ) for p in paths: # if p has already been accounted for, doesn't exist, or is an empty file, don't print it. - if (((p in seen) or not os.path.exists(p)) or (os.stat(p)[6] == 0)): + if ( ( ( p in seen ) or not os.path.exists( p ) ) or ( os.stat( p )[ 6 ] == 0 ) ): continue header += f"\n{os.path.basename(p)}
" - seen[p] = 1 + seen[ p ] = 1 header += "\n
@@ -804,8 +806,8 @@ def writeSummary(self, sp): haveRetry = False for status in STATUS: - cases = self.groupResults[status] - num = len(cases) + cases = self.groupResults[ status ] + num = len( cases ) summary += f""" @@ -820,7 +822,7 @@ def writeSummary(self, sp): caseref = case.name retries = 0 for test in case.testcases: - retries += getattr(test.testcase.atsGroup, "retries", 0) + retries += getattr( test.testcase.atsGroup, "retries", 0 ) if retries > 0: haveRetry = True casename += '*' @@ -836,19 +838,19 @@ def writeSummary(self, sp): if haveRetry: summary += '\n* indicates that test was retried at least once.' - sp.write(summary) + sp.write( summary ) # Write link to documentation for html - def writeDoclink(self, sp): + def writeDoclink( self, sp ): doc = """

Test problem names with a hyperlink have been documented, the HTML version of which can be viewed by clicking on the link. """ - testdoc = os.path.join(config.report_doc_dir, 'testdoc.html') - testsumm = os.path.join(config.report_doc_dir, 'testdoc-summary.txt') - if os.path.exists(testdoc) and os.path.exists(testsumm): + testdoc = os.path.join( config.report_doc_dir, 'testdoc.html' ) + testsumm = os.path.join( config.report_doc_dir, 'testdoc-summary.txt' ) + if os.path.exists( testdoc ) and os.path.exists( testsumm ): doc += f"""
Or, you can click here for the @@ -858,16 +860,16 @@ def writeDoclink(self, sp): """ doc += '\n

' - sp.write(doc) + sp.write( doc ) - def writeFooter(self, sp): + def writeFooter( self, sp ): footer = """ """ - sp.write(footer) + sp.write( footer ) - def browser(self): + def browser( self ): if ReportHTML.launchedBrowser: return @@ -876,43 +878,43 @@ def browser(self): ReportHTML.launchedBrowser = True command = config.browser_command.split() - command.append("file:%s" % config.report_html_file) - subprocess.Popen(command) + command.append( "file:%s" % config.report_html_file ) + subprocess.Popen( command ) -class ReportWait(ReportBase): +class ReportWait( ReportBase ): """This class is used while with the report_wait config option""" - def __init__(self, testcases): - ReportBase.__init__(self, testcases) + def __init__( self, testcases ): + ReportBase.__init__( self, testcases ) self.testcases = testcases - def report(self, fp): + def report( self, fp ): """Write out the text report to the give file pointer""" import time start = time.time() - sleeptime = 60 # interval to check (seconds) + sleeptime = 60 # interval to check (seconds) while True: notdone = [] for t in self.testcases: t.testReport() - report = ReportTestCase(t) + report = ReportTestCase( t ) if report.status in STATUS_NOTDONE: - notdone.append(t) + notdone.append( t ) if notdone: - rr = ReportText(self.testcases) - rr.writeSummary(sys.stdout, - (FAILRUN, UNEXPECTEDPASS, FAILRUNOPTIONAL, FAILCHECK, FAILCHECKMINOR, TIMEOUT, NOTRUN, - INPROGRESS, FILTERED, PASS, EXPECTEDFAIL, SKIP, BATCH, NOTBUILT)) - time.sleep(sleeptime) + rr = ReportText( self.testcases ) + rr.writeSummary( sys.stdout, + ( FAILRUN, UNEXPECTEDPASS, FAILRUNOPTIONAL, FAILCHECK, FAILCHECKMINOR, TIMEOUT, NOTRUN, + INPROGRESS, FILTERED, PASS, EXPECTEDFAIL, SKIP, BATCH, NOTBUILT ) ) + time.sleep( sleeptime ) else: break -class ReportTestCase(object): +class ReportTestCase( object ): """This class represents the outcome from a TestCase. It hides differences between off-line reports and the periodic reports (when the actual ATS test object is known). In addition to @@ -920,10 +922,10 @@ class ReportTestCase(object): that was run, age of the test, the total elapsed time and total resources used.""" - def __init__(self, testcase): + def __init__( self, testcase ): - self.testcase = testcase # test_case - self.status = None # One of the STATUS values (e.g. FAILRUN, PASS, etc.) + self.testcase = testcase # test_case + self.status = None # One of the STATUS values (e.g. FAILRUN, PASS, etc.) self.laststep = None self.diffage = None self.elapsed = 0.0 @@ -937,14 +939,14 @@ def __init__(self, testcase): if teststatus is None: self.status = NOTRUN return - elif teststatus in (FILTERED, SKIP): + elif teststatus in ( FILTERED, SKIP ): self.status = teststatus return else: - for stepnum, step in enumerate(testcase.steps): + for stepnum, step in enumerate( testcase.steps ): # Get the outcome and related information from the TestStep. - outcome, np, startTime, endTime = self._getStepInfo(step) + outcome, np, startTime, endTime = self._getStepInfo( step ) if outcome == "PASS": # So far so good, move on to the next step @@ -958,14 +960,14 @@ def __init__(self, testcase): self.resources += np * dt outcome = "EXPECTEDFAIL" self.status = EXPECTEDFAIL - break # don't continue past an expected failure + break # don't continue past an expected failure if outcome == "UNEX": dt = endTime - startTime self.elapsed += dt self.resources += np * dt outcome = "UNEXPECTEDPASS" self.status = UNEXPECTEDPASS - break # don't continue past an unexpected pass + break # don't continue past an unexpected pass elif outcome == "SKIP": self.status = SKIP break @@ -1021,9 +1023,9 @@ def __init__(self, testcase): else: self.status = FAILRUN try: - with open(step.p.stdout, 'r') as fp: + with open( step.p.stdout, 'r' ) as fp: for line in fp: - if re.search(config.report_notbuilt_regexp, line): + if re.search( config.report_notbuilt_regexp, line ): self.status = NOTBUILT break except: @@ -1037,76 +1039,76 @@ def __init__(self, testcase): # Don't set the laststep, but use it to get the endTime self.status = PASS laststep = step - laststatus = teststatus.findStep(laststep) - assert (laststatus) - self.diffage = now - laststatus["endTime"] + laststatus = teststatus.findStep( laststep ) + assert ( laststatus ) + self.diffage = now - laststatus[ "endTime" ] assert self.status in STATUS - def _getStepInfo(self, teststep): + def _getStepInfo( self, teststep ): """This function hides the differences between the TestStatus files and the information you can get from the ats test object. It returns (status, np, startTime, endTime )""" - atsTest = getattr(teststep, "atsTest", None) + atsTest = getattr( teststep, "atsTest", None ) endTime = None startTime = None if atsTest is not None: - status = str(atsTest.status) - startTime = getattr(atsTest, "startTime", None) - endTime = getattr(atsTest, "endTime", None) + status = str( atsTest.status ) + startTime = getattr( atsTest, "startTime", None ) + endTime = getattr( atsTest, "endTime", None ) if status == "PASS" and atsTest.expectedResult == ats.FAILED: status = "FAIL" if status == "FAIL" and atsTest.expectedResult == ats.FAILED: status = "UNEX" else: - stepstatus = self.testcase.status.findStep(teststep) + stepstatus = self.testcase.status.findStep( teststep ) if stepstatus is None: status = "INIT" else: - status = stepstatus["result"] - startTime = stepstatus["startTime"] - endTime = stepstatus["endTime"] + status = stepstatus[ "result" ] + startTime = stepstatus[ "startTime" ] + endTime = stepstatus[ "endTime" ] - np = getattr(teststep.p, "np", 1) + np = getattr( teststep.p, "np", 1 ) - if status in ("SKIP", "FILT", "INIT", "PASS", "FAIL", "TIME", "EXEC", "BACH", "EXPT", "UNEX"): - return (status, np, startTime, endTime) + if status in ( "SKIP", "FILT", "INIT", "PASS", "FAIL", "TIME", "EXEC", "BACH", "EXPT", "UNEX" ): + return ( status, np, startTime, endTime ) else: - return ("SKIP", np, startTime, endTime) + return ( "SKIP", np, startTime, endTime ) -class ReportGroup(object): +class ReportGroup( object ): """A class to represent a group of TestCases. Currently, the only grouping done is at the directory level: every testcase in a directory belongs to the same ReportGroup.""" - def __init__(self, groupName, testcases): + def __init__( self, groupName, testcases ): self.name = groupName self.testcases = testcases self.status = NOTRUN if self.testcases: - self.status = min([case.status for case in self.testcases]) + self.status = min( [ case.status for case in self.testcases ] ) assert self.status in STATUS - def __cmp__(self, other): + def __cmp__( self, other ): return self.name == other.name -def getowner(dirname, testcase=None): +def getowner( dirname, testcase=None ): owner = "" if not config.report_doc_link: try: - atdfile = os.path.join(config.report_doc_dir, dirname, dirname + ".atd") - with open(atdfile, "r") as fp: + atdfile = os.path.join( config.report_doc_dir, dirname, dirname + ".atd" ) + with open( atdfile, "r" ) as fp: for line in fp: - match = re.search("CUSTODIAN:: +(.*)$", line) + match = re.search( "CUSTODIAN:: +(.*)$", line ) if not match: - owner = match.group(1) + owner = match.group( 1 ) break except IOError as e: - logger.debug(e) - if owner == "" and testcase and ("owner" in testcase.dictionary): - return testcase.dictionary["owner"] + logger.debug( e ) + if owner == "" and testcase and ( "owner" in testcase.dictionary ): + return testcase.dictionary[ "owner" ] return owner diff --git a/geos_ats_package/geos_ats/rules.py b/geos_ats_package/geos_ats/rules.py index 542e547..741b5e2 100644 --- a/geos_ats_package/geos_ats/rules.py +++ b/geos_ats_package/geos_ats/rules.py @@ -8,33 +8,33 @@ import shutil import logging -logger = logging.getLogger('geos_ats') +logger = logging.getLogger( 'geos_ats' ) -def switch(booleans, i): - booleans[i] = not booleans[i] +def switch( booleans, i ): + booleans[ i ] = not booleans[ i ] -def DeclareCompoundRuleClass(name, RuleA, RuleB): +def DeclareCompoundRuleClass( name, RuleA, RuleB ): """ Declares a class of name name that is a new rule that is the combination of 2 base rules. """ - tmp = type(name, (RuleA, RuleB), {}) + tmp = type( name, ( RuleA, RuleB ), {} ) tmp.numToggles = RuleA.numToggles + RuleB.numToggles tmp.numCombinations = RuleA.numCombinations * RuleB.numCombinations # Define the initializer for the new class - def newInit(self, toggles): - RuleA.__init__(self, toggles, 0, RuleA.numToggles) - RuleB.__init__(self, toggles, RuleA.numToggles) + def newInit( self, toggles ): + RuleA.__init__( self, toggles, 0, RuleA.numToggles ) + RuleB.__init__( self, toggles, RuleA.numToggles ) tmp.__init__ = newInit - globals()[name] = tmp + globals()[ name ] = tmp return tmp -def GenRules(RuleType): +def GenRules( RuleType ): """ Generator that produces a rule for each possible combination of toggles""" nt = RuleType.numToggles @@ -52,127 +52,127 @@ def GenRules(RuleType): Note that the resulting rule can be uniquely ID'd by the sum of the toggle array. """ - for i in range(nc): - toggles = [i & pow(2, x) for x in range(nt)] - tmp = RuleType(toggles) + for i in range( nc ): + toggles = [ i & pow( 2, x ) for x in range( nt ) ] + tmp = RuleType( toggles ) tmp.refresh() yield tmp -class Rule(object): +class Rule( object ): """ Base class for the rules""" - def __init__(self, nToggles, nCombinations, toggles): + def __init__( self, nToggles, nCombinations, toggles ): self.numToggles = nToggles self.numCombinations = nCombinations self.toggles = toggles self.repStrings = {} """ Assumes toggles is set in a way consistent with what is done in GenRules""" - self.id = sum(self.toggles) - self.repStrings["@@POS@@"] = str(self.id) + self.id = sum( self.toggles ) + self.repStrings[ "@@POS@@" ] = str( self.id ) - def GetPosition(self): + def GetPosition( self ): return self.id * 1.0 - def refresh(self): + def refresh( self ): pass - def replaceString(self, string): + def replaceString( self, string ): tmp = string for s in self.repStrings: - tmp = tmp.replace(s, self.repStrings[s]) + tmp = tmp.replace( s, self.repStrings[ s ] ) return tmp - def sedFile(self, fIn, fOut): - inFile = open(fIn) - outFile = open(fOut, 'w') + def sedFile( self, fIn, fOut ): + inFile = open( fIn ) + outFile = open( fOut, 'w' ) for line in inFile: - outFile.write(self.replaceString(line)) + outFile.write( self.replaceString( line ) ) inFile.close() outFile.close() - def checkTimehist(self): + def checkTimehist( self ): # timehist - logger.error('checkTimehist method not defined') + logger.error( 'checkTimehist method not defined' ) -class SetupRules(Rule): +class SetupRules( Rule ): numToggles = 2 - numCombinations = pow(2, numToggles) + numCombinations = pow( 2, numToggles ) - def __init__(self, toggles, minToggle=0, maxToggle=None): + def __init__( self, toggles, minToggle=0, maxToggle=None ): self.setupMin = minToggle self.setupMax = maxToggle - Rule.__init__(self, SetupRules.numToggles, SetupRules.numCombinations, toggles) + Rule.__init__( self, SetupRules.numToggles, SetupRules.numCombinations, toggles ) - def refresh(self): - mtoggles = self.toggles[self.setupMin:self.setupMax] + def refresh( self ): + mtoggles = self.toggles[ self.setupMin:self.setupMax ] - underscoredName = mtoggles[0] - self.isTenthCycle = mtoggles[1] + underscoredName = mtoggles[ 0 ] + self.isTenthCycle = mtoggles[ 1 ] self.baseName = "foo%i" % self.id - self.baseName = "%s%s" % (self.baseName, "_001" if underscoredName else "") - self.repStrings["@@BASE@@"] = self.baseName + self.baseName = "%s%s" % ( self.baseName, "_001" if underscoredName else "" ) + self.repStrings[ "@@BASE@@" ] = self.baseName self.inputDeck = "%s.in" % self.baseName - self.repStrings["@@DECK@@"] = self.inputDeck + self.repStrings[ "@@DECK@@" ] = self.inputDeck self.restartBaseName = "%s_001" % self.baseName - self.restartName = "%s_%s" % (self.restartBaseName, "00010" if self.isTenthCycle else "00000") - self.repStrings["@@RF@@"] = self.restartName + self.restartName = "%s_%s" % ( self.restartBaseName, "00010" if self.isTenthCycle else "00000" ) + self.repStrings[ "@@RF@@" ] = self.restartName - super(SetupRules, self).refresh() + super( SetupRules, self ).refresh() - def GetInputDeckName(self): + def GetInputDeckName( self ): return self.inputDeck - def GetInitialRestartName(self): + def GetInitialRestartName( self ): return self.restartName - def GetBaseName(self): + def GetBaseName( self ): return self.baseName -class CommandLineRules(Rule): +class CommandLineRules( Rule ): numToggles = 2 - numCombinations = pow(2, numToggles) + numCombinations = pow( 2, numToggles ) - def __init__(self, toggles, minToggle=0, maxToggle=None): + def __init__( self, toggles, minToggle=0, maxToggle=None ): self.clMin = minToggle self.clMax = maxToggle - Rule.__init__(self, CommandLineRules.numToggles, CommandLineRules.numCombinations, toggles) + Rule.__init__( self, CommandLineRules.numToggles, CommandLineRules.numCombinations, toggles ) - def refresh(self): - mtoggles = self.toggles[self.clMin:self.clMax] - self.probDefined = mtoggles[0] # use the -prob flag - self.restartDefined = mtoggles[1] # use the -rf flag + def refresh( self ): + mtoggles = self.toggles[ self.clMin:self.clMax ] + self.probDefined = mtoggles[ 0 ] # use the -prob flag + self.restartDefined = mtoggles[ 1 ] # use the -rf flag # self.prob = "-prob %s" % "@@BASE@@" if self.probDefined else "" # self.rf = "-rf %s" % "@@RF@@" if self.restartDefined else "" self.prob = "@@BASE@@" if self.probDefined else "" self.rf = "@@RF@@" if self.restartDefined else "" - self.repStrings["@@CL_PROB@@"] = self.prob - self.repStrings["@@CL_RF@@"] = self.rf + self.repStrings[ "@@CL_PROB@@" ] = self.prob + self.repStrings[ "@@CL_RF@@" ] = self.rf - super(CommandLineRules, self).refresh() + super( CommandLineRules, self ).refresh() def main(): - generator = GenRules(SetupRules) + generator = GenRules( SetupRules ) for rule in generator: - vals = (rule.GetInputDeckName(), rule.GetInitialRestartName(), rule.GetPosition()) - logger.debug(rule.replaceString("InputDeck: %s\tRestartFile: %s\tPos: %f" % vals)) + vals = ( rule.GetInputDeckName(), rule.GetInitialRestartName(), rule.GetPosition() ) + logger.debug( rule.replaceString( "InputDeck: %s\tRestartFile: %s\tPos: %f" % vals ) ) - DeclareCompoundRuleClass("SetupCommand", SetupRules, CommandLineRules) - logger.debug(SetupCommand.numCombinations) - generator = GenRules(SetupCommand) - logger.debug("compound:") + DeclareCompoundRuleClass( "SetupCommand", SetupRules, CommandLineRules ) + logger.debug( SetupCommand.numCombinations ) + generator = GenRules( SetupCommand ) + logger.debug( "compound:" ) for rule in generator: - vals = (rule.GetInputDeckName(), rule.GetInitialRestartName(), rule.GetPosition(), rule.prob, rule.rf) - logger.debug(rule.replaceString("InputDeck: %s\tRestartFile: %s\tPos: %f\t%s\t%s" % vals)) + vals = ( rule.GetInputDeckName(), rule.GetInitialRestartName(), rule.GetPosition(), rule.prob, rule.rf ) + logger.debug( rule.replaceString( "InputDeck: %s\tRestartFile: %s\tPos: %f\t%s\t%s" % vals ) ) return @@ -181,44 +181,44 @@ def main(): # argument to check results of pdldiff script # parser.add_option("-p", "--pdldiff", type = "string", dest = "pdldiff" ) - (options, args) = parser.parse_args() + ( options, args ) = parser.parse_args() # assert options.gnuplot - assert len(args) == 4 + assert len( args ) == 4 - base = args[0] - sourceDeck = args[1] - atsFile = args[2] - outdir = args[3] - assert os.path.exists(sourceDeck) - assert os.path.exists(atsFile) + base = args[ 0 ] + sourceDeck = args[ 1 ] + atsFile = args[ 2 ] + outdir = args[ 3 ] + assert os.path.exists( sourceDeck ) + assert os.path.exists( atsFile ) - if os.path.exists(outdir): + if os.path.exists( outdir ): try: - shutil.rmtree(outdir) + shutil.rmtree( outdir ) except: - logger.debug(f"Could not remove directory: {outdir}") + logger.debug( f"Could not remove directory: {outdir}" ) # make a directory try: - os.mkdir(outdir) + os.mkdir( outdir ) # copy in the input deck and other necessary files for running the problem - shutil.copy(sourceDeck, os.path.join(outdir, "%s.ain" % base)) - shutil.copy("leos1.05.h5", outdir) + shutil.copy( sourceDeck, os.path.join( outdir, "%s.ain" % base ) ) + shutil.copy( "leos1.05.h5", outdir ) except: - logger.debug(f"Could not create directory: {outdir}") + logger.debug( f"Could not create directory: {outdir}" ) # copy in the ats file template, replacing appropriate text as we go - outp = open(os.path.join(outdir, "%s.ats" % base), 'w') - inp = open(atsFile, 'r') + outp = open( os.path.join( outdir, "%s.ats" % base ), 'w' ) + inp = open( atsFile, 'r' ) for line in inp: - line = line.replace("BASE", base) - outp.write(line) + line = line.replace( "BASE", base ) + outp.write( line ) # sub = subprocess.call(['sed', 's/BASE/%s/'%base,atsFile],stdout=outp) inp.close() outp.close() - sys.exit(0) + sys.exit( 0 ) if __name__ == "__main__": diff --git a/geos_ats_package/geos_ats/scheduler.py b/geos_ats_package/geos_ats/scheduler.py index 108deab..51faa2f 100644 --- a/geos_ats_package/geos_ats/scheduler.py +++ b/geos_ats_package/geos_ats/scheduler.py @@ -4,45 +4,47 @@ import time from geos_ats.configuration_record import config from geos_ats.common_utilities import Log -from ats.log import log # type: ignore[import] -from ats.atsut import PASSED, FAILED, CREATED, EXPECTED, TIMEDOUT # type: ignore[import] -from ats.schedulers import StandardScheduler # type: ignore[import] +from ats.log import log # type: ignore[import] +from ats.atsut import PASSED, FAILED, CREATED, EXPECTED, TIMEDOUT # type: ignore[import] +from ats.schedulers import StandardScheduler # type: ignore[import] -class GeosAtsScheduler(StandardScheduler): +class GeosAtsScheduler( StandardScheduler ): """Custom scheduler for GeosATS""" name = "GeosATS Scheduler" - def testEnded(self, test): + def testEnded( self, test ): """Manage scheduling and reporting tasks for a test that ended. Log result for every test but only show certain ones on the terminal. Prune group list if a group is finished. """ - echo = self.verbose or (test.status not in (PASSED, EXPECTED)) + echo = self.verbose or ( test.status not in ( PASSED, EXPECTED ) ) g = test.group - n = len(g) + n = len( g ) msg = f"{test.status} #{test.serialNumber} {test.name} {test.message}" if n > 1: msg += f" Group {g.number} #{test.groupSerialNumber} of {n}" - log(msg, echo=echo) + log( msg, echo=echo ) - self.schedule(msg, time.asctime()) - self.removeBlock(test) + self.schedule( msg, time.asctime() ) + self.removeBlock( test ) if g.isFinished(): g.recordOutput() - if not hasattr(g, "retries"): + if not hasattr( g, "retries" ): g.retries = 0 - if test.status in [FAILED, TIMEDOUT] and g.retries < config.max_retry: - with open(test.geos_atsTestCase.errname) as f: + if test.status in [ FAILED, TIMEDOUT ] and g.retries < config.max_retry: + with open( test.geos_atsTestCase.errname ) as f: erroutput = f.read() - if re.search(config.retry_err_regexp, erroutput): + if re.search( config.retry_err_regexp, erroutput ): f.close() - os.rename(test.geos_atsTestCase.errname, "%s.%d" % (test.geos_atsTestCase.errname, g.retries)) - os.rename(test.geos_atsTestCase.outname, "%s.%d" % (test.geos_atsTestCase.outname, g.retries)) + os.rename( test.geos_atsTestCase.errname, + "%s.%d" % ( test.geos_atsTestCase.errname, g.retries ) ) + os.rename( test.geos_atsTestCase.outname, + "%s.%d" % ( test.geos_atsTestCase.outname, g.retries ) ) g.retries += 1 for t in g: t.status = CREATED - Log(f"# retry test={test.geos_atsTestCase.name} ({g.retries}/{config.max_retry})") + Log( f"# retry test={test.geos_atsTestCase.name} ({g.retries}/{config.max_retry})" ) return - self.groups.remove(g) + self.groups.remove( g ) diff --git a/geos_ats_package/geos_ats/suite_settings.py b/geos_ats_package/geos_ats/suite_settings.py index 824079f..c9a2e5e 100644 --- a/geos_ats_package/geos_ats/suite_settings.py +++ b/geos_ats_package/geos_ats/suite_settings.py @@ -3,39 +3,39 @@ testLabels = [ "geos", - "auto", # label used when the tests were automatically converted. Will be deprecated. + "auto", # label used when the tests were automatically converted. Will be deprecated. ] -testOwners = [("corbett5", "Ben Corbett")] -logger = logging.getLogger('geos_ats') +testOwners = [ ( "corbett5", "Ben Corbett" ) ] +logger = logging.getLogger( 'geos_ats' ) -def infoOwners(filename): - topic = common_utilities.InfoTopic("owners") +def infoOwners( filename ): + topic = common_utilities.InfoTopic( "owners" ) topic.startBanner() - owners = sorted(testOwners) + owners = sorted( testOwners ) - table = common_utilities.TextTable(2) + table = common_utilities.TextTable( 2 ) for o in owners: - table.addRow(o[0], o[1]) + table.addRow( o[ 0 ], o[ 1 ] ) table.printTable() - logger.info(f"The list can be found in: {filename}") + logger.info( f"The list can be found in: {filename}" ) topic.endBanner() -def infoLabels(filename): +def infoLabels( filename ): - topic = common_utilities.InfoTopic("labels") + topic = common_utilities.InfoTopic( "labels" ) topic.startBanner() - labels = sorted(testLabels[:]) + labels = sorted( testLabels[ : ] ) - logger.info("Test labels:") + logger.info( "Test labels:" ) for o in labels: - logger.info(f" {o}") + logger.info( f" {o}" ) - logger.info(f"The list can be found in: {filename}") + logger.info( f"The list can be found in: {filename}" ) topic.endBanner() diff --git a/geos_ats_package/geos_ats/test_builder.py b/geos_ats_package/geos_ats/test_builder.py index fe2fb45..83f9ca7 100644 --- a/geos_ats_package/geos_ats/test_builder.py +++ b/geos_ats_package/geos_ats/test_builder.py @@ -11,39 +11,39 @@ from .test_case import TestCase -@dataclass(frozen=True) +@dataclass( frozen=True ) class RestartcheckParameters: atol: float rtol: float - def as_dict(self): - return asdict(self) + def as_dict( self ): + return asdict( self ) -@dataclass(frozen=True) +@dataclass( frozen=True ) class CurveCheckParameters: filename: str - tolerance: Iterable[float] - curves: List[List[str]] - script_instructions: Iterable[Iterable[str]] = None + tolerance: Iterable[ float ] + curves: List[ List[ str ] ] + script_instructions: Iterable[ Iterable[ str ] ] = None time_units: str = "seconds" - def as_dict(self): - return asdict(self) + def as_dict( self ): + return asdict( self ) -@dataclass(frozen=True) +@dataclass( frozen=True ) class TestDeck: name: str description: str - partitions: Iterable[Tuple[int, int, int]] + partitions: Iterable[ Tuple[ int, int, int ] ] restart_step: int check_step: int restartcheck_params: RestartcheckParameters = None curvecheck_params: CurveCheckParameters = None -def collect_block_names(fname): +def collect_block_names( fname ): """ Collect block names in an xml file @@ -54,35 +54,35 @@ def collect_block_names(fname): dict: Pairs of top-level block names and lists of child block names """ pwd = os.getcwd() - actual_dir, actual_fname = os.path.split(os.path.realpath(fname)) - os.chdir(actual_dir) + actual_dir, actual_fname = os.path.split( os.path.realpath( fname ) ) + os.chdir( actual_dir ) # Collect the block names in this file results = {} - parser = etree.XMLParser(remove_comments=True) - tree = etree.parse(actual_fname, parser=parser) + parser = etree.XMLParser( remove_comments=True ) + tree = etree.parse( actual_fname, parser=parser ) root = tree.getroot() for child in root.getchildren(): - results[child.tag] = [grandchild.tag for grandchild in child.getchildren()] + results[ child.tag ] = [ grandchild.tag for grandchild in child.getchildren() ] # Collect block names in included files - for included_root in root.findall('Included'): - for included_file in included_root.findall('File'): - f = included_file.get('name') - child_results = collect_block_names(f) + for included_root in root.findall( 'Included' ): + for included_file in included_root.findall( 'File' ): + f = included_file.get( 'name' ) + child_results = collect_block_names( f ) for k, v in child_results.items(): if k in results: - results[k].extend(v) + results[ k ].extend( v ) else: - results[k] = v - os.chdir(pwd) + results[ k ] = v + os.chdir( pwd ) return results -def generate_geos_tests(decks: Iterable[TestDeck]): +def generate_geos_tests( decks: Iterable[ TestDeck ] ): """ """ - for ii, deck in enumerate(decks): + for ii, deck in enumerate( decks ): restartcheck_params = None curvecheck_params = None @@ -97,51 +97,51 @@ def generate_geos_tests(decks: Iterable[TestDeck]): nx, ny, nz = partition N = nx * ny * nz - testcase_name = "{}_{:02d}".format(deck.name, N) - base_name = "0to{:d}".format(deck.check_step) - xml_file = "{}.xml".format(deck.name) - xml_blocks = collect_block_names(xml_file) + testcase_name = "{}_{:02d}".format( deck.name, N ) + base_name = "0to{:d}".format( deck.check_step ) + xml_file = "{}.xml".format( deck.name ) + xml_blocks = collect_block_names( xml_file ) checks = [] if curvecheck_params: - checks.append('curve') + checks.append( 'curve' ) steps = [ - geos(deck=xml_file, - name=base_name, - np=N, - ngpu=N, - x_partitions=nx, - y_partitions=ny, - z_partitions=nz, - restartcheck_params=restartcheck_params, - curvecheck_params=curvecheck_params) + geos( deck=xml_file, + name=base_name, + np=N, + ngpu=N, + x_partitions=nx, + y_partitions=ny, + z_partitions=nz, + restartcheck_params=restartcheck_params, + curvecheck_params=curvecheck_params ) ] if deck.restart_step > 0: - checks.append('restart') + checks.append( 'restart' ) steps.append( - geos(deck=xml_file, - name="{:d}to{:d}".format(deck.restart_step, deck.check_step), - np=N, - ngpu=N, - x_partitions=nx, - y_partitions=ny, - z_partitions=nz, - restart_file=os.path.join(testcase_name, - "{}_restart_{:09d}".format(base_name, deck.restart_step)), - baseline_pattern=f"{base_name}_restart_[0-9]+\.root", - allow_rebaseline=False, - restartcheck_params=restartcheck_params)) - - AtsTest.stick(level=ii) - AtsTest.stick(checks=','.join(checks)) - AtsTest.stick(solvers=','.join(xml_blocks.get('Solvers', []))) - AtsTest.stick(outputs=','.join(xml_blocks.get('Outputs', []))) - AtsTest.stick(constitutive_models=','.join(xml_blocks.get('Constitutive', []))) - TestCase(name=testcase_name, - desc=deck.description, - label="auto", - owner="GEOS team", - independent=True, - steps=steps) + geos( deck=xml_file, + name="{:d}to{:d}".format( deck.restart_step, deck.check_step ), + np=N, + ngpu=N, + x_partitions=nx, + y_partitions=ny, + z_partitions=nz, + restart_file=os.path.join( testcase_name, + "{}_restart_{:09d}".format( base_name, deck.restart_step ) ), + baseline_pattern=f"{base_name}_restart_[0-9]+\.root", + allow_rebaseline=False, + restartcheck_params=restartcheck_params ) ) + + AtsTest.stick( level=ii ) + AtsTest.stick( checks=','.join( checks ) ) + AtsTest.stick( solvers=','.join( xml_blocks.get( 'Solvers', [] ) ) ) + AtsTest.stick( outputs=','.join( xml_blocks.get( 'Outputs', [] ) ) ) + AtsTest.stick( constitutive_models=','.join( xml_blocks.get( 'Constitutive', [] ) ) ) + TestCase( name=testcase_name, + desc=deck.description, + label="auto", + owner="GEOS team", + independent=True, + steps=steps ) diff --git a/geos_ats_package/geos_ats/test_case.py b/geos_ats_package/geos_ats/test_case.py index fa317d3..e3141b0 100644 --- a/geos_ats_package/geos_ats/test_case.py +++ b/geos_ats_package/geos_ats/test_case.py @@ -1,4 +1,4 @@ -import ats # type: ignore[import] +import ats # type: ignore[import] import os import sys import shutil @@ -17,98 +17,98 @@ TESTS = {} BASELINE_PATH = "baselines" -logger = logging.getLogger('geos_ats') +logger = logging.getLogger( 'geos_ats' ) -class Batch(object): +class Batch( object ): """A class to represent batch options""" - def __init__(self, enabled=True, duration="1h", ppn=0, altname=None): + def __init__( self, enabled=True, duration="1h", ppn=0, altname=None ): - if enabled not in (True, False): - Error("enabled must be a boolean") + if enabled not in ( True, False ): + Error( "enabled must be a boolean" ) self.enabled = enabled self.duration = duration try: - dur = ats.Duration(duration) + dur = ats.Duration( duration ) self.durationSeconds = dur.value except ats.AtsError as e: - logger.error(e) - Error("bad time specification: %s" % duration) + logger.error( e ) + Error( "bad time specification: %s" % duration ) - self.ppn = ppn # processor per node - self.altname = altname # alternate name to use when launcing the batch job + self.ppn = ppn # processor per node + self.altname = altname # alternate name to use when launcing the batch job -class TestCase(object): +class TestCase( object ): """Encapsulates one test case, which may include many steps""" - def __init__(self, name, desc, label=None, labels=None, steps=[], **kw): + def __init__( self, name, desc, label=None, labels=None, steps=[], **kw ): try: - self.initialize(name, desc, label, labels, steps, **kw) + self.initialize( name, desc, label, labels, steps, **kw ) except Exception as e: # make sure error messages get logged, then get out of here. - logging.error(e) - Log(str(e)) - raise Exception(e) + logging.error( e ) + Log( str( e ) ) + raise Exception( e ) - def initialize(self, name, desc, label=None, labels=None, steps=[], batch=Batch(enabled=False), **kw): + def initialize( self, name, desc, label=None, labels=None, steps=[], batch=Batch( enabled=False ), **kw ): self.name = name self.desc = desc self.batch = batch - action = ats.tests.AtsTest.getOptions().get("action") + action = ats.tests.AtsTest.getOptions().get( "action" ) - if kw.get("output_directory", False): - self.path = os.path.abspath(kw.get("output_directory")) + if kw.get( "output_directory", False ): + self.path = os.path.abspath( kw.get( "output_directory" ) ) else: - self.path = os.path.join(os.getcwd(), self.name) + self.path = os.path.join( os.getcwd(), self.name ) - self.dirname = os.path.basename(self.path) + self.dirname = os.path.basename( self.path ) try: - os.makedirs(self.path) + os.makedirs( self.path ) except OSError as e: - if e.errno == errno.EEXIST and os.path.isdir(self.path): + if e.errno == errno.EEXIST and os.path.isdir( self.path ): pass else: - logger.debug(e) + logger.debug( e ) raise Exception() self.atsGroup = None self.dictionary = {} - self.dictionary.update(kw) - self.nodoc = self.dictionary.get("nodoc", False) - self.statusFile = os.path.abspath("TestStatus_%s" % self.name) + self.dictionary.update( kw ) + self.nodoc = self.dictionary.get( "nodoc", False ) + self.statusFile = os.path.abspath( "TestStatus_%s" % self.name ) self.status = None - self.outname = os.path.join(self.path, "%s.data" % self.name) - self.errname = os.path.join(self.path, "%s.err" % self.name) - self.dictionary["name"] = self.name - self.dictionary["output_directory"] = self.path - self.dictionary["baseline_dir"] = os.path.join(os.getcwd(), BASELINE_PATH, self.dirname) - self.dictionary["testcase_out"] = self.outname - self.dictionary["testcase_err"] = self.errname - self.dictionary["testcase_name"] = self.name + self.outname = os.path.join( self.path, "%s.data" % self.name ) + self.errname = os.path.join( self.path, "%s.err" % self.name ) + self.dictionary[ "name" ] = self.name + self.dictionary[ "output_directory" ] = self.path + self.dictionary[ "baseline_dir" ] = os.path.join( os.getcwd(), BASELINE_PATH, self.dirname ) + self.dictionary[ "testcase_out" ] = self.outname + self.dictionary[ "testcase_err" ] = self.errname + self.dictionary[ "testcase_name" ] = self.name # check for test cases, testcases can either be the string # "all" or a list of full test names. - testcases = ats.tests.AtsTest.getOptions().get("testcases") + testcases = ats.tests.AtsTest.getOptions().get( "testcases" ) if testcases == "all": pass elif self.name in testcases: - testcases.remove(self.name) + testcases.remove( self.name ) pass else: return if self.name in TESTS: - Error("Name already in use: %s" % self.name) + Error( "Name already in use: %s" % self.name ) - TESTS[self.name] = self + TESTS[ self.name ] = self # check for independent if config.override_np > 0: @@ -117,48 +117,48 @@ def initialize(self, name, desc, label=None, labels=None, steps=[], batch=Batch( # number of processors. self.independent = False else: - self.independent = self.dictionary.get("independent", False) - if self.independent not in (True, False): - Error("independent must be either True or False: %s" % str(self.independent)) + self.independent = self.dictionary.get( "independent", False ) + if self.independent not in ( True, False ): + Error( "independent must be either True or False: %s" % str( self.independent ) ) # check for depends - self.depends = self.dictionary.get("depends", None) + self.depends = self.dictionary.get( "depends", None ) if self.depends == self.name: # This check avoid testcases depending on themselves. self.depends = None - self.handleLabels(label, labels) + self.handleLabels( label, labels ) # complete the steps. # 1. update the steps with data from the dictionary # 2. substeps are inserted into the list of steps (the steps are flattened) for step in steps: - step.update(self.dictionary) + step.update( self.dictionary ) self.steps = [] for step in steps: - step.insertStep(self.steps) + step.insertStep( self.steps ) # test modifier - modifier = test_modifier.Factory(config.testmodifier) - newSteps = modifier.modifySteps(self.steps, self.dictionary) + modifier = test_modifier.Factory( config.testmodifier ) + newSteps = modifier.modifySteps( self.steps, self.dictionary ) if newSteps: # insert the modified steps, including any extra steps that may have # been added by the modifier. self.steps = [] for step in newSteps: - step.insertStep(self.steps) + step.insertStep( self.steps ) for extraStep in step.extraSteps: - extraStep.insertStep(newSteps) + extraStep.insertStep( newSteps ) self.steps = newSteps else: - Log("# SKIP test=%s : testmodifier=%s" % (self.name, config.testmodifier)) + Log( "# SKIP test=%s : testmodifier=%s" % ( self.name, config.testmodifier ) ) self.status = reporting.SKIP return # Check for explicit skip flag - if action in ("run", "rerun", "continue"): - if self.dictionary.get("skip", None): + if action in ( "run", "rerun", "continue" ): + if self.dictionary.get( "skip", None ): self.status = reporting.SKIP return @@ -166,7 +166,7 @@ def initialize(self, name, desc, label=None, labels=None, steps=[], batch=Batch( npMax = self.findMaxNumberOfProcessors() if config.filter_maxprocessors != -1: if npMax > config.filter_maxprocessors: - Log("# FILTER test=%s : max processors(%d > %d)" % (self.name, npMax, config.filter_maxprocessors)) + Log( "# FILTER test=%s : max processors(%d > %d)" % ( self.name, npMax, config.filter_maxprocessors ) ) self.status = reporting.FILTERED return @@ -174,73 +174,74 @@ def initialize(self, name, desc, label=None, labels=None, steps=[], batch=Batch( ngpuMax = self.findMaxNumberOfGPUs() # filter based on not enough resources - if action in ("run", "rerun", "continue"): + if action in ( "run", "rerun", "continue" ): tests = [ - not ats.tests.AtsTest.getOptions().get("testmode"), not self.batch.enabled, - hasattr(ats.manager.machine, "getNumberOfProcessors") + not ats.tests.AtsTest.getOptions().get( "testmode" ), not self.batch.enabled, + hasattr( ats.manager.machine, "getNumberOfProcessors" ) ] - if all(tests): + if all( tests ): - totalNumberOfProcessors = getattr(ats.manager.machine, "getNumberOfProcessors")() + totalNumberOfProcessors = getattr( ats.manager.machine, "getNumberOfProcessors" )() if npMax > totalNumberOfProcessors: - Log("# SKIP test=%s : not enough processors to run (%d > %d)" % - (self.name, npMax, totalNumberOfProcessors)) + Log( "# SKIP test=%s : not enough processors to run (%d > %d)" % + ( self.name, npMax, totalNumberOfProcessors ) ) self.status = reporting.SKIP return # If the machine doesn't specify a number of GPUs then it has none. - totalNumberOfGPUs = getattr(ats.manager.machine, "getNumberOfGPUS", lambda: 1e90)() + totalNumberOfGPUs = getattr( ats.manager.machine, "getNumberOfGPUS", lambda: 1e90 )() if ngpuMax > totalNumberOfGPUs: - Log("# SKIP test=%s : not enough gpus to run (%d > %d)" % (self.name, ngpuMax, totalNumberOfGPUs)) + Log( "# SKIP test=%s : not enough gpus to run (%d > %d)" % + ( self.name, ngpuMax, totalNumberOfGPUs ) ) self.status = reporting.SKIP return # filtering test steps based on action - if action in ("run", "rerun", "continue"): - checkoption = ats.tests.AtsTest.getOptions().get("checkoption") + if action in ( "run", "rerun", "continue" ): + checkoption = ats.tests.AtsTest.getOptions().get( "checkoption" ) if checkoption == "none": - self.steps = [step for step in self.steps if not step.isCheck()] + self.steps = [ step for step in self.steps if not step.isCheck() ] elif action == "check": - self.steps = [step for step in self.steps if step.isCheck()] + self.steps = [ step for step in self.steps if step.isCheck() ] # move all the delayed steps to the end reorderedSteps = [] for step in self.steps: if not step.isDelayed(): - reorderedSteps.append(step) + reorderedSteps.append( step ) for step in self.steps: if step.isDelayed(): - reorderedSteps.append(step) + reorderedSteps.append( step ) self.steps = reorderedSteps # filter based on previous results: - if action in ("run", "check", "continue"): + if action in ( "run", "check", "continue" ): # read the status file - self.status = test_caseStatus(self) + self.status = test_caseStatus( self ) # if previously passed then skip if self.status.isPassed(): - Log("# SKIP test=%s (previously passed)" % (self.name)) + Log( "# SKIP test=%s (previously passed)" % ( self.name ) ) # don't set status here, as we want the report to reflect the pass return if action == "continue": if self.status.isFailed(): - Log("# SKIP test=%s (previously failed)" % (self.name)) + Log( "# SKIP test=%s (previously failed)" % ( self.name ) ) # don't set status here, as we want the report to reflect the pass return # Perform the action: - if action in ("run", "continue"): - Log("# run test=%s" % (self.name)) + if action in ( "run", "continue" ): + Log( "# run test=%s" % ( self.name ) ) self.testCreate() elif action == "rerun": - Log("# rerun test=%s" % (self.name)) + Log( "# rerun test=%s" % ( self.name ) ) self.testCreate() elif action == "check": - Log("# check test=%s" % (self.name)) + Log( "# check test=%s" % ( self.name ) ) self.testCreate() elif action == "commands": @@ -248,14 +249,14 @@ def initialize(self, name, desc, label=None, labels=None, steps=[], batch=Batch( elif action == "reset": if self.testReset(): - Log("# reset test=%s" % (self.name)) + Log( "# reset test=%s" % ( self.name ) ) elif action == "clean": - Log("# clean test=%s" % (self.name)) + Log( "# clean test=%s" % ( self.name ) ) self.testClean() elif action == "veryclean": - Log("# veryclean test=%s" % (self.name)) + Log( "# veryclean test=%s" % ( self.name ) ) self.testVeryClean() elif action == "rebaseline": @@ -267,162 +268,162 @@ def initialize(self, name, desc, label=None, labels=None, steps=[], batch=Batch( elif action == "list": self.testList() - elif action in ("report"): + elif action in ( "report" ): self.testReport() else: - Error("Unknown action?? %s" % action) + Error( "Unknown action?? %s" % action ) - def resultPaths(self, step=None): + def resultPaths( self, step=None ): """Return the paths to output files for the testcase. Used in reporting""" - paths = [self.outname, self.errname] + paths = [ self.outname, self.errname ] if step: for x in step.resultPaths(): - fullpath = os.path.join(self.path, x) - if os.path.exists(fullpath): - paths.append(fullpath) + fullpath = os.path.join( self.path, x ) + if os.path.exists( fullpath ): + paths.append( fullpath ) return paths - def testReset(self): - self.status = test_caseStatus(self) + def testReset( self ): + self.status = test_caseStatus( self ) ret = self.status.resetFailed() self.status.writeStatusFile() return ret - def testClean(self): - if os.path.exists(self.statusFile): - os.remove(self.statusFile) - if os.path.exists(self.outname): - os.remove(self.outname) - if os.path.exists(self.errname): - os.remove(self.errname) + def testClean( self ): + if os.path.exists( self.statusFile ): + os.remove( self.statusFile ) + if os.path.exists( self.outname ): + os.remove( self.outname ) + if os.path.exists( self.errname ): + os.remove( self.errname ) for step in self.steps: step.clean() - def testVeryClean(self): + def testVeryClean( self ): - def _remove(path): - delpaths = glob.glob(path) + def _remove( path ): + delpaths = glob.glob( path ) for p in delpaths: - if os.path.exists(p): + if os.path.exists( p ): try: - if os.path.isdir(p): - shutil.rmtree(p) + if os.path.isdir( p ): + shutil.rmtree( p ) else: - os.remove(p) + os.remove( p ) except OSError: - pass # so that two simultaneous clean operations don't fail + pass # so that two simultaneous clean operations don't fail # clean self.testClean() # remove log directories - removeLogDirectories(os.getcwd()) + removeLogDirectories( os.getcwd() ) # remove extra files - if len(self.steps) > 0: - _remove(config.report_html_file) - _remove(config.report_text_file) - _remove(self.path) - _remove("*.core") - _remove("core") - _remove("core.*") - _remove("vgcore.*") - _remove("*.btr") - _remove("TestLogs*") - _remove("*.ini") - - def findMaxNumberOfProcessors(self): + if len( self.steps ) > 0: + _remove( config.report_html_file ) + _remove( config.report_text_file ) + _remove( self.path ) + _remove( "*.core" ) + _remove( "core" ) + _remove( "core.*" ) + _remove( "vgcore.*" ) + _remove( "*.btr" ) + _remove( "TestLogs*" ) + _remove( "*.ini" ) + + def findMaxNumberOfProcessors( self ): npMax = 1 for step in self.steps: - np = getattr(step.p, "np", 1) - npMax = max(np, npMax) + np = getattr( step.p, "np", 1 ) + npMax = max( np, npMax ) return npMax - def findMaxNumberOfGPUs(self): + def findMaxNumberOfGPUs( self ): gpuMax = 0 for step in self.steps: - ngpu = getattr(step.p, "ngpu", 0) * getattr(step.p, "np", 1) - gpuMax = max(ngpu, gpuMax) + ngpu = getattr( step.p, "ngpu", 0 ) * getattr( step.p, "np", 1 ) + gpuMax = max( ngpu, gpuMax ) return gpuMax - def testCreate(self): + def testCreate( self ): atsTest = None - keep = ats.tests.AtsTest.getOptions().get("keep") + keep = ats.tests.AtsTest.getOptions().get( "keep" ) # remove outname - if os.path.exists(self.outname): - os.remove(self.outname) - if os.path.exists(self.errname): - os.remove(self.errname) + if os.path.exists( self.outname ): + os.remove( self.outname ) + if os.path.exists( self.errname ): + os.remove( self.errname ) # create the status file if self.status is None: - self.status = test_caseStatus(self) + self.status = test_caseStatus( self ) maxnp = 1 - for stepnum, step in enumerate(self.steps): - np = getattr(step.p, "np", 1) - maxnp = max(np, maxnp) + for stepnum, step in enumerate( self.steps ): + np = getattr( step.p, "np", 1 ) + maxnp = max( np, maxnp ) if config.priority == "processors": priority = maxnp elif config.priority == "timing": - priority = max(globalTestTimings.get(self.name, 1) * maxnp, 1) + priority = max( globalTestTimings.get( self.name, 1 ) * maxnp, 1 ) else: priority = 1 # start a group - ats.tests.AtsTest.newGroup(priority=priority) + ats.tests.AtsTest.newGroup( priority=priority ) # keep a reference to the ats test group self.atsGroup = ats.tests.AtsTest.group # if depends if self.depends: - priorTestCase = TESTS.get(self.depends, None) + priorTestCase = TESTS.get( self.depends, None ) if priorTestCase is None: - Log("Warning: Test %s depends on testcase %s, which is not scheduled to run" % - (self.name, self.depends)) + Log( "Warning: Test %s depends on testcase %s, which is not scheduled to run" % + ( self.name, self.depends ) ) else: if priorTestCase.steps: - atsTest = getattr(priorTestCase.steps[-1], "atsTest", None) + atsTest = getattr( priorTestCase.steps[ -1 ], "atsTest", None ) - for stepnum, step in enumerate(self.steps): + for stepnum, step in enumerate( self.steps ): - np = getattr(step.p, "np", 1) - ngpu = getattr(step.p, "ngpu", 0) + np = getattr( step.p, "np", 1 ) + ngpu = getattr( step.p, "ngpu", 0 ) executable = step.executable() args = step.makeArgs() # set the label - label = "%s/%s_%d_%s" % (self.dirname, self.name, stepnum + 1, step.label()) + label = "%s/%s_%d_%s" % ( self.dirname, self.name, stepnum + 1, step.label() ) # call either 'test' or 'testif' if atsTest is None: - func = lambda *a, **k: test(*a, **k) + func = lambda *a, **k: test( *a, **k ) else: - func = lambda *a, **k: testif(atsTest, *a, **k) + func = lambda *a, **k: testif( atsTest, *a, **k ) # timelimit kw = {} if self.batch.enabled: - kw["timelimit"] = self.batch.duration + kw[ "timelimit" ] = self.batch.duration - if (step.timelimit() and not config.override_timelimit): - kw["timelimit"] = step.timelimit() + if ( step.timelimit() and not config.override_timelimit ): + kw[ "timelimit" ] = step.timelimit() else: - kw["timelimit"] = config.default_timelimit - - atsTest = func(executable=executable, - clas=args, - np=np, - ngpu=ngpu, - label=label, - serial=(not step.useMPI() and not config.script_launch), - independent=self.independent, - batch=self.batch.enabled, - **kw) + kw[ "timelimit" ] = config.default_timelimit + + atsTest = func( executable=executable, + clas=args, + np=np, + ngpu=ngpu, + label=label, + serial=( not step.useMPI() and not config.script_launch ), + independent=self.independent, + batch=self.batch.enabled, + **kw ) # ats test gets a reference to the TestStep and the TestCase atsTest.geos_atsTestCase = self @@ -432,14 +433,14 @@ def testCreate(self): step.atsTest = atsTest # Add the step the test status object - self.status.addStep(atsTest) + self.status.addStep( atsTest ) # set the expected result if step.expectedResult() == "FAIL" or step.expectedResult() is False: atsTest.expectedResult = ats.FAILED # The ATS does not permit tests to depend on failed tests. # therefore we need to break here - self.steps = self.steps[:stepnum + 1] + self.steps = self.steps[ :stepnum + 1 ] break # end the group @@ -448,42 +449,43 @@ def testCreate(self): self.status.resetFailed() self.status.writeStatusFile() - def commandLine(self, step): + def commandLine( self, step ): args = [] executable = step.executable() commandArgs = step.makeArgs() - assert isinstance(commandArgs, list) + assert isinstance( commandArgs, list ) for a in commandArgs: if " " in a: - args.append('"%s"' % a) + args.append( '"%s"' % a ) else: - args.append(a) + args.append( a ) - argsstr = " ".join(args) + argsstr = " ".join( args ) return executable + " " + argsstr - def testCommands(self): - Log("\n# commands test=%s" % (self.name)) + def testCommands( self ): + Log( "\n# commands test=%s" % ( self.name ) ) for step in self.steps: - np = getattr(step.p, "np", 1) + np = getattr( step.p, "np", 1 ) usempi = step.useMPI() - stdout = getattr(step.p, "stdout", None) - commandline = self.commandLine(step).replace('%%', '%') + stdout = getattr( step.p, "stdout", None ) + commandline = self.commandLine( step ).replace( '%%', '%' ) if stdout: - Log("np=%d %s > %s" % (np, commandline, stdout)) + Log( "np=%d %s > %s" % ( np, commandline, stdout ) ) else: - Log("np=%d %s" % (np, commandline)) + Log( "np=%d %s" % ( np, commandline ) ) - def testRebaseline(self): + def testRebaseline( self ): rebaseline = True if config.rebaseline_ask: while 1: if config.rebaseline_undo: - logger.info(f"Are you sure you want to undo the rebaseline for TestCase '{self.name}'?", flush=True) + logger.info( f"Are you sure you want to undo the rebaseline for TestCase '{self.name}'?", + flush=True ) else: - logger.info(f"Are you sure you want to rebaseline TestCase '{self.name}'?", flush=True) + logger.info( f"Are you sure you want to rebaseline TestCase '{self.name}'?", flush=True ) - x = input('[y/n] ') + x = input( '[y/n] ' ) x = x.strip() if x == "y": break @@ -491,31 +493,31 @@ def testRebaseline(self): rebaseline = False break else: - Log("\n# rebaseline test=%s" % (self.name)) + Log( "\n# rebaseline test=%s" % ( self.name ) ) if rebaseline: for step in self.steps: step.rebaseline() - def testRebaselineFailed(self): + def testRebaselineFailed( self ): config.rebaseline_ask = False - self.status = test_caseStatus(self) + self.status = test_caseStatus( self ) if self.status.isFailed(): self.testRebaseline() - def testList(self): - Log("# test=%s : labels=%s" % (self.name.ljust(32), " ".join(self.labels))) + def testList( self ): + Log( "# test=%s : labels=%s" % ( self.name.ljust( 32 ), " ".join( self.labels ) ) ) - def testReport(self): - self.status = test_caseStatus(self) + def testReport( self ): + self.status = test_caseStatus( self ) - def handleLabels(self, label, labels): + def handleLabels( self, label, labels ): """set the labels, and verify they are known to the system, the avoid typos""" if labels is not None and label is not None: - Error("specify only one of 'label' or 'labels'") + Error( "specify only one of 'label' or 'labels'" ) if label is not None: - self.labels = [label] + self.labels = [ label ] elif labels is not None: self.labels = labels else: @@ -523,66 +525,66 @@ def handleLabels(self, label, labels): for x in self.labels: if x not in testLabels: - Error(f"unknown label {x}. run 'geos_ats -i labels' for a list") + Error( f"unknown label {x}. run 'geos_ats -i labels' for a list" ) -class test_caseStatus(object): +class test_caseStatus( object ): - def __init__(self, testCase): + def __init__( self, testCase ): self.testCase = testCase self.statusFile = self.testCase.statusFile self.readStatusFile() - def readStatusFile(self): - if os.path.exists(self.statusFile): - f = open(self.statusFile, "r") - self.status = [eval(x.strip()) for x in f.readlines()] + def readStatusFile( self ): + if os.path.exists( self.statusFile ): + f = open( self.statusFile, "r" ) + self.status = [ eval( x.strip() ) for x in f.readlines() ] f.close() else: self.status = [] - def writeStatusFile(self): + def writeStatusFile( self ): assert self.status is not None - with open(self.statusFile, "w") as f: - f.writelines([str(s) + '\n' for s in self.status]) + with open( self.statusFile, "w" ) as f: + f.writelines( [ str( s ) + '\n' for s in self.status ] ) - def testKey(self, step): - np = getattr(step.p, "np", 1) - key = str((np, step.label(), step.executable(), step.makeArgsForStatusKey())) + def testKey( self, step ): + np = getattr( step.p, "np", 1 ) + key = str( ( np, step.label(), step.executable(), step.makeArgsForStatusKey() ) ) return key - def testData(self, test): - key = self.testKey(test.geos_atsTestStep) + def testData( self, test ): + key = self.testKey( test.geos_atsTestStep ) result = test.status if result == ats.PASSED and test.expectedResult == ats.FAILED: result = ats.FAILED - endTime = getattr(test, "endTime", None) - startTime = getattr(test, "startTime", None) + endTime = getattr( test, "endTime", None ) + startTime = getattr( test, "startTime", None ) data = {} - data["key"] = key - data["result"] = str(result) - data["startTime"] = startTime - data["endTime"] = endTime + data[ "key" ] = key + data[ "result" ] = str( result ) + data[ "startTime" ] = startTime + data[ "endTime" ] = endTime return key, data - def findStep(self, step): - key = self.testKey(step) + def findStep( self, step ): + key = self.testKey( step ) for s in self.status: - if key in s["key"]: + if key in s[ "key" ]: return s return None - def isPassed(self): + def isPassed( self ): for step in self.testCase.steps: - status = self.findStep(step) + status = self.findStep( step ) if status: - if status["result"] == "EXPT": + if status[ "result" ] == "EXPT": # do not continue after an expected fail return True - elif status["result"] == "PASS": + elif status[ "result" ] == "PASS": continue else: return False @@ -590,16 +592,16 @@ def isPassed(self): return False return True - def isFailed(self): + def isFailed( self ): for step in self.testCase.steps: - status = self.findStep(step) + status = self.findStep( step ) if status: - if status["result"] == "EXPT": + if status[ "result" ] == "EXPT": # do not continue after an expected fail return False - elif status["result"] == "PASS": + elif status[ "result" ] == "PASS": continue - elif status["result"] == "FAIL": + elif status[ "result" ] == "FAIL": return True else: return False @@ -607,76 +609,76 @@ def isFailed(self): return False return False - def resetFailed(self): + def resetFailed( self ): ret = False for step in self.testCase.steps: - status = self.findStep(step) + status = self.findStep( step ) if status: - if status["result"] == "EXPT": + if status[ "result" ] == "EXPT": # do not continue after an expected fail - status["result"] = "INIT" + status[ "result" ] = "INIT" ret = True - elif status["result"] == "FAIL": - status["result"] = "INIT" + elif status[ "result" ] == "FAIL": + status[ "result" ] = "INIT" ret = True else: continue return ret - def totalTime(self): + def totalTime( self ): total = 0.0 for step in self.testCase.steps: - status = self.findStep(step) + status = self.findStep( step ) if status: - steptime = status["endTime"] - status["startTime"] + steptime = status[ "endTime" ] - status[ "startTime" ] assert steptime >= 0 total += steptime return total - def addStep(self, test): - key, data = self.testData(test) + def addStep( self, test ): + key, data = self.testData( test ) found = False for s in self.status: - if key == s["key"]: + if key == s[ "key" ]: found = True break if not found: - self.status.append(data) + self.status.append( data ) - def noteEnd(self, test): + def noteEnd( self, test ): """Update the TestStatus file for this test case""" # update the status - key, data = self.testData(test) + key, data = self.testData( test ) self.readStatusFile() found = False - for i, s in enumerate(self.status): - if key in s["key"]: - self.status[i] = data + for i, s in enumerate( self.status ): + if key in s[ "key" ]: + self.status[ i ] = data found = True break if not found: - logger.warning(f"NOT FOUND: {key} {self.statusFile}") + logger.warning( f"NOT FOUND: {key} {self.statusFile}" ) assert found self.writeStatusFile() # append to stdout/stderr file - for stream in ("outname", "errname"): - sourceFile = getattr(test, stream) - dataFile = getattr(self.testCase, stream) + for stream in ( "outname", "errname" ): + sourceFile = getattr( test, stream ) + dataFile = getattr( self.testCase, stream ) - if not os.path.exists(sourceFile): + if not os.path.exists( sourceFile ): continue # Append to the TestCase files - f1 = open(dataFile, "a") - f2 = open(sourceFile, "r") - f1.write(":" * 20 + "\n") - f1.write(self.testCase.commandLine(test.geos_atsTestStep) + "\n") - f1.write(":" * 20 + "\n") - f1.write(f2.read()) + f1 = open( dataFile, "a" ) + f2 = open( sourceFile, "r" ) + f1.write( ":" * 20 + "\n" ) + f1.write( self.testCase.commandLine( test.geos_atsTestStep ) + "\n" ) + f1.write( ":" * 20 + "\n" ) + f1.write( f2.read() ) f1.close() f2.close() @@ -687,37 +689,37 @@ def noteEnd(self, test): destFile = test.geos_atsTestStep.saveErr() if destFile: - destFile = os.path.join(self.testCase.path, destFile) - shutil.copy(sourceFile, destFile) + destFile = os.path.join( self.testCase.path, destFile ) + shutil.copy( sourceFile, destFile ) # If this is the last step (and it passed), clean the temporary files if config.clean_on_pass: - lastStep = (test.geos_atsTestStep is self.testCase.steps[-1]) + lastStep = ( test.geos_atsTestStep is self.testCase.steps[ -1 ] ) if lastStep and self.isPassed(): for step in self.testCase.steps: step.clean() -def infoTestCase(*args): +def infoTestCase( *args ): """This function is used to print documentation about the testcase""" - topic = InfoTopic("testcase") + topic = InfoTopic( "testcase" ) topic.startBanner() - logger.info("Required parameters") - table = TextTable(3) - table.addRow("name", "required", "The name of the test problem") - table.addRow("desc", "required", "A brief description") - table.addRow("label", "required", "A string or sequence of strings to tag the TestCase. See info topic 'labels'") - table.addRow("owner", "optional", - "A string or sequence of strings of test owners for this TestCase. See info topic 'owners'") + logger.info( "Required parameters" ) + table = TextTable( 3 ) + table.addRow( "name", "required", "The name of the test problem" ) + table.addRow( "desc", "required", "A brief description" ) + table.addRow( "label", "required", "A string or sequence of strings to tag the TestCase. See info topic 'labels'" ) + table.addRow( "owner", "optional", + "A string or sequence of strings of test owners for this TestCase. See info topic 'owners'" ) table.addRow( "batch", "optional", "A Batch object. Batch(enabled=True, duration='1h', ppn=0, altname=None)." " ppn is short for processors per node (0 means to use the global default)." " altname will be used for the batch job's name if supplied, otherwise the full name of the test case is used." ), - table.addRow("depends", "optional", "The name of a testcase that this testcase depends") - table.addRow("steps", "required", "A sequence of TestSteps objects. See info topic 'teststeps'") + table.addRow( "depends", "optional", "The name of a testcase that this testcase depends" ) + table.addRow( "steps", "required", "A sequence of TestSteps objects. See info topic 'teststeps'" ) table.printTable() @@ -725,5 +727,5 @@ def infoTestCase(*args): # Make available to the tests -ats.manager.define(TestCase=TestCase) -ats.manager.define(Batch=Batch) +ats.manager.define( TestCase=TestCase ) +ats.manager.define( Batch=Batch ) diff --git a/geos_ats_package/geos_ats/test_modifier.py b/geos_ats_package/geos_ats/test_modifier.py index c6d7cb2..11187b2 100644 --- a/geos_ats_package/geos_ats/test_modifier.py +++ b/geos_ats_package/geos_ats/test_modifier.py @@ -1,73 +1,73 @@ -import ats # type: ignore[import] +import ats # type: ignore[import] from geos_ats import common_utilities from geos_ats.configuration_record import config from geos_ats import test_steps import os import logging -logger = logging.getLogger('geos_ats') +logger = logging.getLogger( 'geos_ats' ) -class TestModifier(object): +class TestModifier( object ): """Base class for test modifiers. It modifies the steps of a test_case to create a new test""" - def modifySteps(self, originalSteps, dictionary): + def modifySteps( self, originalSteps, dictionary ): """Overload this function to generate a new sequence of steps from the existing sequence of steps""" return originalSteps -class TestModifierDefault(TestModifier): +class TestModifierDefault( TestModifier ): label = "default" doc = "Default test modifier: Add a step to check stdout for info and critical warning" - def modifySteps(self, steps, dictionary): + def modifySteps( self, steps, dictionary ): return steps -def Factory(name): +def Factory( name ): """Function that returns the correct TestModifier based on the name""" if not name: return TestModifierDefault() for k, v in globals().items(): - if not isinstance(v, type): + if not isinstance( v, type ): continue if v == TestModifier: continue try: - if issubclass(v, TestModifier): + if issubclass( v, TestModifier ): if v.label == name: return v() except TypeError as e: - logger.debug(e) + logger.debug( e ) - common_utilities.Error("Unknown test modifier: %s" % name) + common_utilities.Error( "Unknown test modifier: %s" % name ) -def infoTestModifier(*args): +def infoTestModifier( *args ): modifiers = [] for k, v in globals().items(): - if not isinstance(v, type): + if not isinstance( v, type ): continue if v == TestModifier: continue try: - if issubclass(v, TestModifier): - modifiers.append(k) + if issubclass( v, TestModifier ): + modifiers.append( k ) except TypeError as e: - logger.debug(e) + logger.debug( e ) - modifiers = sorted(modifiers) + modifiers = sorted( modifiers ) - topic = common_utilities.InfoTopic("test modifiers") + topic = common_utilities.InfoTopic( "test modifiers" ) topic.startBanner() - table = common_utilities.TextTable(2) + table = common_utilities.TextTable( 2 ) for m in modifiers: - mclass = globals()[m] - doc = getattr(mclass, "doc", None) - label = getattr(mclass, "label", None) - table.addRow(label, doc) + mclass = globals()[ m ] + doc = getattr( mclass, "doc", None ) + label = getattr( mclass, "label", None ) + table.addRow( label, doc ) table.printTable() topic.endBanner() diff --git a/geos_ats_package/geos_ats/test_steps.py b/geos_ats_package/geos_ats/test_steps.py index 129e162..57d39ac 100644 --- a/geos_ats_package/geos_ats/test_steps.py +++ b/geos_ats_package/geos_ats/test_steps.py @@ -1,5 +1,5 @@ import os -import ats # type: ignore[import] +import ats # type: ignore[import] import glob import shutil import sys @@ -11,10 +11,10 @@ from geos_ats.common_utilities import Error, Log from geos_ats.configuration_record import config -logger = logging.getLogger('geos_ats') +logger = logging.getLogger( 'geos_ats' ) -def getGeosProblemName(deck, name): +def getGeosProblemName( deck, name ): """ Given an input deck and a name return the prefix Geos will attatch to it's output files. @@ -22,15 +22,15 @@ def getGeosProblemName(deck, name): NAME [in]: The name given to Geos on the command line. """ if name is None: - if deck.endswith(".xml"): - return os.path.basename(deck)[:-4] + if deck.endswith( ".xml" ): + return os.path.basename( deck )[ :-4 ] else: - return os.path.basename(deck) + return os.path.basename( deck ) else: return name -def findMaxMatchingFile(file_path): +def findMaxMatchingFile( file_path ): """ Given a path FILE_PATH where the base name of FILE_PATH is treated as a regular expression find and return the path of the greatest matching file/folder or None if no match is found. @@ -43,31 +43,31 @@ def findMaxMatchingFile(file_path): "test/plot_*.hdf5" will return the file with the greatest name in the ./test directory that begins with "plot_" and ends with ".hdf5". """ - file_directory, pattern = os.path.split(file_path) + file_directory, pattern = os.path.split( file_path ) if file_directory == "": file_directory = "." - if not os.path.isdir(file_directory): + if not os.path.isdir( file_directory ): return None max_match = '' - pattern = re.compile(pattern) - for file in os.listdir(file_directory): - if pattern.match(file) is not None: - max_match = max(file, max_match) + pattern = re.compile( pattern ) + for file in os.listdir( file_directory ): + if pattern.match( file ) is not None: + max_match = max( file, max_match ) if not max_match: return None - return os.path.join(file_directory, max_match) + return os.path.join( file_directory, max_match ) -class TestParam(object): +class TestParam( object ): """ A class that describes a parameter of a test step. """ - def __init__(self, name, doc, default=None): + def __init__( self, name, doc, default=None ): self.name = name self.doc = doc self.default = default @@ -78,7 +78,7 @@ def __init__(self, name, doc, default=None): ################################################################################ -class TestStepBase(object): +class TestStepBase( object ): """ The base clase for a test step. """ @@ -87,65 +87,65 @@ class TestStepBase(object): TestParam( "clean", "additional files to remove during the clean action." " clean may be a string or a list of strings. The strings may contain" - " wildcard characters."), + " wildcard characters." ), TestParam( "timelimit", "maximum time the step is allowed to run before it is considerend a TIMEOUT." - " Specified as a string such as: 1h30m, 60m, etc.", "None"), - TestParam("stdout", "If set, the stdout will be placed in the named file, in the TestCase directory", None), - TestParam("stderr", "If set, the stderr will be placed in the named file, in the TestCase directory", None), - TestParam("expectedResult", "'PASS' or 'FAIL'", "'PASS'"), - TestParam("delayed", "Whether execution of the step will be delayed", "False"), - TestParam("minor", "Whether failure of this step is minor issue", "False"), + " Specified as a string such as: 1h30m, 60m, etc.", "None" ), + TestParam( "stdout", "If set, the stdout will be placed in the named file, in the TestCase directory", None ), + TestParam( "stderr", "If set, the stderr will be placed in the named file, in the TestCase directory", None ), + TestParam( "expectedResult", "'PASS' or 'FAIL'", "'PASS'" ), + TestParam( "delayed", "Whether execution of the step will be delayed", "False" ), + TestParam( "minor", "Whether failure of this step is minor issue", "False" ), ) commonParams = { "name": - TestParam("name", "Used to give other params default values.", "The name of the TestCase"), + TestParam( "name", "Used to give other params default values.", "The name of the TestCase" ), "deck": - TestParam("deck", "Name of the input file. Setting deck to False means no deck is used.", ".in"), + TestParam( "deck", "Name of the input file. Setting deck to False means no deck is used.", ".in" ), "np": - TestParam("np", "The number of processors to run on.", 1), + TestParam( "np", "The number of processors to run on.", 1 ), "ngpu": - TestParam("ngpu", "The number of gpus to run on when available.", 0), + TestParam( "ngpu", "The number of gpus to run on when available.", 0 ), "check": TestParam( "check", "True or False. determines whether the default checksteps will " - "be automatically be added after this step.", "True"), + "be automatically be added after this step.", "True" ), "baseline_dir": - TestParam("baseline_dir", "subdirectory of config.testbaseline_dir where the test " - "baselines are located.", ""), + TestParam( "baseline_dir", "subdirectory of config.testbaseline_dir where the test " + "baselines are located.", "" ), "output_directory": - TestParam("output_directory", "subdirectory where the test log, params, rin, and " - "timehistory files are located.", ""), + TestParam( "output_directory", "subdirectory where the test log, params, rin, and " + "timehistory files are located.", "" ), "rebaseline": TestParam( "rebaseline", "additional files to rebaseline during the rebaseline action." - " rebaseline may be a string or a list of strings."), + " rebaseline may be a string or a list of strings." ), "timehistfile": - TestParam("timehistfile", "name of the file containing all the" - " timehist curves.", "testmode..ul"), + TestParam( "timehistfile", "name of the file containing all the" + " timehist curves.", "testmode..ul" ), "basetimehistfile": - TestParam("basetimehistfile", "location to the baseline timehistfile", - "//"), + TestParam( "basetimehistfile", "location to the baseline timehistfile", + "//" ), "allow_rebaseline": TestParam( "allow_rebaseline", "True if the second file should be re-baselined during a rebaseline action." - " False if the second file should not be rebaselined.", "True"), + " False if the second file should not be rebaselined.", "True" ), "testcase_name": - TestParam("testcase_name", "The name of the testcase"), + TestParam( "testcase_name", "The name of the testcase" ), "testcase_out": - TestParam("testcase_out", "The file where stdout for the testcase is accumulated"), + TestParam( "testcase_out", "The file where stdout for the testcase is accumulated" ), } # namespace to place the params. - class Params(object): + class Params( object ): pass - def __init__(self): + def __init__( self ): self.p = TestStepBase.Params() self.extraSteps = [] - def setParams(self, dictionary, paramlist): + def setParams( self, dictionary, paramlist ): """ Given a list of parameters PARAMLIST and a DICTIONARY set the parameters in PARAMLIST that are also in DICTIONARY but do not yet have a value. @@ -155,38 +155,38 @@ def setParams(self, dictionary, paramlist): """ for p in paramlist: pname = p.name - if getattr(self.p, pname, None) is None: - setattr(self.p, pname, dictionary.get(pname, None)) + if getattr( self.p, pname, None ) is None: + setattr( self.p, pname, dictionary.get( pname, None ) ) - def requireParam(self, param): + def requireParam( self, param ): """ Require that the given parameter is defined and not None. PARAM [in]: The name of the parameter to check. """ - if not hasattr(self.p, param): - Error("%s must be given" % param) - if getattr(self.p, param) is None: - Error("%s must not be None" % param) + if not hasattr( self.p, param ): + Error( "%s must be given" % param ) + if getattr( self.p, param ) is None: + Error( "%s must not be None" % param ) - def insertStep(self, steps): + def insertStep( self, steps ): """ Insert into the list of steps STEPS. STEPS [in/out]: The list of steps to insert into. """ - steps.append(self) + steps.append( self ) - def makeArgs(self): + def makeArgs( self ): """ Return the command line arguments for this step. """ - raise Error("Must implement this") + raise Error( "Must implement this" ) - def makeArgsForStatusKey(self): + def makeArgsForStatusKey( self ): return self.makeArgs() - def setStdout(self, dictionary): + def setStdout( self, dictionary ): """ Generate a unique stdout file using DICTIONARY. @@ -197,26 +197,26 @@ def setStdout(self, dictionary): self.p.stdout = stepname + "." + self.label() + ".out" if self.p.stdout in dictionary: - Log("Non-unique name for stdout file: %s" % self.p.stdout) + Log( "Non-unique name for stdout file: %s" % self.p.stdout ) else: - dictionary[self.p.stdout] = 1 + dictionary[ self.p.stdout ] = 1 - def update(self, dictionary): + def update( self, dictionary ): """ Update parameters using DICTIONARY. All parameters which already have values are not updated. Called by the owning TestCase to pass along it's arguments. DICTIONARY [in]: The dictionary used to update the parameters. """ - raise Error("Must implement this") + raise Error( "Must implement this" ) - def clean(self): + def clean( self ): """ Remove files generated by this test step. """ - self._clean([]) + self._clean( [] ) - def _clean(self, paths, noclean=[]): + def _clean( self, paths, noclean=[] ): """ Delete files/folders in PATHS and self.p.clean as well as stdout and stderr but not in NOCLEAN. Paths to delete can have wildcard characters '*'. @@ -224,21 +224,21 @@ def _clean(self, paths, noclean=[]): PATHS [in]: Paths to remove, can have wildcard characters. NOCLEAN [in]: Paths to ignore, can not have wildcard characters. """ - self._remove(paths, noclean) + self._remove( paths, noclean ) - if hasattr(self.p, "clean"): + if hasattr( self.p, "clean" ): if self.p.clean is not None: - self._remove(self.p.clean, noclean) - if hasattr(self.p, "stdout"): + self._remove( self.p.clean, noclean ) + if hasattr( self.p, "stdout" ): if self.p.stdout is not None: - self._remove(self.p.stdout, noclean) - self._remove("%s.*" % self.p.stdout, noclean) - if hasattr(self.p, "stderr"): + self._remove( self.p.stdout, noclean ) + self._remove( "%s.*" % self.p.stdout, noclean ) + if hasattr( self.p, "stderr" ): if self.p.stderr is not None: - self._remove(self.p.stderr, noclean) - self._remove("%s.*" % self.p.stderr, noclean) + self._remove( self.p.stderr, noclean ) + self._remove( "%s.*" % self.p.stderr, noclean ) - def _remove(self, paths, noclean): + def _remove( self, paths, noclean ): """ Delete files/folders in PATHS but not in NOCLEAN. Paths to delete can have wildcard characters '*'. @@ -246,100 +246,100 @@ def _remove(self, paths, noclean): PATHS [in]: Paths to remove, can have wildcard characters. NOCLEAN [in]: Paths to ignore, can not have wildcard characters. """ - if isinstance(paths, str): - paths = [paths] + if isinstance( paths, str ): + paths = [ paths ] for path in paths: if self.getTestMode(): - Log("clean: %s" % path) + Log( "clean: %s" % path ) else: - delpaths = glob.glob(path) + delpaths = glob.glob( path ) for p in delpaths: if p in noclean: continue try: - if os.path.isdir(p): - shutil.rmtree(p) + if os.path.isdir( p ): + shutil.rmtree( p ) else: - os.remove(p) + os.remove( p ) except OSError as e: - logger.debug(e) # so that two simultaneous clean operations don't fail + logger.debug( e ) # so that two simultaneous clean operations don't fail - def getCheckOption(self): - return ats.tests.AtsTest.getOptions().get("checkoption") + def getCheckOption( self ): + return ats.tests.AtsTest.getOptions().get( "checkoption" ) - def getTestMode(self): - return ats.tests.AtsTest.getOptions().get("testmode") + def getTestMode( self ): + return ats.tests.AtsTest.getOptions().get( "testmode" ) - def isCheck(self): + def isCheck( self ): """ Return True iff this is a check step. """ return False - def isDelayed(self): + def isDelayed( self ): """ Return True iff this step and all substeps should be moved to the end of the test case. """ return self.p.delayed - def isMinor(self): + def isMinor( self ): """ Return True iff failure of this step is a minor issue. """ return self.p.minor - def saveOut(self): + def saveOut( self ): return self.p.stdout - def saveErr(self): + def saveErr( self ): return self.p.stderr - def useMPI(self): + def useMPI( self ): """ Return True iff this step uses MPI. """ return False - def resultPaths(self): + def resultPaths( self ): """ Return a list of paths generated by this step. """ return [] - def timelimit(self): - return getattr(self.p, "timelimit", None) + def timelimit( self ): + return getattr( self.p, "timelimit", None ) - def expectedResult(self): - return getattr(self.p, "expectedResult", "PASS") + def expectedResult( self ): + return getattr( self.p, "expectedResult", "PASS" ) - def handleCommonParams(self): + def handleCommonParams( self ): """ Handle all the common parameters. """ - if hasattr(self.p, "np"): + if hasattr( self.p, "np" ): if self.p.np is None: self.p.np = 1 - if hasattr(self.p, "ngpu"): + if hasattr( self.p, "ngpu" ): if self.p.ngpu is None: self.p.ngpu = 0 - if hasattr(self.p, "check"): + if hasattr( self.p, "check" ): if self.p.check is None: self.p.check = True - if hasattr(self.p, "allow_rebaseline"): + if hasattr( self.p, "allow_rebaseline" ): if self.p.allow_rebaseline is None: self.p.allow_rebaseline = True - def executable(self): + def executable( self ): """ Return the path of the executable used to execute this step. """ - raise Error("Must implement this") + raise Error( "Must implement this" ) - def rebaseline(self): + def rebaseline( self ): """ Rebaseline this test step. """ @@ -349,23 +349,23 @@ def rebaseline(self): ################################################################################ # CheckTestStepBase ################################################################################ -class CheckTestStepBase(TestStepBase): +class CheckTestStepBase( TestStepBase ): """ Base class for check test steps. """ - checkParams = (TestParam( + checkParams = ( TestParam( "enabled", "True or False. determines whether this step is enabled. Often times used to turn off automatic check steps", - "True"), ) + "True" ), ) - def isCheck(self): + def isCheck( self ): return True - def handleCommonParams(self): - TestStepBase.handleCommonParams(self) + def handleCommonParams( self ): + TestStepBase.handleCommonParams( self ) - if hasattr(self.p, "enabled"): + if hasattr( self.p, "enabled" ): if self.p.enabled is None: self.p.enabled = True @@ -373,7 +373,7 @@ def handleCommonParams(self): ################################################################################ # geos ################################################################################ -class geos(TestStepBase): +class geos( TestStepBase ): """ Class for the Geos test step. """ @@ -384,20 +384,20 @@ class geos(TestStepBase): command = "geosx [-i ] [-r ] [-x ] [-y ] [-z ] [-s ] [-n ] [-o ] [ --suppress-pinned ] " params = TestStepBase.defaultParams + ( - TestStepBase.commonParams["name"], TestStepBase.commonParams["deck"], TestStepBase.commonParams["np"], - TestStepBase.commonParams["ngpu"], TestStepBase.commonParams["check"], - TestStepBase.commonParams["baseline_dir"], TestStepBase.commonParams["output_directory"], - TestParam("restart_file", "The name of the restart file."), - TestParam("x_partitions", "The number of partitions in the x direction."), - TestParam("y_partitions", "The number of partitions in the y direction."), - TestParam("z_partitions", - "The number of partitions in the z direction."), TestParam("schema_level", "The schema level."), - TestParam("suppress-pinned", "Option to suppress use of pinned memory for MPI buffers."), - TestParam("trace_data_migration", "Trace host-device data migration.")) + TestStepBase.commonParams[ "name" ], TestStepBase.commonParams[ "deck" ], TestStepBase.commonParams[ "np" ], + TestStepBase.commonParams[ "ngpu" ], TestStepBase.commonParams[ "check" ], + TestStepBase.commonParams[ "baseline_dir" ], TestStepBase.commonParams[ "output_directory" ], + TestParam( "restart_file", "The name of the restart file." ), + TestParam( "x_partitions", "The number of partitions in the x direction." ), + TestParam( "y_partitions", "The number of partitions in the y direction." ), + TestParam( "z_partitions", + "The number of partitions in the z direction." ), TestParam( "schema_level", "The schema level." ), + TestParam( "suppress-pinned", "Option to suppress use of pinned memory for MPI buffers." ), + TestParam( "trace_data_migration", "Trace host-device data migration." ) ) - checkstepnames = ["restartcheck"] + checkstepnames = [ "restartcheck" ] - def __init__(self, restartcheck_params=None, curvecheck_params=None, **kw): + def __init__( self, restartcheck_params=None, curvecheck_params=None, **kw ): """ Initializes the parameters of this test step, and creates the appropriate check steps. @@ -406,112 +406,112 @@ def __init__(self, restartcheck_params=None, curvecheck_params=None, **kw): KEYWORDS [in]: Dictionary that is used to set the parameters of this step and also all check steps. """ - TestStepBase.__init__(self) - self.setParams(kw, self.params) + TestStepBase.__init__( self ) + self.setParams( kw, self.params ) checkOption = self.getCheckOption() self.checksteps = [] - if checkOption in ["all", "curvecheck"]: + if checkOption in [ "all", "curvecheck" ]: if curvecheck_params is not None: - self.checksteps.append(curvecheck(curvecheck_params, **kw)) + self.checksteps.append( curvecheck( curvecheck_params, **kw ) ) - if checkOption in ["all", "restartcheck"]: + if checkOption in [ "all", "restartcheck" ]: if restartcheck_params is not None: - self.checksteps.append(restartcheck(restartcheck_params, **kw)) + self.checksteps.append( restartcheck( restartcheck_params, **kw ) ) - def label(self): + def label( self ): return "geos" - def useMPI(self): + def useMPI( self ): return True - def executable(self): + def executable( self ): # python = os.path.join(binDir, "..", "lib", "PYGEOS", "bin", "python3") # pygeosDir = os.path.join(binDir, "..", "..", "src", "pygeos") # return python + " -m mpi4py " + os.path.join( pygeosDir, "reentrantTest.py" ) # return python + " -m mpi4py " + os.path.join( pygeosDir, "test.py" ) # return config.geos_bin_dir - return os.path.join(config.geos_bin_dir, 'geosx') + return os.path.join( config.geos_bin_dir, 'geosx' ) - def update(self, dictionary): - self.setParams(dictionary, self.params) + def update( self, dictionary ): + self.setParams( dictionary, self.params ) - self.requireParam("deck") - self.requireParam("name") - self.requireParam("baseline_dir") - self.requireParam("output_directory") + self.requireParam( "deck" ) + self.requireParam( "name" ) + self.requireParam( "baseline_dir" ) + self.requireParam( "output_directory" ) self.handleCommonParams() - self.setStdout(dictionary) + self.setStdout( dictionary ) # update all the checksteps if self.p.check: for step in self.checksteps: - step.update(dictionary) + step.update( dictionary ) - def insertStep(self, steps): + def insertStep( self, steps ): # the step - steps.append(self) + steps.append( self ) # the post conditions if self.p.check: for step in self.checksteps: - step.insertStep(steps) + step.insertStep( steps ) - def makeArgs(self): + def makeArgs( self ): args = [] if self.p.deck: - args += ["-i", self.p.deck] + args += [ "-i", self.p.deck ] if self.p.restart_file: - args += ["-r", self.p.restart_file] + args += [ "-r", self.p.restart_file ] if self.p.x_partitions: - args += ["-x", self.p.x_partitions] + args += [ "-x", self.p.x_partitions ] if self.p.y_partitions: - args += ["-y", self.p.y_partitions] + args += [ "-y", self.p.y_partitions ] if self.p.z_partitions: - args += ["-z", self.p.z_partitions] + args += [ "-z", self.p.z_partitions ] if self.p.schema_level: - args += ["-s", self.p.schema_level] + args += [ "-s", self.p.schema_level ] if self.p.name: - args += ["-n", self.p.name] + args += [ "-n", self.p.name ] if self.p.output_directory: - args += ["-o", self.p.output_directory] + args += [ "-o", self.p.output_directory ] # if self.p.ngpu == 0: if self.p.ngpu >= 0: - args += ["--suppress-pinned"] + args += [ "--suppress-pinned" ] if self.p.trace_data_migration: - args += ["--trace-data-migration"] + args += [ "--trace-data-migration" ] - return list(map(str, args)) + return list( map( str, args ) ) - def resultPaths(self): + def resultPaths( self ): paths = [] - name = getGeosProblemName(self.p.deck, self.p.name) - paths += [os.path.join(self.p.output_directory, "%s_restart_*") % name] - paths += [os.path.join(self.p.output_directory, "silo*")] - paths += [os.path.join(self.p.output_directory, "%s_bp_*" % name)] + name = getGeosProblemName( self.p.deck, self.p.name ) + paths += [ os.path.join( self.p.output_directory, "%s_restart_*" ) % name ] + paths += [ os.path.join( self.p.output_directory, "silo*" ) ] + paths += [ os.path.join( self.p.output_directory, "%s_bp_*" % name ) ] return paths - def clean(self): - self._clean(self.resultPaths()) + def clean( self ): + self._clean( self.resultPaths() ) ################################################################################ # restartcheck ################################################################################ -class restartcheck(CheckTestStepBase): +class restartcheck( CheckTestStepBase ): """ Class for the restart check test step. """ @@ -521,127 +521,129 @@ class restartcheck(CheckTestStepBase): command = """restartcheck [-r RELATIVE] [-a ABSOLUTE] [-o OUTPUT] [-e EXCLUDE [EXCLUDE ...]] [-w] file_pattern baseline_pattern""" params = TestStepBase.defaultParams + CheckTestStepBase.checkParams + ( - TestStepBase.commonParams["deck"], TestStepBase.commonParams["name"], TestStepBase.commonParams["np"], - TestStepBase.commonParams["allow_rebaseline"], TestStepBase.commonParams["baseline_dir"], - TestStepBase.commonParams["output_directory"], - TestParam("file_pattern", "Regex pattern to match file written out by geos."), - TestParam("baseline_pattern", "Regex pattern to match file to compare against."), - TestParam("rtol", - "Relative tolerance, default is 0.0."), TestParam("atol", "Absolute tolerance, default is 0.0."), - TestParam("exclude", "Regular expressions matching groups to exclude from the check, default is None."), - TestParam("warnings_are_errors", "Treat warnings as errors, default is True."), - TestParam("suppress_output", "Whether to write output to stdout, default is True."), - TestParam("skip_missing", "Whether to skip missing values in target or baseline files, default is False.")) - - def __init__(self, restartcheck_params, **kw): + TestStepBase.commonParams[ "deck" ], TestStepBase.commonParams[ "name" ], TestStepBase.commonParams[ "np" ], + TestStepBase.commonParams[ "allow_rebaseline" ], TestStepBase.commonParams[ "baseline_dir" ], + TestStepBase.commonParams[ "output_directory" ], + TestParam( "file_pattern", "Regex pattern to match file written out by geos." ), + TestParam( "baseline_pattern", "Regex pattern to match file to compare against." ), + TestParam( "rtol", + "Relative tolerance, default is 0.0." ), TestParam( "atol", "Absolute tolerance, default is 0.0." ), + TestParam( "exclude", "Regular expressions matching groups to exclude from the check, default is None." ), + TestParam( "warnings_are_errors", "Treat warnings as errors, default is True." ), + TestParam( "suppress_output", "Whether to write output to stdout, default is True." ), + TestParam( "skip_missing", "Whether to skip missing values in target or baseline files, default is False." ) ) + + def __init__( self, restartcheck_params, **kw ): """ Set parameters with RESTARTCHECK_PARAMS and then with KEYWORDS. """ - CheckTestStepBase.__init__(self) + CheckTestStepBase.__init__( self ) self.p.warnings_are_errors = True if restartcheck_params is not None: - self.setParams(restartcheck_params, self.params) - self.setParams(kw, self.params) + self.setParams( restartcheck_params, self.params ) + self.setParams( kw, self.params ) - def label(self): + def label( self ): return "restartcheck" - def useMPI(self): + def useMPI( self ): return True - def executable(self): + def executable( self ): if self.getTestMode(): return "python -m mpi4py" else: return sys.executable + " -m mpi4py" - def update(self, dictionary): - self.setParams(dictionary, self.params) + def update( self, dictionary ): + self.setParams( dictionary, self.params ) self.handleCommonParams() - self.requireParam("deck") - self.requireParam("baseline_dir") - self.requireParam("output_directory") + self.requireParam( "deck" ) + self.requireParam( "baseline_dir" ) + self.requireParam( "output_directory" ) if self.p.file_pattern is None: - self.p.file_pattern = getGeosProblemName(self.p.deck, self.p.name) + r"_restart_[0-9]+\.root" + self.p.file_pattern = getGeosProblemName( self.p.deck, self.p.name ) + r"_restart_[0-9]+\.root" if self.p.baseline_pattern is None: self.p.baseline_pattern = self.p.file_pattern - self.restart_file_regex = os.path.join(self.p.output_directory, self.p.file_pattern) - self.restart_baseline_regex = os.path.join(self.p.baseline_dir, self.p.baseline_pattern) + self.restart_file_regex = os.path.join( self.p.output_directory, self.p.file_pattern ) + self.restart_baseline_regex = os.path.join( self.p.baseline_dir, self.p.baseline_pattern ) if self.p.allow_rebaseline is None: self.p.allow_rebaseline = True - def insertStep(self, steps): + def insertStep( self, steps ): if config.restartcheck_enabled and self.p.enabled: - steps.append(self) + steps.append( self ) - def makeArgs(self): - cur_dir = os.path.dirname(os.path.realpath(__file__)) - script_location = os.path.join(cur_dir, "helpers", "restart_check.py") - args = [script_location] + def makeArgs( self ): + cur_dir = os.path.dirname( os.path.realpath( __file__ ) ) + script_location = os.path.join( cur_dir, "helpers", "restart_check.py" ) + args = [ script_location ] if self.p.atol is not None: - args += ["-a", self.p.atol] + args += [ "-a", self.p.atol ] if self.p.rtol is not None: - args += ["-r", self.p.rtol] + args += [ "-r", self.p.rtol ] if self.p.warnings_are_errors: - args += ["-w"] + args += [ "-w" ] if self.p.suppress_output: - args += ["-s"] - if (self.p.skip_missing or config.restart_skip_missing): - args += ["-m"] + args += [ "-s" ] + if ( self.p.skip_missing or config.restart_skip_missing ): + args += [ "-m" ] exclude_values = config.restart_exclude_pattern if self.p.exclude is not None: - exclude_values.extend(self.p.exclude) + exclude_values.extend( self.p.exclude ) for v in exclude_values: - args += ["-e", v] + args += [ "-e", v ] - args += [self.restart_file_regex, self.restart_baseline_regex] - return list(map(str, args)) + args += [ self.restart_file_regex, self.restart_baseline_regex ] + return list( map( str, args ) ) - def rebaseline(self): + def rebaseline( self ): if not self.p.allow_rebaseline: - Log("Rebaseline not allowed for restartcheck of %s." % self.p.name) + Log( "Rebaseline not allowed for restartcheck of %s." % self.p.name ) return - root_file_path = findMaxMatchingFile(self.restart_file_regex) + root_file_path = findMaxMatchingFile( self.restart_file_regex ) if root_file_path is None: - raise IOError("File not found matching the pattern %s in directory %s." % - (self.restart_file_regex, os.getcwd())) + raise IOError( "File not found matching the pattern %s in directory %s." % + ( self.restart_file_regex, os.getcwd() ) ) - baseline_dir = os.path.dirname(self.restart_baseline_regex) - root_baseline_path = findMaxMatchingFile(self.restart_baseline_regex) + baseline_dir = os.path.dirname( self.restart_baseline_regex ) + root_baseline_path = findMaxMatchingFile( self.restart_baseline_regex ) if root_baseline_path is not None: # Delete the baseline root file. - os.remove(root_baseline_path) + os.remove( root_baseline_path ) # Delete the directory holding the baseline data files. - data_dir_path = os.path.splitext(root_baseline_path)[0] - shutil.rmtree(data_dir_path) + data_dir_path = os.path.splitext( root_baseline_path )[ 0 ] + shutil.rmtree( data_dir_path ) else: - os.makedirs(baseline_dir, exist_ok=True) + os.makedirs( baseline_dir, exist_ok=True ) # Copy the root file into the baseline directory. - shutil.copy2(root_file_path, os.path.join(baseline_dir, os.path.basename(root_file_path))) + shutil.copy2( root_file_path, os.path.join( baseline_dir, os.path.basename( root_file_path ) ) ) # Copy the directory holding the data files into the baseline directory. - data_dir_path = os.path.splitext(root_file_path)[0] - shutil.copytree(data_dir_path, os.path.join(baseline_dir, os.path.basename(data_dir_path))) + data_dir_path = os.path.splitext( root_file_path )[ 0 ] + shutil.copytree( data_dir_path, os.path.join( baseline_dir, os.path.basename( data_dir_path ) ) ) - def resultPaths(self): - return [os.path.join(self.p.output_directory, "%s.restartcheck" % os.path.splitext(self.p.file_pattern)[0])] + def resultPaths( self ): + return [ + os.path.join( self.p.output_directory, "%s.restartcheck" % os.path.splitext( self.p.file_pattern )[ 0 ] ) + ] - def clean(self): - self._clean(self.resultPaths()) + def clean( self ): + self._clean( self.resultPaths() ) ################################################################################ # curvecheck ################################################################################ -class curvecheck(CheckTestStepBase): +class curvecheck( CheckTestStepBase ): """ Class for the curve check test step. """ @@ -651,186 +653,186 @@ class curvecheck(CheckTestStepBase): command = """curve_check.py [-h] [-c CURVE [CURVE ...]] [-t TOLERANCE] [-w] [-o OUTPUT] [-n N_COLUMN] [-u {milliseconds,seconds,minutes,hours,days,years}] filename baseline""" params = TestStepBase.defaultParams + CheckTestStepBase.checkParams + ( - TestStepBase.commonParams["deck"], TestStepBase.commonParams["name"], TestStepBase.commonParams["np"], - TestStepBase.commonParams["allow_rebaseline"], TestStepBase.commonParams["baseline_dir"], - TestStepBase.commonParams["output_directory"], - TestParam("filename", "Name of the target curve file written by GEOS."), - TestParam("curves", "A list of parameter, setname value pairs."), + TestStepBase.commonParams[ "deck" ], TestStepBase.commonParams[ "name" ], TestStepBase.commonParams[ "np" ], + TestStepBase.commonParams[ "allow_rebaseline" ], TestStepBase.commonParams[ "baseline_dir" ], + TestStepBase.commonParams[ "output_directory" ], + TestParam( "filename", "Name of the target curve file written by GEOS." ), + TestParam( "curves", "A list of parameter, setname value pairs." ), TestParam( "tolerance", "Curve check tolerance (||x-y||/N), can be specified as a single value or a list of floats corresponding to the curves." - ), TestParam("warnings_are_errors", "Treat warnings as errors, default is True."), - TestParam("script_instructions", "A list of (path, function, value, setname) entries"), - TestParam("time_units", "Time units to use for plots.")) + ), TestParam( "warnings_are_errors", "Treat warnings as errors, default is True." ), + TestParam( "script_instructions", "A list of (path, function, value, setname) entries" ), + TestParam( "time_units", "Time units to use for plots." ) ) - def __init__(self, curvecheck_params, **kw): + def __init__( self, curvecheck_params, **kw ): """ Set parameters with CURVECHECK_PARAMS and then with KEYWORDS. """ - CheckTestStepBase.__init__(self) + CheckTestStepBase.__init__( self ) self.p.warnings_are_errors = True if curvecheck_params is not None: c = curvecheck_params.copy() - Nc = len(c.get('curves', [])) + Nc = len( c.get( 'curves', [] ) ) # Note: ats seems to store list/tuple parameters incorrectly # Convert these to strings - for k in ['curves', 'script_instructions']: + for k in [ 'curves', 'script_instructions' ]: if k in c: - if isinstance(c[k], (list, tuple)): - c[k] = ';'.join([','.join(c) for c in c[k]]) + if isinstance( c[ k ], ( list, tuple ) ): + c[ k ] = ';'.join( [ ','.join( c ) for c in c[ k ] ] ) # Check whether tolerance was specified as a single float, list # and then convert into a comma-delimited string - tol = c.get('tolerance', 0.0) - if isinstance(tol, (float, int)): - tol = [tol] * Nc - c['tolerance'] = ','.join([str(x) for x in tol]) + tol = c.get( 'tolerance', 0.0 ) + if isinstance( tol, ( float, int ) ): + tol = [ tol ] * Nc + c[ 'tolerance' ] = ','.join( [ str( x ) for x in tol ] ) - self.setParams(c, self.params) - self.setParams(kw, self.params) + self.setParams( c, self.params ) + self.setParams( kw, self.params ) - def label(self): + def label( self ): return "curvecheck" - def useMPI(self): + def useMPI( self ): return True - def executable(self): + def executable( self ): if self.getTestMode(): return "python" else: return sys.executable - def update(self, dictionary): - self.setParams(dictionary, self.params) + def update( self, dictionary ): + self.setParams( dictionary, self.params ) self.handleCommonParams() - self.requireParam("deck") - self.requireParam("baseline_dir") - self.requireParam("output_directory") + self.requireParam( "deck" ) + self.requireParam( "baseline_dir" ) + self.requireParam( "output_directory" ) - self.baseline_file = os.path.join(self.p.baseline_dir, self.p.filename) - self.target_file = os.path.join(self.p.output_directory, self.p.filename) - self.figure_root = os.path.join(self.p.output_directory, 'curve_check') + self.baseline_file = os.path.join( self.p.baseline_dir, self.p.filename ) + self.target_file = os.path.join( self.p.output_directory, self.p.filename ) + self.figure_root = os.path.join( self.p.output_directory, 'curve_check' ) if self.p.allow_rebaseline is None: self.p.allow_rebaseline = True - def insertStep(self, steps): + def insertStep( self, steps ): if config.restartcheck_enabled and self.p.enabled: - steps.append(self) + steps.append( self ) - def makeArgs(self): - cur_dir = os.path.dirname(os.path.realpath(__file__)) - script_location = os.path.join(cur_dir, "helpers", "curve_check.py") - args = [script_location] + def makeArgs( self ): + cur_dir = os.path.dirname( os.path.realpath( __file__ ) ) + script_location = os.path.join( cur_dir, "helpers", "curve_check.py" ) + args = [ script_location ] if self.p.curves is not None: - for c in self.p.curves.split(';'): - args += ["-c"] - args += c.split(',') + for c in self.p.curves.split( ';' ): + args += [ "-c" ] + args += c.split( ',' ) if self.p.tolerance is not None: - for t in self.p.tolerance.split(','): - args += ["-t", t] + for t in self.p.tolerance.split( ',' ): + args += [ "-t", t ] if self.p.time_units is not None: - args += ["-u", self.p.time_units] + args += [ "-u", self.p.time_units ] if self.p.script_instructions is not None: - for c in self.p.script_instructions.split(';'): - args += ["-s"] - args += c.split(',') + for c in self.p.script_instructions.split( ';' ): + args += [ "-s" ] + args += c.split( ',' ) if self.p.warnings_are_errors: - args += ["-w"] + args += [ "-w" ] - args += ['-o', self.figure_root] - args += [self.target_file, self.baseline_file] - return list(map(str, args)) + args += [ '-o', self.figure_root ] + args += [ self.target_file, self.baseline_file ] + return list( map( str, args ) ) - def rebaseline(self): + def rebaseline( self ): if not self.p.allow_rebaseline: - Log("Rebaseline not allowed for curvecheck of %s." % self.p.name) + Log( "Rebaseline not allowed for curvecheck of %s." % self.p.name ) return - baseline_dir = os.path.split(self.baseline_file)[0] - os.makedirs(baseline_dir, exist_ok=True) - shutil.copyfile(self.target_file, self.baseline_file) + baseline_dir = os.path.split( self.baseline_file )[ 0 ] + os.makedirs( baseline_dir, exist_ok=True ) + shutil.copyfile( self.target_file, self.baseline_file ) - def resultPaths(self): - figure_pattern = os.path.join(self.figure_root, '*.png') - figure_list = sorted(glob.glob(figure_pattern)) - return [self.target_file] + figure_list + def resultPaths( self ): + figure_pattern = os.path.join( self.figure_root, '*.png' ) + figure_list = sorted( glob.glob( figure_pattern ) ) + return [ self.target_file ] + figure_list - def clean(self): - self._clean(self.resultPaths()) + def clean( self ): + self._clean( self.resultPaths() ) -def infoTestStepParams(params, maxwidth=None): +def infoTestStepParams( params, maxwidth=None ): if maxwidth is None: - maxwidth = max(10, max([len(p.name) for p in params])) + maxwidth = max( 10, max( [ len( p.name ) for p in params ] ) ) for p in params: paramdoc = p.doc if p.default is not None: - paramdoc += " (default = %s)" % (p.default) - paramdoc = textwrap.wrap(paramdoc, width=100 - maxwidth) - logger.debug(" %*s:" % (maxwidth, p.name), paramdoc[0].strip()) - for line in paramdoc[1:]: - logger.debug(" %*s %s" % (maxwidth, "", line.strip())) + paramdoc += " (default = %s)" % ( p.default ) + paramdoc = textwrap.wrap( paramdoc, width=100 - maxwidth ) + logger.debug( " %*s:" % ( maxwidth, p.name ), paramdoc[ 0 ].strip() ) + for line in paramdoc[ 1: ]: + logger.debug( " %*s %s" % ( maxwidth, "", line.strip() ) ) -def infoTestStep(stepname): - topic = common_utilities.InfoTopic(stepname) +def infoTestStep( stepname ): + topic = common_utilities.InfoTopic( stepname ) topic.startBanner() - logger.debug(f"TestStep: {stepname}") - stepclass = globals()[stepname] - if not hasattr(stepclass, "doc"): + logger.debug( f"TestStep: {stepname}" ) + stepclass = globals()[ stepname ] + if not hasattr( stepclass, "doc" ): return - logger.debug("Description:") - doc = textwrap.dedent(stepclass.doc) - doc = textwrap.wrap(doc, width=100) + logger.debug( "Description:" ) + doc = textwrap.dedent( stepclass.doc ) + doc = textwrap.wrap( doc, width=100 ) for line in doc: - logger.debug(" ", line.strip()) + logger.debug( " ", line.strip() ) - logger.debug("Command:") - doc = textwrap.dedent(stepclass.command) - doc = textwrap.wrap(doc, width=100) - logger.debug(f" {doc[0].strip()}") - for line in doc[1:]: - logger.debug(f'\\\n {" " * len(stepname)} {line}') + logger.debug( "Command:" ) + doc = textwrap.dedent( stepclass.command ) + doc = textwrap.wrap( doc, width=100 ) + logger.debug( f" {doc[0].strip()}" ) + for line in doc[ 1: ]: + logger.debug( f'\\\n {" " * len(stepname)} {line}' ) # compute max param width: - allparams = [p.name for p in stepclass.params] - if hasattr(stepclass, "checkstepnames"): + allparams = [ p.name for p in stepclass.params ] + if hasattr( stepclass, "checkstepnames" ): for checkstep in stepclass.checkstepnames: - checkclass = globals()[checkstep] - if not hasattr(checkclass, "doc"): + checkclass = globals()[ checkstep ] + if not hasattr( checkclass, "doc" ): continue - allparams.extend([p.name for p in checkclass.params]) - maxwidth = max(10, max([len(p) for p in allparams])) + allparams.extend( [ p.name for p in checkclass.params ] ) + maxwidth = max( 10, max( [ len( p ) for p in allparams ] ) ) - logger.debug("Parameters:") - infoTestStepParams(stepclass.params, maxwidth) + logger.debug( "Parameters:" ) + infoTestStepParams( stepclass.params, maxwidth ) - paramset = set([p.name for p in stepclass.params]) + paramset = set( [ p.name for p in stepclass.params ] ) - if hasattr(stepclass, "checkstepnames"): + if hasattr( stepclass, "checkstepnames" ): for checkstep in stepclass.checkstepnames: - logger.debug(f"CheckStep: {checkstep}") + logger.debug( f"CheckStep: {checkstep}" ) checkparams = [] - checkclass = globals()[checkstep] - if not hasattr(checkclass, "doc"): + checkclass = globals()[ checkstep ] + if not hasattr( checkclass, "doc" ): continue for p in checkclass.params: if p.name not in paramset: - checkparams.append(p) + checkparams.append( p ) - infoTestStepParams(checkparams, maxwidth) + infoTestStepParams( checkparams, maxwidth ) topic.endBanner() -def infoTestSteps(*args): +def infoTestSteps( *args ): """This function is used to print documentation about the teststeps to stdout""" # get the list of step classes @@ -838,39 +840,39 @@ def infoTestSteps(*args): checkstepnames = [] for k, v in globals().items(): - if not isinstance(v, type): + if not isinstance( v, type ): continue - if v in (CheckTestStepBase, TestStepBase): + if v in ( CheckTestStepBase, TestStepBase ): continue try: - if issubclass(v, CheckTestStepBase): - checkstepnames.append(k) - elif issubclass(v, TestStepBase): - steps.append(k) + if issubclass( v, CheckTestStepBase ): + checkstepnames.append( k ) + elif issubclass( v, TestStepBase ): + steps.append( k ) except TypeError as e: - logger.debug(e) + logger.debug( e ) - steps = sorted(steps) - checkstepnames = sorted(checkstepnames) + steps = sorted( steps ) + checkstepnames = sorted( checkstepnames ) steps = steps + checkstepnames def all(): for s in steps: - infoTestStep(s) + infoTestStep( s ) - topic = common_utilities.InfoTopic("teststep") - topic.addTopic("all", "full info on all the teststeps", all) + topic = common_utilities.InfoTopic( "teststep" ) + topic.addTopic( "all", "full info on all the teststeps", all ) for s in steps: - stepclass = globals()[s] - doc = getattr(stepclass, "doc", None) - topic.addTopic(s, textwrap.dedent(doc).strip(), lambda ss=s: infoTestStep(ss)) + stepclass = globals()[ s ] + doc = getattr( stepclass, "doc", None ) + topic.addTopic( s, textwrap.dedent( doc ).strip(), lambda ss=s: infoTestStep( ss ) ) - topic.process(args) + topic.process( args ) # Register test step definitions -ats.manager.define(geos=geos) -ats.manager.define(restartcheck=restartcheck) -ats.manager.define(config=config) +ats.manager.define( geos=geos ) +ats.manager.define( restartcheck=restartcheck ) +ats.manager.define( config=config ) diff --git a/geos_ats_package/geos_ats/user_utilities.py b/geos_ats_package/geos_ats/user_utilities.py index 775e3a1..4fb304a 100644 --- a/geos_ats_package/geos_ats/user_utilities.py +++ b/geos_ats_package/geos_ats/user_utilities.py @@ -1,4 +1,4 @@ -import ats # type: ignore[import] +import ats # type: ignore[import] import os @@ -6,32 +6,32 @@ # Common functions available to the tests # (each must be registered via ats.manager.define() ################################################################################ -def which(program): +def which( program ): - def is_exe(fpath): - return os.path.isfile(fpath) and os.access(fpath, os.X_OK) + def is_exe( fpath ): + return os.path.isfile( fpath ) and os.access( fpath, os.X_OK ) - fpath, fname = os.path.split(program) + fpath, fname = os.path.split( program ) if fpath: - if is_exe(program): + if is_exe( program ): return program else: - for path in os.environ["PATH"].split(os.pathsep): - path = path.strip('"') - exe_file = os.path.join(path, program) - if is_exe(exe_file): + for path in os.environ[ "PATH" ].split( os.pathsep ): + path = path.strip( '"' ) + exe_file = os.path.join( path, program ) + if is_exe( exe_file ): return exe_file return None -def getEnviron(name): +def getEnviron( name ): import os try: - return os.environ[name] + return os.environ[ name ] except: return None -ats.manager.define(which=which) -ats.manager.define(getEnviron=getEnviron) +ats.manager.define( which=which ) +ats.manager.define( getEnviron=getEnviron ) diff --git a/geosx_mesh_doctor/checks/check_fractures.py b/geosx_mesh_doctor/checks/check_fractures.py index b2c241b..9f14aed 100644 --- a/geosx_mesh_doctor/checks/check_fractures.py +++ b/geosx_mesh_doctor/checks/check_fractures.py @@ -18,20 +18,16 @@ vtkCell, ) from vtkmodules.vtkCommonCore import ( - vtkPoints, -) + vtkPoints, ) from vtkmodules.vtkIOXML import ( - vtkXMLMultiBlockDataReader, -) + vtkXMLMultiBlockDataReader, ) from vtkmodules.util.numpy_support import ( - vtk_to_numpy, -) + vtk_to_numpy, ) from vtk_utils import ( - vtk_iter, -) + vtk_iter, ) -@dataclass(frozen=True) +@dataclass( frozen=True ) class Options: tolerance: float matrix_name: str @@ -39,140 +35,139 @@ class Options: collocated_nodes_field_name: str -@dataclass(frozen=True) +@dataclass( frozen=True ) class Result: # First index is the local index of the fracture mesh. # Second is the local index of the matrix mesh. # Third is the global index in the matrix mesh. - errors: Sequence[tuple[int, int, int]] + errors: Sequence[ tuple[ int, int, int ] ] -def __read_multiblock(vtk_input_file: str, matrix_name: str, fracture_name: str) -> Tuple[vtkUnstructuredGrid, vtkUnstructuredGrid]: +def __read_multiblock( vtk_input_file: str, matrix_name: str, + fracture_name: str ) -> Tuple[ vtkUnstructuredGrid, vtkUnstructuredGrid ]: reader = vtkXMLMultiBlockDataReader() - reader.SetFileName(vtk_input_file) + reader.SetFileName( vtk_input_file ) reader.Update() multi_block = reader.GetOutput() - for b in range(multi_block.GetNumberOfBlocks()): - block_name: str = multi_block.GetMetaData(b).Get(multi_block.NAME()) + for b in range( multi_block.GetNumberOfBlocks() ): + block_name: str = multi_block.GetMetaData( b ).Get( multi_block.NAME() ) if block_name == matrix_name: - matrix: vtkUnstructuredGrid = multi_block.GetBlock(b) + matrix: vtkUnstructuredGrid = multi_block.GetBlock( b ) if block_name == fracture_name: - fracture: vtkUnstructuredGrid = multi_block.GetBlock(b) + fracture: vtkUnstructuredGrid = multi_block.GetBlock( b ) assert matrix and fracture return matrix, fracture -def format_collocated_nodes(fracture_mesh: vtkUnstructuredGrid) -> Sequence[Iterable[int]]: +def format_collocated_nodes( fracture_mesh: vtkUnstructuredGrid ) -> Sequence[ Iterable[ int ] ]: """ Extract the collocated nodes information from the mesh and formats it in a python way. :param fracture_mesh: The mesh of the fracture (with 2d cells). :return: An iterable over all the buckets of collocated nodes. """ - collocated_nodes: numpy.ndarray = vtk_to_numpy(fracture_mesh.GetPointData().GetArray("collocated_nodes")) - if len(collocated_nodes.shape) == 1: - collocated_nodes: numpy.ndarray = collocated_nodes.reshape((collocated_nodes.shape[0], 1)) - generator = (tuple(sorted(bucket[bucket > -1])) for bucket in collocated_nodes) - return tuple(generator) + collocated_nodes: numpy.ndarray = vtk_to_numpy( fracture_mesh.GetPointData().GetArray( "collocated_nodes" ) ) + if len( collocated_nodes.shape ) == 1: + collocated_nodes: numpy.ndarray = collocated_nodes.reshape( ( collocated_nodes.shape[ 0 ], 1 ) ) + generator = ( tuple( sorted( bucket[ bucket > -1 ] ) ) for bucket in collocated_nodes ) + return tuple( generator ) -def __check_collocated_nodes_positions(matrix_points: Sequence[Tuple[float, float, float]], - fracture_points: Sequence[Tuple[float, float, float]], - g2l: Sequence[int], - collocated_nodes: Iterable[Iterable[int]]) -> Collection[Tuple[int, Iterable[int], Iterable[Tuple[float, float, float]]]]: +def __check_collocated_nodes_positions( + matrix_points: Sequence[ Tuple[ float, float, float ] ], fracture_points: Sequence[ Tuple[ float, float, float ] ], + g2l: Sequence[ int ], collocated_nodes: Iterable[ Iterable[ int ] ] +) -> Collection[ Tuple[ int, Iterable[ int ], Iterable[ Tuple[ float, float, float ] ] ] ]: issues = [] - for li, bucket in enumerate(collocated_nodes): - matrix_nodes = (fracture_points[li], ) + tuple(map(lambda gi: matrix_points[g2l[gi]], bucket)) - m = numpy.array(matrix_nodes) - rank: int = numpy.linalg.matrix_rank(m) + for li, bucket in enumerate( collocated_nodes ): + matrix_nodes = ( fracture_points[ li ], ) + tuple( map( lambda gi: matrix_points[ g2l[ gi ] ], bucket ) ) + m = numpy.array( matrix_nodes ) + rank: int = numpy.linalg.matrix_rank( m ) if rank > 1: - issues.append((li, bucket, tuple(map(lambda gi: matrix_points[g2l[gi]], bucket)))) + issues.append( ( li, bucket, tuple( map( lambda gi: matrix_points[ g2l[ gi ] ], bucket ) ) ) ) return issues -def my_iter(ccc): - car, cdr = ccc[0], ccc[1:] +def my_iter( ccc ): + car, cdr = ccc[ 0 ], ccc[ 1: ] for i in car: if cdr: - for j in my_iter(cdr): + for j in my_iter( cdr ): yield i, *j else: - yield (i, ) + yield ( i, ) -def __check_neighbors(matrix: vtkUnstructuredGrid, - fracture: vtkUnstructuredGrid, - g2l: Sequence[int], - collocated_nodes: Sequence[Iterable[int]]): - fracture_nodes: Set[int] = set() +def __check_neighbors( matrix: vtkUnstructuredGrid, fracture: vtkUnstructuredGrid, g2l: Sequence[ int ], + collocated_nodes: Sequence[ Iterable[ int ] ] ): + fracture_nodes: Set[ int ] = set() for bucket in collocated_nodes: for gi in bucket: - fracture_nodes.add(g2l[gi]) + fracture_nodes.add( g2l[ gi ] ) # For each face of each cell, # if all the points of the face are "made" of collocated nodes, # then this is a fracture face. - fracture_faces: Set[FrozenSet[int]] = set() - for c in range(matrix.GetNumberOfCells()): - cell: vtkCell = matrix.GetCell(c) - for f in range(cell.GetNumberOfFaces()): - face: vtkCell = cell.GetFace(f) - point_ids = frozenset(vtk_iter(face.GetPointIds())) + fracture_faces: Set[ FrozenSet[ int ] ] = set() + for c in range( matrix.GetNumberOfCells() ): + cell: vtkCell = matrix.GetCell( c ) + for f in range( cell.GetNumberOfFaces() ): + face: vtkCell = cell.GetFace( f ) + point_ids = frozenset( vtk_iter( face.GetPointIds() ) ) if point_ids <= fracture_nodes: - fracture_faces.add(point_ids) + fracture_faces.add( point_ids ) # Finding the cells - for c in tqdm(range(fracture.GetNumberOfCells()), desc="Finding neighbor cell pairs"): - cell: vtkCell = fracture.GetCell(c) - cns: Set[FrozenSet[int]] = set() # subset of collocated_nodes - point_ids = frozenset(vtk_iter(cell.GetPointIds())) + for c in tqdm( range( fracture.GetNumberOfCells() ), desc="Finding neighbor cell pairs" ): + cell: vtkCell = fracture.GetCell( c ) + cns: Set[ FrozenSet[ int ] ] = set() # subset of collocated_nodes + point_ids = frozenset( vtk_iter( cell.GetPointIds() ) ) for point_id in point_ids: - bucket = collocated_nodes[point_id] - local_bucket = frozenset(map(g2l.__getitem__, bucket)) - cns.add(local_bucket) + bucket = collocated_nodes[ point_id ] + local_bucket = frozenset( map( g2l.__getitem__, bucket ) ) + cns.add( local_bucket ) found = 0 - tmp = tuple(map(tuple, cns)) - for node_combinations in my_iter(tmp): - f = frozenset(node_combinations) + tmp = tuple( map( tuple, cns ) ) + for node_combinations in my_iter( tmp ): + f = frozenset( node_combinations ) if f in fracture_faces: found += 1 if found != 2: - logging.warning(f"Something went wrong since we should have found 2 fractures faces (we found {found}) for collocated nodes {cns}.") + logging.warning( + f"Something went wrong since we should have found 2 fractures faces (we found {found}) for collocated nodes {cns}." + ) -def __check(vtk_input_file: str, options: Options) -> Result: - matrix, fracture = __read_multiblock(vtk_input_file, options.matrix_name, options.fracture_name) +def __check( vtk_input_file: str, options: Options ) -> Result: + matrix, fracture = __read_multiblock( vtk_input_file, options.matrix_name, options.fracture_name ) matrix_points: vtkPoints = matrix.GetPoints() fracture_points: vtkPoints = fracture.GetPoints() - collocated_nodes: Sequence[Iterable[int]] = format_collocated_nodes(fracture) + collocated_nodes: Sequence[ Iterable[ int ] ] = format_collocated_nodes( fracture ) assert matrix.GetPointData().GetGlobalIds() and matrix.GetCellData().GetGlobalIds() and \ fracture.GetPointData().GetGlobalIds() and fracture.GetCellData().GetGlobalIds() - point_ids = vtk_to_numpy(matrix.GetPointData().GetGlobalIds()) - g2l = numpy.ones(len(point_ids), dtype=int) * -1 - for loc, glo in enumerate(point_ids): - g2l[glo] = loc + point_ids = vtk_to_numpy( matrix.GetPointData().GetGlobalIds() ) + g2l = numpy.ones( len( point_ids ), dtype=int ) * -1 + for loc, glo in enumerate( point_ids ): + g2l[ glo ] = loc g2l.flags.writeable = False - issues = __check_collocated_nodes_positions(vtk_to_numpy(matrix.GetPoints().GetData()), - vtk_to_numpy(fracture.GetPoints().GetData()), - g2l, - collocated_nodes) - assert len(issues) == 0 + issues = __check_collocated_nodes_positions( vtk_to_numpy( matrix.GetPoints().GetData() ), + vtk_to_numpy( fracture.GetPoints().GetData() ), g2l, collocated_nodes ) + assert len( issues ) == 0 - __check_neighbors(matrix, fracture, g2l, collocated_nodes) + __check_neighbors( matrix, fracture, g2l, collocated_nodes ) errors = [] - for i, duplicates in enumerate(collocated_nodes): - for duplicate in filter(lambda i: i > -1, duplicates): - p0 = matrix_points.GetPoint(g2l[duplicate]) - p1 = fracture_points.GetPoint(i) - if numpy.linalg.norm(numpy.array(p1) - numpy.array(p0)) > options.tolerance: - errors.append((i, g2l[duplicate], duplicate)) - return Result(errors=errors) + for i, duplicates in enumerate( collocated_nodes ): + for duplicate in filter( lambda i: i > -1, duplicates ): + p0 = matrix_points.GetPoint( g2l[ duplicate ] ) + p1 = fracture_points.GetPoint( i ) + if numpy.linalg.norm( numpy.array( p1 ) - numpy.array( p0 ) ) > options.tolerance: + errors.append( ( i, g2l[ duplicate ], duplicate ) ) + return Result( errors=errors ) -def check(vtk_input_file: str, options: Options) -> Result: +def check( vtk_input_file: str, options: Options ) -> Result: try: - return __check(vtk_input_file, options) + return __check( vtk_input_file, options ) except BaseException as e: - logging.error(e) - return Result(errors=()) + logging.error( e ) + return Result( errors=() ) diff --git a/geosx_mesh_doctor/checks/collocated_nodes.py b/geosx_mesh_doctor/checks/collocated_nodes.py index 7a5273e..d64cd6c 100644 --- a/geosx_mesh_doctor/checks/collocated_nodes.py +++ b/geosx_mesh_doctor/checks/collocated_nodes.py @@ -17,32 +17,32 @@ from . import vtk_utils -@dataclass(frozen=True) +@dataclass( frozen=True ) class Options: tolerance: float -@dataclass(frozen=True) +@dataclass( frozen=True ) class Result: - nodes_buckets: Iterable[Iterable[int]] # Each bucket contains the duplicated node indices. - wrong_support_elements: Collection[int] # Element indices with support node indices appearing more than once. + nodes_buckets: Iterable[ Iterable[ int ] ] # Each bucket contains the duplicated node indices. + wrong_support_elements: Collection[ int ] # Element indices with support node indices appearing more than once. -def __check(mesh, options: Options) -> Result: +def __check( mesh, options: Options ) -> Result: points = mesh.GetPoints() locator = vtkIncrementalOctreePointLocator() - locator.SetTolerance(options.tolerance) + locator.SetTolerance( options.tolerance ) output = vtkPoints() - locator.InitPointInsertion(output, points.GetBounds()) + locator.InitPointInsertion( output, points.GetBounds() ) # original ids to/from filtered ids. - filtered_to_original = numpy.ones(points.GetNumberOfPoints(), dtype=int) * -1 + filtered_to_original = numpy.ones( points.GetNumberOfPoints(), dtype=int ) * -1 - rejected_points = defaultdict(list) - point_id = reference(0) - for i in range(points.GetNumberOfPoints()): - is_inserted = locator.InsertUniquePoint(points.GetPoint(i), point_id) + rejected_points = defaultdict( list ) + point_id = reference( 0 ) + for i in range( points.GetNumberOfPoints() ): + is_inserted = locator.InsertUniquePoint( points.GetPoint( i ), point_id ) if not is_inserted: # If it's not inserted, `point_id` contains the node that was already at that location. # But in that case, `point_id` is the new numbering in the destination points array. @@ -50,29 +50,28 @@ def __check(mesh, options: Options) -> Result: logging.debug( f"Point {i} at {points.GetPoint(i)} has been rejected, point {filtered_to_original[point_id.get()]} is already inserted." ) - rejected_points[point_id.get()].append(i) + rejected_points[ point_id.get() ].append( i ) else: # If it's inserted, `point_id` contains the new index in the destination array. # We store this information to be able to connect the source and destination arrays. # original_to_filtered[i] = point_id.get() - filtered_to_original[point_id.get()] = i + filtered_to_original[ point_id.get() ] = i tmp = [] for n, ns in rejected_points.items(): - tmp.append((n, *ns)) + tmp.append( ( n, *ns ) ) # Checking that the support node indices appear only once per element. wrong_support_elements = [] - for c in range(mesh.GetNumberOfCells()): - cell = mesh.GetCell(c) + for c in range( mesh.GetNumberOfCells() ): + cell = mesh.GetCell( c ) num_points_per_cell = cell.GetNumberOfPoints() - if len({cell.GetPointId(i) for i in range(num_points_per_cell)}) != num_points_per_cell: - wrong_support_elements.append(c) + if len( { cell.GetPointId( i ) for i in range( num_points_per_cell ) } ) != num_points_per_cell: + wrong_support_elements.append( c ) - return Result(nodes_buckets=tmp, - wrong_support_elements=wrong_support_elements) + return Result( nodes_buckets=tmp, wrong_support_elements=wrong_support_elements ) -def check(vtk_input_file: str, options: Options) -> Result: - mesh = vtk_utils.read_mesh(vtk_input_file) - return __check(mesh, options) +def check( vtk_input_file: str, options: Options ) -> Result: + mesh = vtk_utils.read_mesh( vtk_input_file ) + return __check( mesh, options ) diff --git a/geosx_mesh_doctor/checks/element_volumes.py b/geosx_mesh_doctor/checks/element_volumes.py index 4dfd917..4d45453 100644 --- a/geosx_mesh_doctor/checks/element_volumes.py +++ b/geosx_mesh_doctor/checks/element_volumes.py @@ -14,24 +14,22 @@ vtkMeshQuality, ) from vtkmodules.util.numpy_support import ( - vtk_to_numpy, -) - + vtk_to_numpy, ) from . import vtk_utils -@dataclass(frozen=True) +@dataclass( frozen=True ) class Options: min_volume: float -@dataclass(frozen=True) +@dataclass( frozen=True ) class Result: - element_volumes: List[Tuple[int, float]] + element_volumes: List[ Tuple[ int, float ] ] -def __check(mesh, options: Options) -> Result: +def __check( mesh, options: Options ) -> Result: cs = vtkCellSizeFilter() cs.ComputeAreaOff() @@ -39,44 +37,46 @@ def __check(mesh, options: Options) -> Result: cs.ComputeSumOff() cs.ComputeVertexCountOff() cs.ComputeVolumeOn() - volume_array_name = "__MESH_DOCTOR_VOLUME-" + str(uuid.uuid4()) # Making the name unique - cs.SetVolumeArrayName(volume_array_name) + volume_array_name = "__MESH_DOCTOR_VOLUME-" + str( uuid.uuid4() ) # Making the name unique + cs.SetVolumeArrayName( volume_array_name ) - cs.SetInputData(mesh) + cs.SetInputData( mesh ) cs.Update() mq = vtkMeshQuality() - SUPPORTED_TYPES = [VTK_HEXAHEDRON, VTK_TETRA] + SUPPORTED_TYPES = [ VTK_HEXAHEDRON, VTK_TETRA ] mq.SetTetQualityMeasureToVolume() mq.SetHexQualityMeasureToVolume() - if hasattr(mq, "SetPyramidQualityMeasureToVolume"): # This feature is quite recent + if hasattr( mq, "SetPyramidQualityMeasureToVolume" ): # This feature is quite recent mq.SetPyramidQualityMeasureToVolume() - SUPPORTED_TYPES.append(VTK_PYRAMID) + SUPPORTED_TYPES.append( VTK_PYRAMID ) mq.SetWedgeQualityMeasureToVolume() - SUPPORTED_TYPES.append(VTK_WEDGE) + SUPPORTED_TYPES.append( VTK_WEDGE ) else: - logging.warning("Your \"pyvtk\" version does not bring pyramid nor wedge support with vtkMeshQuality. Using the fallback solution.") + logging.warning( + "Your \"pyvtk\" version does not bring pyramid nor wedge support with vtkMeshQuality. Using the fallback solution." + ) - mq.SetInputData(mesh) + mq.SetInputData( mesh ) mq.Update() - volume = cs.GetOutput().GetCellData().GetArray(volume_array_name) - quality = mq.GetOutput().GetCellData().GetArray("Quality") # Name is imposed by vtk. + volume = cs.GetOutput().GetCellData().GetArray( volume_array_name ) + quality = mq.GetOutput().GetCellData().GetArray( "Quality" ) # Name is imposed by vtk. assert volume is not None assert quality is not None - volume = vtk_to_numpy(volume) - quality = vtk_to_numpy(quality) - small_volumes: List[Tuple[int, float]] = [] - for i, pack in enumerate(zip(volume, quality)): + volume = vtk_to_numpy( volume ) + quality = vtk_to_numpy( quality ) + small_volumes: List[ Tuple[ int, float ] ] = [] + for i, pack in enumerate( zip( volume, quality ) ): v, q = pack - vol = q if mesh.GetCellType(i) in SUPPORTED_TYPES else v + vol = q if mesh.GetCellType( i ) in SUPPORTED_TYPES else v if vol < options.min_volume: - small_volumes.append((i, vol)) - return Result(element_volumes=small_volumes) + small_volumes.append( ( i, vol ) ) + return Result( element_volumes=small_volumes ) -def check(vtk_input_file: str, options: Options) -> Result: - mesh = vtk_utils.read_mesh(vtk_input_file) - return __check(mesh, options) +def check( vtk_input_file: str, options: Options ) -> Result: + mesh = vtk_utils.read_mesh( vtk_input_file ) + return __check( mesh, options ) diff --git a/geosx_mesh_doctor/checks/fix_elements_orderings.py b/geosx_mesh_doctor/checks/fix_elements_orderings.py index 61dd034..eb603b4 100644 --- a/geosx_mesh_doctor/checks/fix_elements_orderings.py +++ b/geosx_mesh_doctor/checks/fix_elements_orderings.py @@ -8,8 +8,7 @@ ) from vtkmodules.vtkCommonCore import ( - vtkIdList, -) + vtkIdList, ) from . import vtk_utils from .vtk_utils import ( @@ -18,49 +17,49 @@ ) -@dataclass(frozen=True) +@dataclass( frozen=True ) class Options: vtk_output: VtkOutput - cell_type_to_ordering: Dict[int, List[int]] + cell_type_to_ordering: Dict[ int, List[ int ] ] -@dataclass(frozen=True) +@dataclass( frozen=True ) class Result: output: str - unchanged_cell_types: FrozenSet[int] + unchanged_cell_types: FrozenSet[ int ] -def __check(mesh, options: Options) -> Result: +def __check( mesh, options: Options ) -> Result: # The vtk cell type is an int and will be the key of the following mapping, # that will point to the relevant permutation. - cell_type_to_ordering: Dict[int, List[int]] = options.cell_type_to_ordering - unchanged_cell_types: Set[int] = set() # For logging purpose + cell_type_to_ordering: Dict[ int, List[ int ] ] = options.cell_type_to_ordering + unchanged_cell_types: Set[ int ] = set() # For logging purpose # Preparing the output mesh by first keeping the same instance type. output_mesh = mesh.NewInstance() - output_mesh.CopyStructure(mesh) - output_mesh.CopyAttributes(mesh) + output_mesh.CopyStructure( mesh ) + output_mesh.CopyAttributes( mesh ) # `output_mesh` now contains a full copy of the input mesh. # We'll now modify the support nodes orderings in place if needed. cells = output_mesh.GetCells() - for cell_idx in range(output_mesh.GetNumberOfCells()): - cell_type: int = output_mesh.GetCell(cell_idx).GetCellType() - new_ordering = cell_type_to_ordering.get(cell_type) + for cell_idx in range( output_mesh.GetNumberOfCells() ): + cell_type: int = output_mesh.GetCell( cell_idx ).GetCellType() + new_ordering = cell_type_to_ordering.get( cell_type ) if new_ordering: support_point_ids = vtkIdList() - cells.GetCellAtId(cell_idx, support_point_ids) + cells.GetCellAtId( cell_idx, support_point_ids ) new_support_point_ids = [] - for i, v in enumerate(new_ordering): - new_support_point_ids.append(support_point_ids.GetId(new_ordering[i])) - cells.ReplaceCellAtId(cell_idx, to_vtk_id_list(new_support_point_ids)) + for i, v in enumerate( new_ordering ): + new_support_point_ids.append( support_point_ids.GetId( new_ordering[ i ] ) ) + cells.ReplaceCellAtId( cell_idx, to_vtk_id_list( new_support_point_ids ) ) else: - unchanged_cell_types.add(cell_type) - is_written_error = vtk_utils.write_mesh(output_mesh, options.vtk_output) - return Result(output=options.vtk_output.output if not is_written_error else "", - unchanged_cell_types=frozenset(unchanged_cell_types)) + unchanged_cell_types.add( cell_type ) + is_written_error = vtk_utils.write_mesh( output_mesh, options.vtk_output ) + return Result( output=options.vtk_output.output if not is_written_error else "", + unchanged_cell_types=frozenset( unchanged_cell_types ) ) -def check(vtk_input_file: str, options: Options) -> Result: - mesh = vtk_utils.read_mesh(vtk_input_file) - return __check(mesh, options) +def check( vtk_input_file: str, options: Options ) -> Result: + mesh = vtk_utils.read_mesh( vtk_input_file ) + return __check( mesh, options ) diff --git a/geosx_mesh_doctor/checks/generate_cube.py b/geosx_mesh_doctor/checks/generate_cube.py index f8625f5..c9a6d65 100644 --- a/geosx_mesh_doctor/checks/generate_cube.py +++ b/geosx_mesh_doctor/checks/generate_cube.py @@ -5,8 +5,7 @@ import numpy from vtkmodules.vtkCommonCore import ( - vtkPoints, -) + vtkPoints, ) from vtkmodules.vtkCommonDataModel import ( VTK_HEXAHEDRON, vtkCellArray, @@ -15,51 +14,49 @@ vtkUnstructuredGrid, ) from vtkmodules.util.numpy_support import ( - numpy_to_vtk, -) + numpy_to_vtk, ) from . import vtk_utils from .vtk_utils import ( - VtkOutput, -) + VtkOutput, ) from .generate_global_ids import __build_global_ids -@dataclass(frozen=True) +@dataclass( frozen=True ) class Result: info: str -@dataclass(frozen=True) +@dataclass( frozen=True ) class FieldInfo: name: str dimension: int support: str -@dataclass(frozen=True) +@dataclass( frozen=True ) class Options: vtk_output: VtkOutput generate_cells_global_ids: bool generate_points_global_ids: bool - xs: Sequence[float] - ys: Sequence[float] - zs: Sequence[float] - nxs: Sequence[int] - nys: Sequence[int] - nzs: Sequence[int] - fields: Iterable[FieldInfo] + xs: Sequence[ float ] + ys: Sequence[ float ] + zs: Sequence[ float ] + nxs: Sequence[ int ] + nys: Sequence[ int ] + nzs: Sequence[ int ] + fields: Iterable[ FieldInfo ] -@dataclass(frozen=True) +@dataclass( frozen=True ) class XYZ: x: numpy.ndarray y: numpy.ndarray z: numpy.ndarray -def build_rectilinear_blocks_mesh(xyzs: Iterable[XYZ]) -> vtkUnstructuredGrid: +def build_rectilinear_blocks_mesh( xyzs: Iterable[ XYZ ] ) -> vtkUnstructuredGrid: """ Builds an unstructured vtk grid from the `xyzs` blocks. Kind of InternalMeshGenerator. :param xyzs: The blocks. @@ -68,44 +65,44 @@ def build_rectilinear_blocks_mesh(xyzs: Iterable[XYZ]) -> vtkUnstructuredGrid: rgs = [] for xyz in xyzs: rg = vtkRectilinearGrid() - rg.SetDimensions(len(xyz.x), len(xyz.y), len(xyz.z)) - rg.SetXCoordinates(numpy_to_vtk(xyz.x)) - rg.SetYCoordinates(numpy_to_vtk(xyz.y)) - rg.SetZCoordinates(numpy_to_vtk(xyz.z)) - rgs.append(rg) + rg.SetDimensions( len( xyz.x ), len( xyz.y ), len( xyz.z ) ) + rg.SetXCoordinates( numpy_to_vtk( xyz.x ) ) + rg.SetYCoordinates( numpy_to_vtk( xyz.y ) ) + rg.SetZCoordinates( numpy_to_vtk( xyz.z ) ) + rgs.append( rg ) - num_points = sum(map(lambda r: r.GetNumberOfPoints(), rgs)) - num_cells = sum(map(lambda r: r.GetNumberOfCells(), rgs)) + num_points = sum( map( lambda r: r.GetNumberOfPoints(), rgs ) ) + num_cells = sum( map( lambda r: r.GetNumberOfCells(), rgs ) ) points = vtkPoints() - points.Allocate(num_points) + points.Allocate( num_points ) for rg in rgs: - for i in range(rg.GetNumberOfPoints()): - points.InsertNextPoint(rg.GetPoint(i)) + for i in range( rg.GetNumberOfPoints() ): + points.InsertNextPoint( rg.GetPoint( i ) ) - cell_types = [VTK_HEXAHEDRON] * num_cells + cell_types = [ VTK_HEXAHEDRON ] * num_cells cells = vtkCellArray() - cells.AllocateExact(num_cells, num_cells * 8) + cells.AllocateExact( num_cells, num_cells * 8 ) - m = (0, 1, 3, 2, 4, 5, 7, 6) # VTK_VOXEL and VTK_HEXAHEDRON do not share the same ordering. + m = ( 0, 1, 3, 2, 4, 5, 7, 6 ) # VTK_VOXEL and VTK_HEXAHEDRON do not share the same ordering. offset = 0 for rg in rgs: - for i in range(rg.GetNumberOfCells()): - c = rg.GetCell(i) + for i in range( rg.GetNumberOfCells() ): + c = rg.GetCell( i ) new_cell = vtkHexahedron() - for j in range(8): - new_cell.GetPointIds().SetId(j, offset + c.GetPointId(m[j])) - cells.InsertNextCell(new_cell) + for j in range( 8 ): + new_cell.GetPointIds().SetId( j, offset + c.GetPointId( m[ j ] ) ) + cells.InsertNextCell( new_cell ) offset += rg.GetNumberOfPoints() mesh = vtkUnstructuredGrid() - mesh.SetPoints(points) - mesh.SetCells(cell_types, cells) + mesh.SetPoints( points ) + mesh.SetCells( cell_types, cells ) return mesh -def __add_fields(mesh: vtkUnstructuredGrid, fields: Iterable[FieldInfo]) -> vtkUnstructuredGrid: +def __add_fields( mesh: vtkUnstructuredGrid, fields: Iterable[ FieldInfo ] ) -> vtkUnstructuredGrid: for field_info in fields: if field_info.support == "CELLS": data = mesh.GetCellData() @@ -113,48 +110,50 @@ def __add_fields(mesh: vtkUnstructuredGrid, fields: Iterable[FieldInfo]) -> vtkU elif field_info.support == "POINTS": data = mesh.GetPointData() n = mesh.GetNumberOfPoints() - array = numpy.ones((n, field_info.dimension), dtype=float) - vtk_array = numpy_to_vtk(array) - vtk_array.SetName(field_info.name) - data.AddArray(vtk_array) + array = numpy.ones( ( n, field_info.dimension ), dtype=float ) + vtk_array = numpy_to_vtk( array ) + vtk_array.SetName( field_info.name ) + data.AddArray( vtk_array ) return mesh -def __build(options: Options): - def build_coordinates(positions, num_elements): +def __build( options: Options ): + + def build_coordinates( positions, num_elements ): result = [] - it = zip(zip(positions, positions[1:]), num_elements) + it = zip( zip( positions, positions[ 1: ] ), num_elements ) try: - coords, n = next(it) + coords, n = next( it ) while True: start, stop = coords end_point = False - tmp = numpy.linspace(start=start, stop=stop, num=n+end_point, endpoint=end_point) - coords, n = next(it) - result.append(tmp) + tmp = numpy.linspace( start=start, stop=stop, num=n + end_point, endpoint=end_point ) + coords, n = next( it ) + result.append( tmp ) except StopIteration: end_point = True - tmp = numpy.linspace(start=start, stop=stop, num=n+end_point, endpoint=end_point) - result.append(tmp) - return numpy.concatenate(result) - x = build_coordinates(options.xs, options.nxs) - y = build_coordinates(options.ys, options.nys) - z = build_coordinates(options.zs, options.nzs) - cube = build_rectilinear_blocks_mesh((XYZ(x, y, z),)) - cube = __add_fields(cube, options.fields) - __build_global_ids(cube, options.generate_cells_global_ids, options.generate_points_global_ids) + tmp = numpy.linspace( start=start, stop=stop, num=n + end_point, endpoint=end_point ) + result.append( tmp ) + return numpy.concatenate( result ) + + x = build_coordinates( options.xs, options.nxs ) + y = build_coordinates( options.ys, options.nys ) + z = build_coordinates( options.zs, options.nzs ) + cube = build_rectilinear_blocks_mesh( ( XYZ( x, y, z ), ) ) + cube = __add_fields( cube, options.fields ) + __build_global_ids( cube, options.generate_cells_global_ids, options.generate_points_global_ids ) return cube -def __check(options: Options) -> Result: - output_mesh = __build(options) - vtk_utils.write_mesh(output_mesh, options.vtk_output) - return Result(info=f"Mesh was written to {options.vtk_output.output}") +def __check( options: Options ) -> Result: + output_mesh = __build( options ) + vtk_utils.write_mesh( output_mesh, options.vtk_output ) + return Result( info=f"Mesh was written to {options.vtk_output.output}" ) -def check(vtk_input_file: str, options: Options) -> Result: +def check( vtk_input_file: str, options: Options ) -> Result: try: - return __check(options) + return __check( options ) except BaseException as e: - logging.error(e) - return Result(info="Something went wrong.") + logging.error( e ) + return Result( info="Something went wrong." ) diff --git a/geosx_mesh_doctor/checks/generate_fractures.py b/geosx_mesh_doctor/checks/generate_fractures.py index 22fbadc..c21325d 100644 --- a/geosx_mesh_doctor/checks/generate_fractures.py +++ b/geosx_mesh_doctor/checks/generate_fractures.py @@ -44,132 +44,129 @@ to_vtk_id_list, ) from .vtk_polyhedron import ( - FaceStream, -) + FaceStream, ) -class FracturePolicy(Enum): +class FracturePolicy( Enum ): FIELD = 0 INTERNAL_SURFACES = 1 -@dataclass(frozen=True) +@dataclass( frozen=True ) class Options: policy: FracturePolicy field: str - field_values: FrozenSet[int] + field_values: FrozenSet[ int ] vtk_output: VtkOutput vtk_fracture_output: VtkOutput -@dataclass(frozen=True) +@dataclass( frozen=True ) class Result: info: str -@dataclass(frozen=True) +@dataclass( frozen=True ) class FractureInfo: - node_to_cells: Mapping[int, Iterable[int]] # For each _fracture_ node, gives all the cells that use this node. - face_nodes: Iterable[Collection[int]] # For each fracture face, returns the nodes of this face + node_to_cells: Mapping[ int, Iterable[ int ] ] # For each _fracture_ node, gives all the cells that use this node. + face_nodes: Iterable[ Collection[ int ] ] # For each fracture face, returns the nodes of this face -def build_node_to_cells(mesh: vtkUnstructuredGrid, - face_nodes: Iterable[Iterable[int]]) -> Mapping[int, Iterable[int]]: - node_to_cells: Dict[int, Set[int]] = defaultdict(set) # TODO normally, just a list and not a set should be enough. +def build_node_to_cells( mesh: vtkUnstructuredGrid, + face_nodes: Iterable[ Iterable[ int ] ] ) -> Mapping[ int, Iterable[ int ] ]: + node_to_cells: Dict[ int, + Set[ int ] ] = defaultdict( set ) # TODO normally, just a list and not a set should be enough. - fracture_nodes: Set[int] = set() + fracture_nodes: Set[ int ] = set() for fns in face_nodes: for n in fns: - fracture_nodes.add(n) + fracture_nodes.add( n ) - for cell_id in tqdm(range(mesh.GetNumberOfCells()), desc="Computing the node to cells mapping"): - cell_points: FrozenSet[int] = frozenset(vtk_iter(mesh.GetCell(cell_id).GetPointIds())) - intersection: Iterable[int] = cell_points & fracture_nodes + for cell_id in tqdm( range( mesh.GetNumberOfCells() ), desc="Computing the node to cells mapping" ): + cell_points: FrozenSet[ int ] = frozenset( vtk_iter( mesh.GetCell( cell_id ).GetPointIds() ) ) + intersection: Iterable[ int ] = cell_points & fracture_nodes for node in intersection: - node_to_cells[node].add(cell_id) + node_to_cells[ node ].add( cell_id ) return node_to_cells -def __build_fracture_info_from_fields(mesh: vtkUnstructuredGrid, - f: Sequence[int], - field_values: FrozenSet[int]) -> FractureInfo: - cells_to_faces: Dict[int, List[int]] = defaultdict(list) +def __build_fracture_info_from_fields( mesh: vtkUnstructuredGrid, f: Sequence[ int ], + field_values: FrozenSet[ int ] ) -> FractureInfo: + cells_to_faces: Dict[ int, List[ int ] ] = defaultdict( list ) # For each face of each cell, we search for the unique neighbor cell (if it exists). # Then, if the 2 values of the two cells match the field requirements, # we store the cell and its local face index: this is indeed part of the surface that we'll need to be split. cell: vtkCell - for cell_id in tqdm(range(mesh.GetNumberOfCells()), desc="Computing the cell to faces mapping"): - if f[cell_id] not in field_values: # No need to consider a cell if its field value is not in the target range. + for cell_id in tqdm( range( mesh.GetNumberOfCells() ), desc="Computing the cell to faces mapping" ): + if f[ cell_id ] not in field_values: # No need to consider a cell if its field value is not in the target range. continue - cell = mesh.GetCell(cell_id) - for i in range(cell.GetNumberOfFaces()): + cell = mesh.GetCell( cell_id ) + for i in range( cell.GetNumberOfFaces() ): neighbor_cell_ids = vtkIdList() - mesh.GetCellNeighbors(cell_id, cell.GetFace(i).GetPointIds(), neighbor_cell_ids) + mesh.GetCellNeighbors( cell_id, cell.GetFace( i ).GetPointIds(), neighbor_cell_ids ) assert neighbor_cell_ids.GetNumberOfIds() < 2 - for j in range(neighbor_cell_ids.GetNumberOfIds()): # It's 0 or 1... - neighbor_cell_id = neighbor_cell_ids.GetId(j) - if f[neighbor_cell_id] != f[cell_id] and f[neighbor_cell_id] in field_values: - cells_to_faces[cell_id].append(i) # TODO add this (cell_is, face_id) information to the fracture_info? - face_nodes: List[Collection[int]] = list() - face_nodes_hashes: Set[FrozenSet[int]] = set() # A temporary not to add multiple times the same face. - for cell_id, faces_ids in tqdm(cells_to_faces.items(), desc="Extracting the faces of the fractures"): - cell = mesh.GetCell(cell_id) + for j in range( neighbor_cell_ids.GetNumberOfIds() ): # It's 0 or 1... + neighbor_cell_id = neighbor_cell_ids.GetId( j ) + if f[ neighbor_cell_id ] != f[ cell_id ] and f[ neighbor_cell_id ] in field_values: + cells_to_faces[ cell_id ].append( + i ) # TODO add this (cell_is, face_id) information to the fracture_info? + face_nodes: List[ Collection[ int ] ] = list() + face_nodes_hashes: Set[ FrozenSet[ int ] ] = set() # A temporary not to add multiple times the same face. + for cell_id, faces_ids in tqdm( cells_to_faces.items(), desc="Extracting the faces of the fractures" ): + cell = mesh.GetCell( cell_id ) for face_id in faces_ids: - fn: Collection[int] = tuple(vtk_iter(cell.GetFace(face_id).GetPointIds())) - fnh = frozenset(fn) + fn: Collection[ int ] = tuple( vtk_iter( cell.GetFace( face_id ).GetPointIds() ) ) + fnh = frozenset( fn ) if fnh not in face_nodes_hashes: - face_nodes_hashes.add(fnh) - face_nodes.append(fn) - node_to_cells: Mapping[int, Iterable[int]] = build_node_to_cells(mesh, face_nodes) + face_nodes_hashes.add( fnh ) + face_nodes.append( fn ) + node_to_cells: Mapping[ int, Iterable[ int ] ] = build_node_to_cells( mesh, face_nodes ) - return FractureInfo(node_to_cells=node_to_cells, face_nodes=face_nodes) + return FractureInfo( node_to_cells=node_to_cells, face_nodes=face_nodes ) -def __build_fracture_info_from_internal_surfaces(mesh: vtkUnstructuredGrid, - f: Sequence[int], - field_values: FrozenSet[int]) -> FractureInfo: - node_to_cells: Dict[int, List[int]] = {} - face_nodes: List[Collection[int]] = [] - for cell_id in tqdm(range(mesh.GetNumberOfCells()), desc="Computing the face to nodes mapping"): - cell = mesh.GetCell(cell_id) +def __build_fracture_info_from_internal_surfaces( mesh: vtkUnstructuredGrid, f: Sequence[ int ], + field_values: FrozenSet[ int ] ) -> FractureInfo: + node_to_cells: Dict[ int, List[ int ] ] = {} + face_nodes: List[ Collection[ int ] ] = [] + for cell_id in tqdm( range( mesh.GetNumberOfCells() ), desc="Computing the face to nodes mapping" ): + cell = mesh.GetCell( cell_id ) if cell.GetCellDimension() == 2: - if f[cell_id] in field_values: + if f[ cell_id ] in field_values: nodes = [] - for v in range(cell.GetNumberOfPoints()): - point_id: int = cell.GetPointId(v) - node_to_cells[point_id] = [] - nodes.append(point_id) - face_nodes.append(tuple(nodes)) - - for cell_id in tqdm(range(mesh.GetNumberOfCells()), desc="Computing the node to cells mapping"): - cell = mesh.GetCell(cell_id) + for v in range( cell.GetNumberOfPoints() ): + point_id: int = cell.GetPointId( v ) + node_to_cells[ point_id ] = [] + nodes.append( point_id ) + face_nodes.append( tuple( nodes ) ) + + for cell_id in tqdm( range( mesh.GetNumberOfCells() ), desc="Computing the node to cells mapping" ): + cell = mesh.GetCell( cell_id ) if cell.GetCellDimension() == 3: - for v in range(cell.GetNumberOfPoints()): - if cell.GetPointId(v) in node_to_cells: - node_to_cells[cell.GetPointId(v)].append(cell_id) + for v in range( cell.GetNumberOfPoints() ): + if cell.GetPointId( v ) in node_to_cells: + node_to_cells[ cell.GetPointId( v ) ].append( cell_id ) - return FractureInfo(node_to_cells=node_to_cells, face_nodes=face_nodes) + return FractureInfo( node_to_cells=node_to_cells, face_nodes=face_nodes ) -def build_fracture_info(mesh: vtkUnstructuredGrid, - options: Options) -> FractureInfo: +def build_fracture_info( mesh: vtkUnstructuredGrid, options: Options ) -> FractureInfo: field = options.field field_values = options.field_values cell_data = mesh.GetCellData() - if cell_data.HasArray(field): - f = vtk_to_numpy(cell_data.GetArray(field)) + if cell_data.HasArray( field ): + f = vtk_to_numpy( cell_data.GetArray( field ) ) else: - raise ValueError(f"Cell field {field} does not exist in mesh, nothing done") + raise ValueError( f"Cell field {field} does not exist in mesh, nothing done" ) if options.policy == FracturePolicy.FIELD: - return __build_fracture_info_from_fields(mesh, f, field_values) + return __build_fracture_info_from_fields( mesh, f, field_values ) elif options.policy == FracturePolicy.INTERNAL_SURFACES: - return __build_fracture_info_from_internal_surfaces(mesh, f, field_values) + return __build_fracture_info_from_internal_surfaces( mesh, f, field_values ) -def build_cell_to_cell_graph(mesh: vtkUnstructuredGrid, - fracture: FractureInfo) -> networkx.Graph: +def build_cell_to_cell_graph( mesh: vtkUnstructuredGrid, fracture: FractureInfo ) -> networkx.Graph: """ Connects all the cells that touch the fracture by at least one node. Two cells are connected when they share at least a face which is not a face of the fracture. @@ -180,39 +177,38 @@ def build_cell_to_cell_graph(mesh: vtkUnstructuredGrid, """ # Faces are identified by their nodes. But the order of those nodes may vary while referring to the same face. # Therefore we compute some kinds of hashes of those face to easily detect if a face is part of the fracture. - tmp: List[FrozenSet[int]] = [] + tmp: List[ FrozenSet[ int ] ] = [] for fn in fracture.face_nodes: - tmp.append(frozenset(fn)) - face_hashes: FrozenSet[FrozenSet[int]] = frozenset(tmp) + tmp.append( frozenset( fn ) ) + face_hashes: FrozenSet[ FrozenSet[ int ] ] = frozenset( tmp ) # We extract the list of the cells that touch the fracture by at least one node. - cells: Set[int] = set() + cells: Set[ int ] = set() for cell_ids in fracture.node_to_cells.values(): for cell_id in cell_ids: - cells.add(cell_id) + cells.add( cell_id ) # Using the last precomputed containers, we're now building the dict which connects # every face (hash) of the fracture to the cells that touch the face... - face_to_cells: Dict[FrozenSet[int], List[int]] = defaultdict(list) - for cell_id in tqdm(cells, desc="Computing the cell to cell graph"): - cell: vtkCell = mesh.GetCell(cell_id) - for face_id in range(cell.GetNumberOfFaces()): - face_hash: FrozenSet[int] = frozenset(vtk_iter(cell.GetFace(face_id).GetPointIds())) + face_to_cells: Dict[ FrozenSet[ int ], List[ int ] ] = defaultdict( list ) + for cell_id in tqdm( cells, desc="Computing the cell to cell graph" ): + cell: vtkCell = mesh.GetCell( cell_id ) + for face_id in range( cell.GetNumberOfFaces() ): + face_hash: FrozenSet[ int ] = frozenset( vtk_iter( cell.GetFace( face_id ).GetPointIds() ) ) if face_hash not in face_hashes: - face_to_cells[face_hash].append(cell_id) + face_to_cells[ face_hash ].append( cell_id ) # ... eventually, when a face touches two cells, this means that those two cells share the same face # and should be connected in the final cell to cell graph. cell_to_cell = networkx.Graph() - cell_to_cell.add_nodes_from(cells) - cell_to_cell.add_edges_from(filter(lambda cs: len(cs) == 2, face_to_cells.values())) + cell_to_cell.add_nodes_from( cells ) + cell_to_cell.add_edges_from( filter( lambda cs: len( cs ) == 2, face_to_cells.values() ) ) return cell_to_cell -def __identify_split(num_points: int, - cell_to_cell: networkx.Graph, - node_to_cells: Mapping[int, Iterable[int]]) -> Mapping[int, Mapping[int, int]]: +def __identify_split( num_points: int, cell_to_cell: networkx.Graph, + node_to_cells: Mapping[ int, Iterable[ int ] ] ) -> Mapping[ int, Mapping[ int, int ] ]: """ For each cell, compute the node indices replacements. :param num_points: Number of points in the whole mesh (not the fracture). @@ -229,34 +225,35 @@ class NewIndex: Note that the first time an index is met, the index itself is returned: we do not want to change an index if we do not have to. """ - def __init__(self, num_nodes: int): + + def __init__( self, num_nodes: int ): self.__current_last_index = num_nodes - 1 - self.__seen: Set[int] = set() + self.__seen: Set[ int ] = set() - def __call__(self, index: int) -> int: + def __call__( self, index: int ) -> int: if index in self.__seen: self.__current_last_index += 1 return self.__current_last_index else: - self.__seen.add(index) + self.__seen.add( index ) return index - build_new_index = NewIndex(num_points) - result: Dict[int, Dict[int, int]] = defaultdict(dict) - for node, cells in tqdm(sorted(node_to_cells.items()), # Iteration over `sorted` nodes to have a predictable result for tests. - desc="Identifying the node splits"): - for connected_cells in networkx.connected_components(cell_to_cell.subgraph(cells)): + build_new_index = NewIndex( num_points ) + result: Dict[ int, Dict[ int, int ] ] = defaultdict( dict ) + for node, cells in tqdm( + sorted( node_to_cells.items() ), # Iteration over `sorted` nodes to have a predictable result for tests. + desc="Identifying the node splits" ): + for connected_cells in networkx.connected_components( cell_to_cell.subgraph( cells ) ): # Each group of connect cells need around `node` must consider the same `node`. # Separate groups must have different (duplicated) nodes. - new_index: int = build_new_index(node) + new_index: int = build_new_index( node ) for cell in connected_cells: - result[cell][node] = new_index + result[ cell ][ node ] = new_index return result -def __copy_fields(old_mesh: vtkUnstructuredGrid, - new_mesh: vtkUnstructuredGrid, - collocated_nodes: Sequence[int]) -> None: +def __copy_fields( old_mesh: vtkUnstructuredGrid, new_mesh: vtkUnstructuredGrid, + collocated_nodes: Sequence[ int ] ) -> None: """ Copies the fields from the old mesh to the new one. Point data will be duplicated for collocated nodes. @@ -268,62 +265,62 @@ def __copy_fields(old_mesh: vtkUnstructuredGrid, # Copying the cell data. # The cells are the same, just their nodes support have changed. input_cell_data = old_mesh.GetCellData() - for i in range(input_cell_data.GetNumberOfArrays()): - input_array = input_cell_data.GetArray(i) - logging.info(f"Copying cell field \"{input_array.GetName()}\".") - new_mesh.GetCellData().AddArray(input_array) + for i in range( input_cell_data.GetNumberOfArrays() ): + input_array = input_cell_data.GetArray( i ) + logging.info( f"Copying cell field \"{input_array.GetName()}\"." ) + new_mesh.GetCellData().AddArray( input_array ) # Copying field data. This data is a priori not related to geometry. input_field_data = old_mesh.GetFieldData() - for i in range(input_field_data.GetNumberOfArrays()): - input_array = input_field_data.GetArray(i) - logging.info(f"Copying field data \"{input_array.GetName()}\".") - new_mesh.GetFieldData().AddArray(input_array) + for i in range( input_field_data.GetNumberOfArrays() ): + input_array = input_field_data.GetArray( i ) + logging.info( f"Copying field data \"{input_array.GetName()}\"." ) + new_mesh.GetFieldData().AddArray( input_array ) # Copying the point data. input_point_data = old_mesh.GetPointData() - for i in range(input_point_data.GetNumberOfArrays()): - input_array = input_point_data.GetArray(i) - logging.info(f"Copying point field \"{input_array.GetName()}\"") + for i in range( input_point_data.GetNumberOfArrays() ): + input_array = input_point_data.GetArray( i ) + logging.info( f"Copying point field \"{input_array.GetName()}\"" ) tmp = input_array.NewInstance() - tmp.SetName(input_array.GetName()) - tmp.SetNumberOfComponents(input_array.GetNumberOfComponents()) - tmp.SetNumberOfTuples(new_mesh.GetNumberOfPoints()) - for p in range(tmp.GetNumberOfTuples()): - tmp.SetTuple(p, input_array.GetTuple(collocated_nodes[p])) - new_mesh.GetPointData().AddArray(tmp) + tmp.SetName( input_array.GetName() ) + tmp.SetNumberOfComponents( input_array.GetNumberOfComponents() ) + tmp.SetNumberOfTuples( new_mesh.GetNumberOfPoints() ) + for p in range( tmp.GetNumberOfTuples() ): + tmp.SetTuple( p, input_array.GetTuple( collocated_nodes[ p ] ) ) + new_mesh.GetPointData().AddArray( tmp ) -def __perform_split(old_mesh: vtkUnstructuredGrid, - cell_to_node_mapping: Mapping[int, Mapping[int, int]]) -> vtkUnstructuredGrid: +def __perform_split( old_mesh: vtkUnstructuredGrid, + cell_to_node_mapping: Mapping[ int, Mapping[ int, int ] ] ) -> vtkUnstructuredGrid: """ Split the main 3d mesh based on the node duplication information contained in @p cell_to_node_mapping :param old_mesh: The main 3d mesh. :param cell_to_node_mapping: For each cell, gives the nodes that must be duplicated and their new index. :return: The main 3d mesh split at the fracture location. """ - added_points: Set[int] = set() + added_points: Set[ int ] = set() for node_mapping in cell_to_node_mapping.values(): for i, o in node_mapping.items(): if i != o: - added_points.add(o) - num_new_points: int = old_mesh.GetNumberOfPoints() + len(added_points) + added_points.add( o ) + num_new_points: int = old_mesh.GetNumberOfPoints() + len( added_points ) # Creating the new points for the new mesh. old_points: vtkPoints = old_mesh.GetPoints() new_points = vtkPoints() - new_points.SetNumberOfPoints(num_new_points) - collocated_nodes = numpy.ones(num_new_points, dtype=int) * -1 + new_points.SetNumberOfPoints( num_new_points ) + collocated_nodes = numpy.ones( num_new_points, dtype=int ) * -1 # Copying old points into the new container. - for p in range(old_points.GetNumberOfPoints()): - new_points.SetPoint(p, old_points.GetPoint(p)) - collocated_nodes[p] = p + for p in range( old_points.GetNumberOfPoints() ): + new_points.SetPoint( p, old_points.GetPoint( p ) ) + collocated_nodes[ p ] = p # Creating the new collocated/duplicated points based on the old points positions. for node_mapping in cell_to_node_mapping.values(): for i, o in node_mapping.items(): if i != o: - new_points.SetPoint(o, old_points.GetPoint(i)) - collocated_nodes[o] = i + new_points.SetPoint( o, old_points.GetPoint( i ) ) + collocated_nodes[ o ] = i collocated_nodes.flags.writeable = False # We are creating a new mesh. @@ -335,43 +332,42 @@ def __perform_split(old_mesh: vtkUnstructuredGrid, # Maybe in the future using a `DeepCopy` of the vtkCellArray can be considered? # The cell point ids could be modified in place then. new_mesh = old_mesh.NewInstance() - new_mesh.SetPoints(new_points) - new_mesh.Allocate(old_mesh.GetNumberOfCells()) + new_mesh.SetPoints( new_points ) + new_mesh.Allocate( old_mesh.GetNumberOfCells() ) - for c in tqdm(range(old_mesh.GetNumberOfCells()), desc="Performing the mesh split"): - node_mapping: Mapping[int, int] = cell_to_node_mapping.get(c, {}) - cell: vtkCell = old_mesh.GetCell(c) + for c in tqdm( range( old_mesh.GetNumberOfCells() ), desc="Performing the mesh split" ): + node_mapping: Mapping[ int, int ] = cell_to_node_mapping.get( c, {} ) + cell: vtkCell = old_mesh.GetCell( c ) cell_type: int = cell.GetCellType() # For polyhedron, we'll manipulate the face stream directly. if cell_type == VTK_POLYHEDRON: face_stream = vtkIdList() - old_mesh.GetFaceStream(c, face_stream) - new_face_nodes: List[List[int]] = [] - for face_nodes in FaceStream.build_from_vtk_id_list(face_stream).face_nodes: + old_mesh.GetFaceStream( c, face_stream ) + new_face_nodes: List[ List[ int ] ] = [] + for face_nodes in FaceStream.build_from_vtk_id_list( face_stream ).face_nodes: new_point_ids = [] for current_point_id in face_nodes: - new_point_id: int = node_mapping.get(current_point_id, current_point_id) - new_point_ids.append(new_point_id) - new_face_nodes.append(new_point_ids) - new_mesh.InsertNextCell(cell_type, to_vtk_id_list(FaceStream(new_face_nodes).dump())) + new_point_id: int = node_mapping.get( current_point_id, current_point_id ) + new_point_ids.append( new_point_id ) + new_face_nodes.append( new_point_ids ) + new_mesh.InsertNextCell( cell_type, to_vtk_id_list( FaceStream( new_face_nodes ).dump() ) ) else: # For the standard cells, we extract the point ids of the cell directly. # Then the values will be (potentially) overwritten in place, before being sent back into the cell. cell_point_ids: vtkIdList = cell.GetPointIds() - for i in range(cell_point_ids.GetNumberOfIds()): - current_point_id: int = cell_point_ids.GetId(i) - new_point_id: int = node_mapping.get(current_point_id, current_point_id) - cell_point_ids.SetId(i, new_point_id) - new_mesh.InsertNextCell(cell_type, cell_point_ids) + for i in range( cell_point_ids.GetNumberOfIds() ): + current_point_id: int = cell_point_ids.GetId( i ) + new_point_id: int = node_mapping.get( current_point_id, current_point_id ) + cell_point_ids.SetId( i, new_point_id ) + new_mesh.InsertNextCell( cell_type, cell_point_ids ) - __copy_fields(old_mesh, new_mesh, collocated_nodes) + __copy_fields( old_mesh, new_mesh, collocated_nodes ) return new_mesh -def __generate_fracture_mesh(mesh_points: vtkPoints, - fracture_info: FractureInfo, - cell_to_node_mapping: Mapping[int, Mapping[int, int]]) -> vtkUnstructuredGrid: +def __generate_fracture_mesh( mesh_points: vtkPoints, fracture_info: FractureInfo, + cell_to_node_mapping: Mapping[ int, Mapping[ int, int ] ] ) -> vtkUnstructuredGrid: """ Generates the mesh of the fracture. :param mesh_points: The points of the main 3d mesh. @@ -379,104 +375,105 @@ def __generate_fracture_mesh(mesh_points: vtkPoints, :param cell_to_node_mapping: For each cell, gives the nodes that must be duplicated and their new index. :return: The fracture mesh. """ - logging.info("Generating the meshes") + logging.info( "Generating the meshes" ) - is_node_duplicated = numpy.zeros(mesh_points.GetNumberOfPoints(), dtype=bool) # defaults to False + is_node_duplicated = numpy.zeros( mesh_points.GetNumberOfPoints(), dtype=bool ) # defaults to False for node_mapping in cell_to_node_mapping.values(): for i, o in node_mapping.items(): - if not is_node_duplicated[i]: - is_node_duplicated[i] = i != o + if not is_node_duplicated[ i ]: + is_node_duplicated[ i ] = i != o # Some elements can have all their nodes not duplicated. # In this case, it's mandatory not get rid of this element # because the neighboring 3d elements won't follow. - face_nodes: List[Collection[int]] = [] - discarded_face_nodes: Set[Iterable[int]] = set() + face_nodes: List[ Collection[ int ] ] = [] + discarded_face_nodes: Set[ Iterable[ int ] ] = set() for ns in fracture_info.face_nodes: - if any(map(is_node_duplicated.__getitem__, ns)): - face_nodes.append(ns) + if any( map( is_node_duplicated.__getitem__, ns ) ): + face_nodes.append( ns ) else: - discarded_face_nodes.add(ns) + discarded_face_nodes.add( ns ) if discarded_face_nodes: # tmp = [] # for dfns in discarded_face_nodes: # tmp.append(", ".join(map(str, dfns))) - msg: str = "(" + '), ('.join(map(lambda dfns: ", ".join(map(str, dfns)), discarded_face_nodes)) + ")" + msg: str = "(" + '), ('.join( map( lambda dfns: ", ".join( map( str, dfns ) ), discarded_face_nodes ) ) + ")" # logging.info(f"The {len(tmp)} faces made of nodes ({'), ('.join(tmp)}) were/was discarded from the fracture mesh because none of their/its nodes were duplicated.") # print(f"The {len(tmp)} faces made of nodes ({'), ('.join(tmp)}) were/was discarded from the fracture mesh because none of their/its nodes were duplicated.") - print(f"The faces made of nodes [{msg}] were/was discarded from the fracture mesh because none of their/its nodes were duplicated.") + print( + f"The faces made of nodes [{msg}] were/was discarded from the fracture mesh because none of their/its nodes were duplicated." + ) - fracture_nodes_tmp = numpy.ones(mesh_points.GetNumberOfPoints(), dtype=int) * -1 + fracture_nodes_tmp = numpy.ones( mesh_points.GetNumberOfPoints(), dtype=int ) * -1 for ns in face_nodes: for n in ns: - fracture_nodes_tmp[n] = n - fracture_nodes: Collection[int] = tuple(filter(lambda n: n > -1, fracture_nodes_tmp)) - num_points: int = len(fracture_nodes) + fracture_nodes_tmp[ n ] = n + fracture_nodes: Collection[ int ] = tuple( filter( lambda n: n > -1, fracture_nodes_tmp ) ) + num_points: int = len( fracture_nodes ) points = vtkPoints() - points.SetNumberOfPoints(num_points) - node_3d_to_node_2d: Dict[int, int] = {} # Building the node mapping, from 3d mesh nodes to 2d fracture nodes. - for i, n in enumerate(fracture_nodes): - coords: Tuple[float, float, float] = mesh_points.GetPoint(n) - points.SetPoint(i, coords) - node_3d_to_node_2d[n] = i + points.SetNumberOfPoints( num_points ) + node_3d_to_node_2d: Dict[ int, int ] = {} # Building the node mapping, from 3d mesh nodes to 2d fracture nodes. + for i, n in enumerate( fracture_nodes ): + coords: Tuple[ float, float, float ] = mesh_points.GetPoint( n ) + points.SetPoint( i, coords ) + node_3d_to_node_2d[ n ] = i polygons = vtkCellArray() for ns in face_nodes: polygon = vtkPolygon() - polygon.GetPointIds().SetNumberOfIds(len(ns)) - for i, n in enumerate(ns): - polygon.GetPointIds().SetId(i, node_3d_to_node_2d[n]) - polygons.InsertNextCell(polygon) + polygon.GetPointIds().SetNumberOfIds( len( ns ) ) + for i, n in enumerate( ns ): + polygon.GetPointIds().SetId( i, node_3d_to_node_2d[ n ] ) + polygons.InsertNextCell( polygon ) - buckets: Dict[int, Set[int]] = defaultdict(set) + buckets: Dict[ int, Set[ int ] ] = defaultdict( set ) for node_mapping in cell_to_node_mapping.values(): for i, o in node_mapping.items(): - k: Optional[int] = node_3d_to_node_2d.get(min(i, o)) + k: Optional[ int ] = node_3d_to_node_2d.get( min( i, o ) ) if k is not None: - buckets[k].update((i, o)) + buckets[ k ].update( ( i, o ) ) - assert set(buckets.keys()) == set(range(num_points)) - max_collocated_nodes: int = max(map(len, buckets.values())) if buckets.values() else 0 - collocated_nodes = numpy.ones((num_points, max_collocated_nodes), dtype=int) * -1 + assert set( buckets.keys() ) == set( range( num_points ) ) + max_collocated_nodes: int = max( map( len, buckets.values() ) ) if buckets.values() else 0 + collocated_nodes = numpy.ones( ( num_points, max_collocated_nodes ), dtype=int ) * -1 for i, bucket in buckets.items(): - for j, val in enumerate(bucket): - collocated_nodes[i, j] = val - array = numpy_to_vtk(collocated_nodes, array_type=VTK_ID_TYPE) - array.SetName("collocated_nodes") + for j, val in enumerate( bucket ): + collocated_nodes[ i, j ] = val + array = numpy_to_vtk( collocated_nodes, array_type=VTK_ID_TYPE ) + array.SetName( "collocated_nodes" ) fracture_mesh = vtkUnstructuredGrid() # We could be using vtkPolyData, but it's not supported by GEOS for now. - fracture_mesh.SetPoints(points) + fracture_mesh.SetPoints( points ) if polygons.GetNumberOfCells() > 0: - fracture_mesh.SetCells([VTK_POLYGON] * polygons.GetNumberOfCells(), polygons) - fracture_mesh.GetPointData().AddArray(array) + fracture_mesh.SetCells( [ VTK_POLYGON ] * polygons.GetNumberOfCells(), polygons ) + fracture_mesh.GetPointData().AddArray( array ) return fracture_mesh -def __split_mesh_on_fracture(mesh: vtkUnstructuredGrid, - options: Options) -> Tuple[vtkUnstructuredGrid, vtkUnstructuredGrid]: - fracture: FractureInfo = build_fracture_info(mesh, options) - cell_to_cell: networkx.Graph = build_cell_to_cell_graph(mesh, fracture) - cell_to_node_mapping: Mapping[int, Mapping[int, int]] = __identify_split(mesh.GetNumberOfPoints(), - cell_to_cell, - fracture.node_to_cells) - output_mesh: vtkUnstructuredGrid = __perform_split(mesh, cell_to_node_mapping) - fractured_mesh: vtkUnstructuredGrid = __generate_fracture_mesh(mesh.GetPoints(), fracture, cell_to_node_mapping) +def __split_mesh_on_fracture( mesh: vtkUnstructuredGrid, + options: Options ) -> Tuple[ vtkUnstructuredGrid, vtkUnstructuredGrid ]: + fracture: FractureInfo = build_fracture_info( mesh, options ) + cell_to_cell: networkx.Graph = build_cell_to_cell_graph( mesh, fracture ) + cell_to_node_mapping: Mapping[ int, Mapping[ int, int ] ] = __identify_split( mesh.GetNumberOfPoints(), + cell_to_cell, fracture.node_to_cells ) + output_mesh: vtkUnstructuredGrid = __perform_split( mesh, cell_to_node_mapping ) + fractured_mesh: vtkUnstructuredGrid = __generate_fracture_mesh( mesh.GetPoints(), fracture, cell_to_node_mapping ) return output_mesh, fractured_mesh -def __check(mesh, options: Options) -> Result: - output_mesh, fracture_mesh = __split_mesh_on_fracture(mesh, options) - vtk_utils.write_mesh(output_mesh, options.vtk_output) - vtk_utils.write_mesh(fracture_mesh, options.vtk_fracture_output) +def __check( mesh, options: Options ) -> Result: + output_mesh, fracture_mesh = __split_mesh_on_fracture( mesh, options ) + vtk_utils.write_mesh( output_mesh, options.vtk_output ) + vtk_utils.write_mesh( fracture_mesh, options.vtk_fracture_output ) # TODO provide statistics about what was actually performed (size of the fracture, number of split nodes...). - return Result(info="OK") + return Result( info="OK" ) -def check(vtk_input_file: str, options: Options) -> Result: +def check( vtk_input_file: str, options: Options ) -> Result: try: - mesh = vtk_utils.read_mesh(vtk_input_file) - return __check(mesh, options) + mesh = vtk_utils.read_mesh( vtk_input_file ) + return __check( mesh, options ) except BaseException as e: - logging.error(e) - return Result(info="Something went wrong") + logging.error( e ) + return Result( info="Something went wrong" ) diff --git a/geosx_mesh_doctor/checks/generate_global_ids.py b/geosx_mesh_doctor/checks/generate_global_ids.py index 80474e2..1bf9cf2 100644 --- a/geosx_mesh_doctor/checks/generate_global_ids.py +++ b/geosx_mesh_doctor/checks/generate_global_ids.py @@ -6,25 +6,22 @@ from . import vtk_utils from .vtk_utils import ( - VtkOutput, -) + VtkOutput, ) -@dataclass(frozen=True) +@dataclass( frozen=True ) class Options: vtk_output: VtkOutput generate_cells_global_ids: bool generate_points_global_ids: bool -@dataclass(frozen=True) +@dataclass( frozen=True ) class Result: info: str -def __build_global_ids(mesh, - generate_cells_global_ids: bool, - generate_points_global_ids: bool) -> None: +def __build_global_ids( mesh, generate_cells_global_ids: bool, generate_points_global_ids: bool ) -> None: """ Adds the global ids for cells and points in place into the mesh instance. :param mesh: @@ -33,36 +30,36 @@ def __build_global_ids(mesh, # Building GLOBAL_IDS for points and cells.g GLOBAL_IDS for points and cells. # First for points... if mesh.GetPointData().GetGlobalIds(): - logging.error("Mesh already has globals ids for points; nothing done.") + logging.error( "Mesh already has globals ids for points; nothing done." ) elif generate_points_global_ids: point_global_ids = vtkIdTypeArray() - point_global_ids.SetName("GLOBAL_IDS_POINTS") - point_global_ids.Allocate(mesh.GetNumberOfPoints()) - for i in range(mesh.GetNumberOfPoints()): - point_global_ids.InsertNextValue(i) - mesh.GetPointData().SetGlobalIds(point_global_ids) + point_global_ids.SetName( "GLOBAL_IDS_POINTS" ) + point_global_ids.Allocate( mesh.GetNumberOfPoints() ) + for i in range( mesh.GetNumberOfPoints() ): + point_global_ids.InsertNextValue( i ) + mesh.GetPointData().SetGlobalIds( point_global_ids ) # ... then for cells. if mesh.GetCellData().GetGlobalIds(): - logging.error("Mesh already has globals ids for cells; nothing done.") + logging.error( "Mesh already has globals ids for cells; nothing done." ) elif generate_cells_global_ids: cells_global_ids = vtkIdTypeArray() - cells_global_ids.SetName("GLOBAL_IDS_CELLS") - cells_global_ids.Allocate(mesh.GetNumberOfCells()) - for i in range(mesh.GetNumberOfCells()): - cells_global_ids.InsertNextValue(i) - mesh.GetCellData().SetGlobalIds(cells_global_ids) + cells_global_ids.SetName( "GLOBAL_IDS_CELLS" ) + cells_global_ids.Allocate( mesh.GetNumberOfCells() ) + for i in range( mesh.GetNumberOfCells() ): + cells_global_ids.InsertNextValue( i ) + mesh.GetCellData().SetGlobalIds( cells_global_ids ) -def __check(mesh, options: Options) -> Result: - __build_global_ids(mesh, options.generate_cells_global_ids, options.generate_points_global_ids) - vtk_utils.write_mesh(mesh, options.vtk_output) - return Result(info=f"Mesh was written to {options.vtk_output.output}") +def __check( mesh, options: Options ) -> Result: + __build_global_ids( mesh, options.generate_cells_global_ids, options.generate_points_global_ids ) + vtk_utils.write_mesh( mesh, options.vtk_output ) + return Result( info=f"Mesh was written to {options.vtk_output.output}" ) -def check(vtk_input_file: str, options: Options) -> Result: +def check( vtk_input_file: str, options: Options ) -> Result: try: - mesh = vtk_utils.read_mesh(vtk_input_file) - return __check(mesh, options) + mesh = vtk_utils.read_mesh( vtk_input_file ) + return __check( mesh, options ) except BaseException as e: - logging.error(e) - return Result(info="Something went wrong.") + logging.error( e ) + return Result( info="Something went wrong." ) diff --git a/geosx_mesh_doctor/checks/non_conformal.py b/geosx_mesh_doctor/checks/non_conformal.py index 43f26e2..1706222 100644 --- a/geosx_mesh_doctor/checks/non_conformal.py +++ b/geosx_mesh_doctor/checks/non_conformal.py @@ -21,14 +21,11 @@ vtkUnstructuredGrid, ) from vtkmodules.vtkCommonTransforms import ( - vtkTransform, -) + vtkTransform, ) from vtkmodules.vtkFiltersCore import ( - vtkPolyDataNormals, -) + vtkPolyDataNormals, ) from vtkmodules.vtkFiltersGeometry import ( - vtkDataSetSurfaceFilter, -) + vtkDataSetSurfaceFilter, ) from vtkmodules.vtkFiltersModeling import ( vtkCollisionDetectionFilter, vtkLinearExtrusionFilter, @@ -40,22 +37,21 @@ from . import vtk_utils from .vtk_polyhedron import ( - vtk_iter, -) + vtk_iter, ) from . import triangle_distance -@dataclass(frozen=True) +@dataclass( frozen=True ) class Options: angle_tolerance: float point_tolerance: float face_tolerance: float -@dataclass(frozen=True) +@dataclass( frozen=True ) class Result: - non_conformal_cells: List[Tuple[int, int]] + non_conformal_cells: List[ Tuple[ int, int ] ] class BoundaryMesh: @@ -67,32 +63,36 @@ class BoundaryMesh: Therefore, we reorient the polyhedron cells ourselves, so we're sure that they point outwards. And then we compute the boundary meshes for both meshes, given that the computing options are not identical. """ - def __init__(self, mesh: vtkUnstructuredGrid): + + def __init__( self, mesh: vtkUnstructuredGrid ): """ Builds a boundary mesh. :param mesh: The 3d mesh. """ # Building the boundary meshes - boundary_mesh, __normals, self.__original_cells = BoundaryMesh.__build_boundary_mesh(mesh) - cells_to_reorient = filter(lambda c: mesh.GetCell(c).GetCellType() == VTK_POLYHEDRON, - map(self.__original_cells.GetValue, - range(self.__original_cells.GetNumberOfValues()))) - reoriented_mesh = reorient_mesh(mesh, cells_to_reorient) - self.re_boundary_mesh, re_normals, _ = BoundaryMesh.__build_boundary_mesh(reoriented_mesh, consistency=False) + boundary_mesh, __normals, self.__original_cells = BoundaryMesh.__build_boundary_mesh( mesh ) + cells_to_reorient = filter( + lambda c: mesh.GetCell( c ).GetCellType() == VTK_POLYHEDRON, + map( self.__original_cells.GetValue, range( self.__original_cells.GetNumberOfValues() ) ) ) + reoriented_mesh = reorient_mesh( mesh, cells_to_reorient ) + self.re_boundary_mesh, re_normals, _ = BoundaryMesh.__build_boundary_mesh( reoriented_mesh, consistency=False ) num_cells = boundary_mesh.GetNumberOfCells() # Precomputing the underlying cell type - self.__is_underlying_cell_type_a_polyhedron = numpy.zeros(num_cells, dtype=bool) - for ic in range(num_cells): - self.__is_underlying_cell_type_a_polyhedron[ic] = mesh.GetCell(self.__original_cells.GetValue(ic)).GetCellType() == VTK_POLYHEDRON + self.__is_underlying_cell_type_a_polyhedron = numpy.zeros( num_cells, dtype=bool ) + for ic in range( num_cells ): + self.__is_underlying_cell_type_a_polyhedron[ ic ] = mesh.GetCell( + self.__original_cells.GetValue( ic ) ).GetCellType() == VTK_POLYHEDRON # Precomputing the normals - self.__normals: numpy.ndarray = numpy.empty((num_cells, 3), dtype=numpy.double, order='C') # Do not modify the storage layout - for ic in range(num_cells): - if self.__is_underlying_cell_type_a_polyhedron[ic]: - self.__normals[ic, :] = re_normals.GetTuple3(ic) + self.__normals: numpy.ndarray = numpy.empty( ( num_cells, 3 ), dtype=numpy.double, + order='C' ) # Do not modify the storage layout + for ic in range( num_cells ): + if self.__is_underlying_cell_type_a_polyhedron[ ic ]: + self.__normals[ ic, : ] = re_normals.GetTuple3( ic ) else: - self.__normals[ic, :] = __normals.GetTuple3(ic) + self.__normals[ ic, : ] = __normals.GetTuple3( ic ) + @staticmethod - def __build_boundary_mesh(mesh: vtkUnstructuredGrid, consistency=True) -> Tuple[vtkUnstructuredGrid, Any, Any]: + def __build_boundary_mesh( mesh: vtkUnstructuredGrid, consistency=True ) -> Tuple[ vtkUnstructuredGrid, Any, Any ]: """ From a 3d mesh, build the envelope meshes. :param mesh: The input 3d mesh. @@ -107,76 +107,76 @@ def __build_boundary_mesh(mesh: vtkUnstructuredGrid, consistency=True) -> Tuple[ # Note that we do not need the original points, but we could keep them as well if needed original_cells_key = "ORIGINAL_CELLS" - f.SetOriginalCellIdsName(original_cells_key) + f.SetOriginalCellIdsName( original_cells_key ) boundary_mesh = vtkPolyData() - f.UnstructuredGridExecute(mesh, boundary_mesh) + f.UnstructuredGridExecute( mesh, boundary_mesh ) n = vtkPolyDataNormals() - n.SetConsistency(consistency) - n.SetAutoOrientNormals(consistency) + n.SetConsistency( consistency ) + n.SetAutoOrientNormals( consistency ) n.FlipNormalsOff() n.ComputeCellNormalsOn() - n.SetInputData(boundary_mesh) + n.SetInputData( boundary_mesh ) n.Update() - normals = n.GetOutput().GetCellData().GetArray("Normals") + normals = n.GetOutput().GetCellData().GetArray( "Normals" ) assert normals assert normals.GetNumberOfComponents() == 3 assert normals.GetNumberOfTuples() == boundary_mesh.GetNumberOfCells() - original_cells = boundary_mesh.GetCellData().GetArray(original_cells_key) + original_cells = boundary_mesh.GetCellData().GetArray( original_cells_key ) assert original_cells return boundary_mesh, normals, original_cells - def GetNumberOfCells(self) -> int: + def GetNumberOfCells( self ) -> int: """ The number of cells. :return: An integer. """ return self.re_boundary_mesh.GetNumberOfCells() - def GetNumberOfPoints(self) -> int: + def GetNumberOfPoints( self ) -> int: """ The number of points. :return: An integer. """ return self.re_boundary_mesh.GetNumberOfPoints() - def bounds(self, i) -> Tuple[float, float, float, float, float, float]: + def bounds( self, i ) -> Tuple[ float, float, float, float, float, float ]: """ The boundrary box of cell `i`. :param i: The boundary cell index. :return: The vtk bounding box. """ - return self.re_boundary_mesh.GetCell(i).GetBounds() + return self.re_boundary_mesh.GetCell( i ).GetBounds() - def normals(self, i) -> numpy.ndarray: + def normals( self, i ) -> numpy.ndarray: """ The normal of cell `i`. This normal will be directed outwards :param i: The boundary cell index. :return: The normal as a length-3 numpy array. """ - return self.__normals[i] + return self.__normals[ i ] - def GetCell(self, i) -> vtkCell: + def GetCell( self, i ) -> vtkCell: """ Cell i of the boundary mesh. This cell will have its normal directed outwards. :param i: The boundary cell index. :return: The cell instance. :warning: This member function relies on the vtkUnstructuredGrid.GetCell member function which is not thread safe. """ - return self.re_boundary_mesh.GetCell(i) + return self.re_boundary_mesh.GetCell( i ) - def GetPoint(self, i) -> Tuple[float, float, float]: + def GetPoint( self, i ) -> Tuple[ float, float, float ]: """ Point i of the boundary mesh. :param i: The boundary point index. :return: A length-3 tuple containing the coordinates of the point. :warning: This member function relies on the vtkUnstructuredGrid.GetPoint member function which is not thread safe. """ - return self.re_boundary_mesh.GetPoint(i) + return self.re_boundary_mesh.GetPoint( i ) @property - def original_cells(self): + def original_cells( self ): """ Returns the 2d boundary cell to the 3d cell index of the original mesh. :return: A 1d array. @@ -184,7 +184,7 @@ def original_cells(self): return self.__original_cells -def build_poly_data_for_extrusion(i: int, boundary_mesh: BoundaryMesh) -> vtkPolyData: +def build_poly_data_for_extrusion( i: int, boundary_mesh: BoundaryMesh ) -> vtkPolyData: """ Creates a vtkPolyData containing the unique cell `i` of the boundary mesh. This operation is needed to use the vtk extrusion filter. @@ -192,26 +192,26 @@ def build_poly_data_for_extrusion(i: int, boundary_mesh: BoundaryMesh) -> vtkPol :param boundary_mesh: :return: The created vtkPolyData. """ - cell = boundary_mesh.GetCell(i) + cell = boundary_mesh.GetCell( i ) copied_cell = cell.NewInstance() - copied_cell.DeepCopy(cell) + copied_cell.DeepCopy( cell ) points_ids_mapping = [] - for i in range(copied_cell.GetNumberOfPoints()): - copied_cell.GetPointIds().SetId(i, i) - points_ids_mapping.append(cell.GetPointId(i)) + for i in range( copied_cell.GetNumberOfPoints() ): + copied_cell.GetPointIds().SetId( i, i ) + points_ids_mapping.append( cell.GetPointId( i ) ) polygons = vtkCellArray() - polygons.InsertNextCell(copied_cell) + polygons.InsertNextCell( copied_cell ) points = vtkPoints() - points.SetNumberOfPoints(len(points_ids_mapping)) - for i, v in enumerate(points_ids_mapping): - points.SetPoint(i, boundary_mesh.GetPoint(v)) + points.SetNumberOfPoints( len( points_ids_mapping ) ) + for i, v in enumerate( points_ids_mapping ): + points.SetPoint( i, boundary_mesh.GetPoint( v ) ) polygon_poly_data = vtkPolyData() - polygon_poly_data.SetPoints(points) - polygon_poly_data.SetPolys(polygons) + polygon_poly_data.SetPoints( points ) + polygon_poly_data.SetPolys( polygons ) return polygon_poly_data -def are_points_conformal(point_tolerance: float, cell_i: vtkCell, cell_j: vtkCell) -> bool: +def are_points_conformal( point_tolerance: float, cell_i: vtkCell, cell_j: vtkCell ) -> bool: """ Checks if points of cell `i` matches, one by one, the points of cell `j`. :param point_tolerance: The point tolerance to consider that two points match. @@ -225,16 +225,16 @@ def are_points_conformal(point_tolerance: float, cell_i: vtkCell, cell_j: vtkCel point_locator = vtkStaticPointLocator() points = vtkPointSet() - points.SetPoints(cell_i.GetPoints()) - point_locator.SetDataSet(points) + points.SetPoints( cell_i.GetPoints() ) + point_locator.SetDataSet( points ) point_locator.BuildLocator() found_points = set() - for ip in range(cell_j.GetNumberOfPoints()): - p = cell_j.GetPoints().GetPoint(ip) - squared_dist = vtk_reference(0.) # unused - found_point = point_locator.FindClosestPointWithinRadius(point_tolerance, p, squared_dist) - found_points.add(found_point) - return found_points == set(range(cell_i.GetNumberOfPoints())) + for ip in range( cell_j.GetNumberOfPoints() ): + p = cell_j.GetPoints().GetPoint( ip ) + squared_dist = vtk_reference( 0. ) # unused + found_point = point_locator.FindClosestPointWithinRadius( point_tolerance, p, squared_dist ) + found_points.add( found_point ) + return found_points == set( range( cell_i.GetNumberOfPoints() ) ) class Extruder: @@ -242,12 +242,15 @@ class Extruder: Computes and stores all the extrusions of the boundary faces. The main reason for this class is to be lazy and cache the extrusions. """ - def __init__(self, boundary_mesh: BoundaryMesh, face_tolerance: float): - self.__extrusions: List[vtkPolyData] = [None, ] * boundary_mesh.GetNumberOfCells() + + def __init__( self, boundary_mesh: BoundaryMesh, face_tolerance: float ): + self.__extrusions: List[ vtkPolyData ] = [ + None, + ] * boundary_mesh.GetNumberOfCells() self.__boundary_mesh = boundary_mesh self.__face_tolerance = face_tolerance - def __extrude(self, polygon_poly_data, normal) -> vtkPolyData: + def __extrude( self, polygon_poly_data, normal ) -> vtkPolyData: """ Extrude the polygon data to create a volume that will be used for intersection. :param polygon_poly_data: The data to extrude @@ -256,31 +259,29 @@ def __extrude(self, polygon_poly_data, normal) -> vtkPolyData: """ extruder = vtkLinearExtrusionFilter() extruder.SetExtrusionTypeToVectorExtrusion() - extruder.SetVector(normal) - extruder.SetScaleFactor(self.__face_tolerance / 2.) - extruder.SetInputData(polygon_poly_data) + extruder.SetVector( normal ) + extruder.SetScaleFactor( self.__face_tolerance / 2. ) + extruder.SetInputData( polygon_poly_data ) extruder.Update() return extruder.GetOutput() - def __getitem__(self, i) -> vtkPolyData: + def __getitem__( self, i ) -> vtkPolyData: """ Returns the vtk extrusion for boundary element i. :param i: The cell index. :return: The vtk instance. """ - extrusion = self.__extrusions[i] + extrusion = self.__extrusions[ i ] if extrusion: return extrusion - extrusion = self.__extrude(build_poly_data_for_extrusion(i, self.__boundary_mesh), - self.__boundary_mesh.normals(i)) - self.__extrusions[i] = extrusion + extrusion = self.__extrude( build_poly_data_for_extrusion( i, self.__boundary_mesh ), + self.__boundary_mesh.normals( i ) ) + self.__extrusions[ i ] = extrusion return extrusion -def are_faces_conformal_using_extrusions(extrusions: Extruder, - i: int, j: int, - boundary_mesh: vtkUnstructuredGrid, - point_tolerance: float) -> bool: +def are_faces_conformal_using_extrusions( extrusions: Extruder, i: int, j: int, boundary_mesh: vtkUnstructuredGrid, + point_tolerance: float ) -> bool: """ Tests if two boundary faces are conformal, checking for intersection between their normal extruded volumes. :param extrusions: The extrusions cache. @@ -292,28 +293,27 @@ def are_faces_conformal_using_extrusions(extrusions: Extruder, """ collision = vtkCollisionDetectionFilter() collision.SetCollisionModeToFirstContact() - collision.SetInputData(0, extrusions[i]) - collision.SetInputData(1, extrusions[j]) + collision.SetInputData( 0, extrusions[ i ] ) + collision.SetInputData( 1, extrusions[ j ] ) m_i = vtkTransform() m_j = vtkTransform() - collision.SetTransform(0, m_i) - collision.SetTransform(1, m_j) + collision.SetTransform( 0, m_i ) + collision.SetTransform( 1, m_j ) collision.Update() if collision.GetNumberOfContacts() == 0: return True # Duplicating data not to risk anything w.r.t. thread safety of the GetCell function. - cell_i = boundary_mesh.GetCell(i) + cell_i = boundary_mesh.GetCell( i ) copied_cell_i = cell_i.NewInstance() - copied_cell_i.DeepCopy(cell_i) + copied_cell_i.DeepCopy( cell_i ) - return are_points_conformal(point_tolerance, copied_cell_i, boundary_mesh.GetCell(j)) + return are_points_conformal( point_tolerance, copied_cell_i, boundary_mesh.GetCell( j ) ) -def are_faces_conformal_using_distances(i: int, j: int, - boundary_mesh: vtkUnstructuredGrid, - face_tolerance: float, point_tolerance: float) -> bool: +def are_faces_conformal_using_distances( i: int, j: int, boundary_mesh: vtkUnstructuredGrid, face_tolerance: float, + point_tolerance: float ) -> bool: """ Tests if two boundary faces are conformal, checking the minimal distance between triangulated surfaces. :param i: The cell index of the first cell. @@ -323,42 +323,42 @@ def are_faces_conformal_using_distances(i: int, j: int, :param point_tolerance: The point tolerance to consider that two points match. :return: A boolean. """ - cp_i = boundary_mesh.GetCell(i).NewInstance() - cp_i.DeepCopy(boundary_mesh.GetCell(i)) - cp_j = boundary_mesh.GetCell(j).NewInstance() - cp_j.DeepCopy(boundary_mesh.GetCell(j)) + cp_i = boundary_mesh.GetCell( i ).NewInstance() + cp_i.DeepCopy( boundary_mesh.GetCell( i ) ) + cp_j = boundary_mesh.GetCell( j ).NewInstance() + cp_j.DeepCopy( boundary_mesh.GetCell( j ) ) - def triangulate(cell): + def triangulate( cell ): assert cell.GetCellDimension() == 2 __points_ids = vtkIdList() __points = vtkPoints() - cell.Triangulate(0, __points_ids, __points) - __points_ids = tuple(vtk_iter(__points_ids)) - assert len(__points_ids) % 3 == 0 + cell.Triangulate( 0, __points_ids, __points ) + __points_ids = tuple( vtk_iter( __points_ids ) ) + assert len( __points_ids ) % 3 == 0 assert __points.GetNumberOfPoints() % 3 == 0 return __points_ids, __points - points_ids_i, points_i = triangulate(cp_i) - points_ids_j, points_j = triangulate(cp_j) + points_ids_i, points_i = triangulate( cp_i ) + points_ids_j, points_j = triangulate( cp_j ) - def build_numpy_triangles(points_ids): + def build_numpy_triangles( points_ids ): __triangles = [] - for __i in range(0, len(points_ids), 3): + for __i in range( 0, len( points_ids ), 3 ): __t = [] - for __pi in points_ids[__i: __i + 3]: - __t.append(boundary_mesh.GetPoint(__pi)) - __triangles.append(numpy.array(__t, dtype=float)) + for __pi in points_ids[ __i:__i + 3 ]: + __t.append( boundary_mesh.GetPoint( __pi ) ) + __triangles.append( numpy.array( __t, dtype=float ) ) return __triangles - triangles_i = build_numpy_triangles(points_ids_i) - triangles_j = build_numpy_triangles(points_ids_j) + triangles_i = build_numpy_triangles( points_ids_i ) + triangles_j = build_numpy_triangles( points_ids_j ) min_dist = numpy.inf - for ti, tj in [(ti, tj) for ti in triangles_i for tj in triangles_j]: + for ti, tj in [ ( ti, tj ) for ti in triangles_i for tj in triangles_j ]: # Note that here, we compute the exact distance to compare with the threshold. # We could improve by exiting the iterative distance computation as soon as # we're sure we're smaller than the threshold. No need of the exact solution. - dist, _, _ = triangle_distance.distance_between_two_triangles(ti, tj) + dist, _, _ = triangle_distance.distance_between_two_triangles( ti, tj ) if dist < min_dist: min_dist = dist if min_dist < face_tolerance: @@ -366,67 +366,68 @@ def build_numpy_triangles(points_ids): if min_dist > face_tolerance: return True - return are_points_conformal(point_tolerance, cp_i, cp_j) + return are_points_conformal( point_tolerance, cp_i, cp_j ) -def __check(mesh: vtkUnstructuredGrid, options: Options) -> Result: +def __check( mesh: vtkUnstructuredGrid, options: Options ) -> Result: """ Checks if the mesh is "conformal" (i.e. if some of its boundary faces may not be too close to each other without matching nodes). :param mesh: The vtk mesh :param options: The check options. :return: The Result instance. """ - boundary_mesh = BoundaryMesh(mesh) - cos_theta = abs(math.cos(numpy.deg2rad(options.angle_tolerance))) + boundary_mesh = BoundaryMesh( mesh ) + cos_theta = abs( math.cos( numpy.deg2rad( options.angle_tolerance ) ) ) num_cells = boundary_mesh.GetNumberOfCells() # Computing the exact number of cells per node - num_cells_per_node = numpy.zeros(boundary_mesh.GetNumberOfPoints(), dtype=int) - for ic in range(boundary_mesh.GetNumberOfCells()): - c = boundary_mesh.GetCell(ic) + num_cells_per_node = numpy.zeros( boundary_mesh.GetNumberOfPoints(), dtype=int ) + for ic in range( boundary_mesh.GetNumberOfCells() ): + c = boundary_mesh.GetCell( ic ) point_ids = c.GetPointIds() - for point_id in vtk_iter(point_ids): - num_cells_per_node[point_id] += 1 + for point_id in vtk_iter( point_ids ): + num_cells_per_node[ point_id ] += 1 cell_locator = vtkStaticCellLocator() cell_locator.Initialize() - cell_locator.SetNumberOfCellsPerNode(num_cells_per_node.max()) - cell_locator.SetDataSet(boundary_mesh.re_boundary_mesh) + cell_locator.SetNumberOfCellsPerNode( num_cells_per_node.max() ) + cell_locator.SetDataSet( boundary_mesh.re_boundary_mesh ) cell_locator.BuildLocator() # Precomputing the bounding boxes. # The options are important to directly interact with memory in C++. - bounding_boxes = numpy.empty((boundary_mesh.GetNumberOfCells(), 6), dtype=numpy.double, order="C") - for i in range(boundary_mesh.GetNumberOfCells()): - bb = vtkBoundingBox(boundary_mesh.bounds(i)) - bb.Inflate(2 * options.face_tolerance) - assert bounding_boxes[i, :].data.contiguous # Do not modify the storage layout since vtk deals with raw memory here. - bb.GetBounds(bounding_boxes[i, :]) + bounding_boxes = numpy.empty( ( boundary_mesh.GetNumberOfCells(), 6 ), dtype=numpy.double, order="C" ) + for i in range( boundary_mesh.GetNumberOfCells() ): + bb = vtkBoundingBox( boundary_mesh.bounds( i ) ) + bb.Inflate( 2 * options.face_tolerance ) + assert bounding_boxes[ + i, : ].data.contiguous # Do not modify the storage layout since vtk deals with raw memory here. + bb.GetBounds( bounding_boxes[ i, : ] ) non_conformal_cells = [] - extrusions = Extruder(boundary_mesh, options.face_tolerance) + extrusions = Extruder( boundary_mesh, options.face_tolerance ) close_cells = vtkIdList() # Looping on all the pairs of boundary cells. We'll hopefully discard most of the pairs. - for i in tqdm(range(num_cells), desc="Non conformal elements"): - cell_locator.FindCellsWithinBounds(bounding_boxes[i], close_cells) - for j in vtk_iter(close_cells): + for i in tqdm( range( num_cells ), desc="Non conformal elements" ): + cell_locator.FindCellsWithinBounds( bounding_boxes[ i ], close_cells ) + for j in vtk_iter( close_cells ): if j < i: continue # Discarding pairs that are not facing each others (with a threshold). - normal_i, normal_j = boundary_mesh.normals(i), boundary_mesh.normals(j) - if numpy.dot(normal_i, normal_j) > -cos_theta: # opposite directions only (can be facing or not) + normal_i, normal_j = boundary_mesh.normals( i ), boundary_mesh.normals( j ) + if numpy.dot( normal_i, normal_j ) > -cos_theta: # opposite directions only (can be facing or not) continue # At this point, back-to-back and face-to-face pairs of elements are considered. - if not are_faces_conformal_using_extrusions(extrusions, i, j, boundary_mesh, options.point_tolerance): - non_conformal_cells.append((i, j)) + if not are_faces_conformal_using_extrusions( extrusions, i, j, boundary_mesh, options.point_tolerance ): + non_conformal_cells.append( ( i, j ) ) # Extracting the original 3d element index (and not the index of the boundary mesh). tmp = [] for i, j in non_conformal_cells: - tmp.append((boundary_mesh.original_cells.GetValue(i), boundary_mesh.original_cells.GetValue(j))) + tmp.append( ( boundary_mesh.original_cells.GetValue( i ), boundary_mesh.original_cells.GetValue( j ) ) ) - return Result(non_conformal_cells=tmp) + return Result( non_conformal_cells=tmp ) -def check(vtk_input_file: str, options: Options) -> Result: - mesh = vtk_utils.read_mesh(vtk_input_file) - return __check(mesh, options) +def check( vtk_input_file: str, options: Options ) -> Result: + mesh = vtk_utils.read_mesh( vtk_input_file ) + return __check( mesh, options ) diff --git a/geosx_mesh_doctor/checks/reorient_mesh.py b/geosx_mesh_doctor/checks/reorient_mesh.py index efb664b..a206439 100644 --- a/geosx_mesh_doctor/checks/reorient_mesh.py +++ b/geosx_mesh_doctor/checks/reorient_mesh.py @@ -27,11 +27,9 @@ vtkTetra, ) from vtkmodules.vtkFiltersCore import ( - vtkTriangleFilter, -) + vtkTriangleFilter, ) from .vtk_utils import ( - to_vtk_id_list, -) + to_vtk_id_list, ) from .vtk_polyhedron import ( FaceStream, @@ -39,7 +37,7 @@ ) -def __compute_volume(mesh_points: vtkPoints, face_stream: FaceStream) -> float: +def __compute_volume( mesh_points: vtkPoints, face_stream: FaceStream ) -> float: """ Computes the volume of a polyhedron element (defined by its face_stream). :param mesh_points: The mesh points, needed to compute the volume. @@ -54,44 +52,43 @@ def __compute_volume(mesh_points: vtkPoints, face_stream: FaceStream) -> float: polygons = vtkCellArray() for face_nodes in face_stream.face_nodes: polygon = vtkPolygon() - polygon.GetPointIds().SetNumberOfIds(len(face_nodes)) + polygon.GetPointIds().SetNumberOfIds( len( face_nodes ) ) # We use the same global points numbering for the polygons than for the input mesh. # There will be a lot of points in the poly data that won't be used as a support for the polygons. # But the algorithm deals with it, and it's actually faster (and easier) to do this # than to renumber and allocate a new fit-for-purpose set of points just for the polygons. - for i, point_id in enumerate(face_nodes): - polygon.GetPointIds().SetId(i, point_id) - polygons.InsertNextCell(polygon) + for i, point_id in enumerate( face_nodes ): + polygon.GetPointIds().SetId( i, point_id ) + polygons.InsertNextCell( polygon ) polygon_poly_data = vtkPolyData() - polygon_poly_data.SetPoints(mesh_points) - polygon_poly_data.SetPolys(polygons) + polygon_poly_data.SetPoints( mesh_points ) + polygon_poly_data.SetPolys( polygons ) f = vtkTriangleFilter() - f.SetInputData(polygon_poly_data) + f.SetInputData( polygon_poly_data ) f.Update() triangles = f.GetOutput() # Computing the barycenter that will be used as the tip of all the tetra which mesh the polyhedron. # (The basis of all the tetra being the triangles of the envelope). # We could take any point, not only the barycenter. # But in order to work with figure of the same magnitude, let's compute the barycenter. - tmp_barycenter = numpy.empty((face_stream.num_support_points, 3), dtype=float) - for i, point_id in enumerate(face_stream.support_point_ids): - tmp_barycenter[i, :] = mesh_points.GetPoint(point_id) - barycenter = tmp_barycenter[:, 0].mean(), tmp_barycenter[:, 1].mean(), tmp_barycenter[:, 2].mean() + tmp_barycenter = numpy.empty( ( face_stream.num_support_points, 3 ), dtype=float ) + for i, point_id in enumerate( face_stream.support_point_ids ): + tmp_barycenter[ i, : ] = mesh_points.GetPoint( point_id ) + barycenter = tmp_barycenter[ :, 0 ].mean(), tmp_barycenter[ :, 1 ].mean(), tmp_barycenter[ :, 2 ].mean() # Looping on all the triangles of the envelope of the polyhedron, creating the matching tetra. # Then the volume of all the tetra are added to get the final polyhedron volume. cell_volume = 0. - for i in range(triangles.GetNumberOfCells()): - triangle = triangles.GetCell(i) + for i in range( triangles.GetNumberOfCells() ): + triangle = triangles.GetCell( i ) assert triangle.GetCellType() == VTK_TRIANGLE p = triangle.GetPoints() - cell_volume += vtkTetra.ComputeVolume(barycenter, p.GetPoint(0), p.GetPoint(1), p.GetPoint(2)) + cell_volume += vtkTetra.ComputeVolume( barycenter, p.GetPoint( 0 ), p.GetPoint( 1 ), p.GetPoint( 2 ) ) return cell_volume -def __select_and_flip_faces(mesh_points: vtkPoints, - colors: Dict[FrozenSet[int], int], - face_stream: FaceStream) -> FaceStream: +def __select_and_flip_faces( mesh_points: vtkPoints, colors: Dict[ FrozenSet[ int ], int ], + face_stream: FaceStream ) -> FaceStream: """ Given a polyhedra, given that we were able to paint the faces in two colors, we now need to select which faces/color to flip such that the volume of the element is positive. @@ -101,46 +98,47 @@ def __select_and_flip_faces(mesh_points: vtkPoints, :return: The face stream that leads to a positive volume. """ # Flipping either color 0 or 1. - color_to_nodes: Dict[int, List[int]] = {0: [], 1: []} + color_to_nodes: Dict[ int, List[ int ] ] = { 0: [], 1: [] } for connected_components_indices, color in colors.items(): - color_to_nodes[color] += connected_components_indices + color_to_nodes[ color ] += connected_components_indices # This implementation works even if there is one unique color. # Admittedly, there will be one face stream that won't be flipped. - fs: Tuple[FaceStream, FaceStream] = face_stream.flip_faces(color_to_nodes[0]), face_stream.flip_faces(color_to_nodes[1]) - volumes = __compute_volume(mesh_points, fs[0]), __compute_volume(mesh_points, fs[1]) + fs: Tuple[ FaceStream, FaceStream ] = ( face_stream.flip_faces( color_to_nodes[ 0 ] ), + face_stream.flip_faces( color_to_nodes[ 1 ] ) ) + volumes = __compute_volume( mesh_points, fs[ 0 ] ), __compute_volume( mesh_points, fs[ 1 ] ) # We keep the flipped element for which the volume is largest # (i.e. positive, since they should be the opposite of each other). - return fs[numpy.argmax(volumes)] + return fs[ numpy.argmax( volumes ) ] -def __reorient_element(mesh_points: vtkPoints, face_stream_ids: vtkIdList) -> vtkIdList: +def __reorient_element( mesh_points: vtkPoints, face_stream_ids: vtkIdList ) -> vtkIdList: """ Considers a vtk face stream and flips the appropriate faces to get an element with normals directed outwards. :param mesh_points: The mesh points, needed to compute the volume. :param face_stream_ids: The raw vtk face stream, not converted into a more practical python class. :return: The raw vtk face stream with faces properly flipped. """ - face_stream = FaceStream.build_from_vtk_id_list(face_stream_ids) - face_graph = build_face_to_face_connectivity_through_edges(face_stream, add_compatibility=True) + face_stream = FaceStream.build_from_vtk_id_list( face_stream_ids ) + face_graph = build_face_to_face_connectivity_through_edges( face_stream, add_compatibility=True ) # Removing the non-compatible connections to build the non-connected components. g = networkx.Graph() - g.add_nodes_from(face_graph.nodes) - g.add_edges_from(filter(lambda uvd: uvd[2]["compatible"] == "+", face_graph.edges(data=True))) - connected_components = tuple(networkx.connected_components(g)) + g.add_nodes_from( face_graph.nodes ) + g.add_edges_from( filter( lambda uvd: uvd[ 2 ][ "compatible" ] == "+", face_graph.edges( data=True ) ) ) + connected_components = tuple( networkx.connected_components( g ) ) # Squashing all the connected nodes that need to receive the normal direction flip (or not) together. - quotient_graph = networkx.algorithms.quotient_graph(face_graph, connected_components) + quotient_graph = networkx.algorithms.quotient_graph( face_graph, connected_components ) # Coloring the new graph lets us know how which cluster of faces need to eventually receive the same flip. # W.r.t. the nature of our problem (a normal can be directed inwards or outwards), # two colors should be enough to color the face graph. # `colors` maps the nodes of each connected component to its color. - colors: Dict[FrozenSet[int], int] = networkx.algorithms.greedy_color(quotient_graph) - assert len(colors) in (1, 2) + colors: Dict[ FrozenSet[ int ], int ] = networkx.algorithms.greedy_color( quotient_graph ) + assert len( colors ) in ( 1, 2 ) # We now compute the face stream which generates outwards normal vectors. - flipped_face_stream = __select_and_flip_faces(mesh_points, colors, face_stream) - return to_vtk_id_list(flipped_face_stream.dump()) + flipped_face_stream = __select_and_flip_faces( mesh_points, colors, face_stream ) + return to_vtk_id_list( flipped_face_stream.dump() ) -def reorient_mesh(mesh, cell_indices: Iterator[int]) -> vtkUnstructuredGrid: +def reorient_mesh( mesh, cell_indices: Iterator[ int ] ) -> vtkUnstructuredGrid: """ Reorient the polyhedron elements such that they all have their normals directed outwards. :param mesh: The input vtk mesh. @@ -149,30 +147,31 @@ def reorient_mesh(mesh, cell_indices: Iterator[int]) -> vtkUnstructuredGrid: """ num_cells = mesh.GetNumberOfCells() # Building an indicator/predicate from the list - needs_to_be_reoriented = numpy.zeros(num_cells, dtype=bool) + needs_to_be_reoriented = numpy.zeros( num_cells, dtype=bool ) for ic in cell_indices: - needs_to_be_reoriented[ic] = True + needs_to_be_reoriented[ ic ] = True output_mesh = mesh.NewInstance() # I did not manage to call `output_mesh.CopyStructure(mesh)` because I could not modify the polyhedron in place. # Therefore, I insert the cells one by one... - output_mesh.SetPoints(mesh.GetPoints()) - logging.info("Reorienting the polyhedron cells to enforce normals directed outward.") - with tqdm(total=needs_to_be_reoriented.sum(), desc="Reorienting polyhedra") as progress_bar: # For smoother progress, we only update on reoriented elements. - for ic in range(num_cells): - cell = mesh.GetCell(ic) + output_mesh.SetPoints( mesh.GetPoints() ) + logging.info( "Reorienting the polyhedron cells to enforce normals directed outward." ) + with tqdm( total=needs_to_be_reoriented.sum(), desc="Reorienting polyhedra" + ) as progress_bar: # For smoother progress, we only update on reoriented elements. + for ic in range( num_cells ): + cell = mesh.GetCell( ic ) cell_type = cell.GetCellType() if cell_type == VTK_POLYHEDRON: face_stream_ids = vtkIdList() - mesh.GetFaceStream(ic, face_stream_ids) - if needs_to_be_reoriented[ic]: - new_face_stream_ids = __reorient_element(mesh.GetPoints(), face_stream_ids) + mesh.GetFaceStream( ic, face_stream_ids ) + if needs_to_be_reoriented[ ic ]: + new_face_stream_ids = __reorient_element( mesh.GetPoints(), face_stream_ids ) else: new_face_stream_ids = face_stream_ids - output_mesh.InsertNextCell(VTK_POLYHEDRON, new_face_stream_ids) + output_mesh.InsertNextCell( VTK_POLYHEDRON, new_face_stream_ids ) else: - output_mesh.InsertNextCell(cell_type, cell.GetPointIds()) - if needs_to_be_reoriented[ic]: - progress_bar.update(1) + output_mesh.InsertNextCell( cell_type, cell.GetPointIds() ) + if needs_to_be_reoriented[ ic ]: + progress_bar.update( 1 ) assert output_mesh.GetNumberOfCells() == mesh.GetNumberOfCells() return output_mesh diff --git a/geosx_mesh_doctor/checks/self_intersecting_elements.py b/geosx_mesh_doctor/checks/self_intersecting_elements.py index 0e98d4f..2b6c8e0 100644 --- a/geosx_mesh_doctor/checks/self_intersecting_elements.py +++ b/geosx_mesh_doctor/checks/self_intersecting_elements.py @@ -5,40 +5,34 @@ List, ) -from vtkmodules.vtkFiltersGeneral import ( - vtkCellValidator -) -from vtkmodules.vtkCommonCore import ( - vtkOutputWindow, - vtkFileOutputWindow -) +from vtkmodules.vtkFiltersGeneral import ( vtkCellValidator ) +from vtkmodules.vtkCommonCore import ( vtkOutputWindow, vtkFileOutputWindow ) from vtkmodules.util.numpy_support import ( - vtk_to_numpy, -) + vtk_to_numpy, ) from . import vtk_utils -@dataclass(frozen=True) +@dataclass( frozen=True ) class Options: tolerance: float -@dataclass(frozen=True) +@dataclass( frozen=True ) class Result: - wrong_number_of_points_elements: Collection[int] - intersecting_edges_elements: Collection[int] - intersecting_faces_elements: Collection[int] - non_contiguous_edges_elements: Collection[int] - non_convex_elements: Collection[int] - faces_are_oriented_incorrectly_elements: Collection[int] + wrong_number_of_points_elements: Collection[ int ] + intersecting_edges_elements: Collection[ int ] + intersecting_faces_elements: Collection[ int ] + non_contiguous_edges_elements: Collection[ int ] + non_convex_elements: Collection[ int ] + faces_are_oriented_incorrectly_elements: Collection[ int ] -def __check(mesh, options: Options) -> Result: +def __check( mesh, options: Options ) -> Result: err_out = vtkFileOutputWindow() - err_out.SetFileName("/dev/null") # vtkCellValidator outputs loads for each cell... + err_out.SetFileName( "/dev/null" ) # vtkCellValidator outputs loads for each cell... vtk_std_err_out = vtkOutputWindow() - vtk_std_err_out.SetInstance(err_out) + vtk_std_err_out.SetInstance( err_out ) valid = 0x0 wrong_number_of_points = 0x01 @@ -48,45 +42,45 @@ def __check(mesh, options: Options) -> Result: non_convex = 0x10 faces_are_oriented_incorrectly = 0x20 - wrong_number_of_points_elements: List[int] = [] - intersecting_edges_elements: List[int] = [] - intersecting_faces_elements: List[int] = [] - non_contiguous_edges_elements: List[int] = [] - non_convex_elements: List[int] = [] - faces_are_oriented_incorrectly_elements: List[int] = [] + wrong_number_of_points_elements: List[ int ] = [] + intersecting_edges_elements: List[ int ] = [] + intersecting_faces_elements: List[ int ] = [] + non_contiguous_edges_elements: List[ int ] = [] + non_convex_elements: List[ int ] = [] + faces_are_oriented_incorrectly_elements: List[ int ] = [] f = vtkCellValidator() - f.SetTolerance(options.tolerance) + f.SetTolerance( options.tolerance ) - f.SetInputData(mesh) + f.SetInputData( mesh ) f.Update() output = f.GetOutput() - validity = output.GetCellData().GetArray("ValidityState") # Could not change name using the vtk interface. + validity = output.GetCellData().GetArray( "ValidityState" ) # Could not change name using the vtk interface. assert validity is not None - validity = vtk_to_numpy(validity) - for i, v in enumerate(validity): + validity = vtk_to_numpy( validity ) + for i, v in enumerate( validity ): if not v & valid: if v & wrong_number_of_points: - wrong_number_of_points_elements.append(i) + wrong_number_of_points_elements.append( i ) if v & intersecting_edges: - intersecting_edges_elements.append(i) + intersecting_edges_elements.append( i ) if v & intersecting_faces: - intersecting_faces_elements.append(i) + intersecting_faces_elements.append( i ) if v & non_contiguous_edges: - non_contiguous_edges_elements.append(i) + non_contiguous_edges_elements.append( i ) if v & non_convex: - non_convex_elements.append(i) + non_convex_elements.append( i ) if v & faces_are_oriented_incorrectly: - faces_are_oriented_incorrectly_elements.append(i) - return Result(wrong_number_of_points_elements=wrong_number_of_points_elements, - intersecting_edges_elements=intersecting_edges_elements, - intersecting_faces_elements=intersecting_faces_elements, - non_contiguous_edges_elements=non_contiguous_edges_elements, - non_convex_elements=non_convex_elements, - faces_are_oriented_incorrectly_elements=faces_are_oriented_incorrectly_elements) + faces_are_oriented_incorrectly_elements.append( i ) + return Result( wrong_number_of_points_elements=wrong_number_of_points_elements, + intersecting_edges_elements=intersecting_edges_elements, + intersecting_faces_elements=intersecting_faces_elements, + non_contiguous_edges_elements=non_contiguous_edges_elements, + non_convex_elements=non_convex_elements, + faces_are_oriented_incorrectly_elements=faces_are_oriented_incorrectly_elements ) -def check(vtk_input_file: str, options: Options) -> Result: - mesh = vtk_utils.read_mesh(vtk_input_file) - return __check(mesh, options) +def check( vtk_input_file: str, options: Options ) -> Result: + mesh = vtk_utils.read_mesh( vtk_input_file ) + return __check( mesh, options ) diff --git a/geosx_mesh_doctor/checks/supported_elements.py b/geosx_mesh_doctor/checks/supported_elements.py index 84c5fcb..755d240 100644 --- a/geosx_mesh_doctor/checks/supported_elements.py +++ b/geosx_mesh_doctor/checks/supported_elements.py @@ -17,8 +17,7 @@ import numpy from vtkmodules.vtkCommonCore import ( - vtkIdList, -) + vtkIdList, ) from vtkmodules.vtkCommonDataModel import ( vtkCellTypes, vtkUnstructuredGrid, @@ -32,80 +31,82 @@ VTK_WEDGE, ) from vtkmodules.util.numpy_support import ( - vtk_to_numpy, -) + vtk_to_numpy, ) from . import vtk_utils from .vtk_utils import vtk_iter from .vtk_polyhedron import build_face_to_face_connectivity_through_edges, FaceStream -@dataclass(frozen=True) +@dataclass( frozen=True ) class Options: num_proc: int chunk_size: int -@dataclass(frozen=True) +@dataclass( frozen=True ) class Result: - unsupported_std_elements_types: FrozenSet[int] # list of unsupported types - unsupported_polyhedron_elements: FrozenSet[int] # list of polyhedron elements that could not be converted to supported std elements + unsupported_std_elements_types: FrozenSet[ int ] # list of unsupported types + unsupported_polyhedron_elements: FrozenSet[ + int ] # list of polyhedron elements that could not be converted to supported std elements -MESH: Optional[vtkUnstructuredGrid] = None # for multiprocessing, vtkUnstructuredGrid cannot be pickled. Let's use a global variable instead. +MESH: Optional[ + vtkUnstructuredGrid ] = None # for multiprocessing, vtkUnstructuredGrid cannot be pickled. Let's use a global variable instead. class IsPolyhedronConvertible: - def __init__(self, mesh: vtkUnstructuredGrid): + + def __init__( self, mesh: vtkUnstructuredGrid ): global MESH # for multiprocessing, vtkUnstructuredGrid cannot be pickled. Let's use a global variable instead. MESH = mesh - def build_prism_graph(n: int, name: str) -> networkx.Graph: + def build_prism_graph( n: int, name: str ) -> networkx.Graph: """ Builds the face to face connectivities (through edges) for prism graphs. :param n: The number of nodes of the basis (i.e. the pentagonal prims gets n = 5) :param name: A human-readable name for logging purpose. :return: A graph instance. """ - tmp = networkx.cycle_graph(n) - for node in range(n): - tmp.add_edge(node, n) - tmp.add_edge(node, n + 1) + tmp = networkx.cycle_graph( n ) + for node in range( n ): + tmp.add_edge( node, n ) + tmp.add_edge( node, n + 1 ) tmp.name = name return tmp # Building the reference graphs - tet_graph = networkx.complete_graph(4) + tet_graph = networkx.complete_graph( 4 ) tet_graph.name = "Tetrahedron" - pyr_graph = build_prism_graph(4, "Pyramid") - pyr_graph.remove_node(5) # Removing a node also removes its associated edges. - self.__reference_graphs: Mapping[int, Iterable[networkx.Graph]] = { - 4: (tet_graph,), - 5: (pyr_graph, build_prism_graph(3, "Wedge")), - 6: (build_prism_graph(4, "Hexahedron"),), - 7: (build_prism_graph(5, "Prism5"),), - 8: (build_prism_graph(6, "Prism6"),), - 9: (build_prism_graph(7, "Prism7"),), - 10: (build_prism_graph(8, "Prism8"),), - 11: (build_prism_graph(9, "Prism9"),), - 12: (build_prism_graph(10, "Prism10"),), - 13: (build_prism_graph(11, "Prism11"),), + pyr_graph = build_prism_graph( 4, "Pyramid" ) + pyr_graph.remove_node( 5 ) # Removing a node also removes its associated edges. + self.__reference_graphs: Mapping[ int, Iterable[ networkx.Graph ] ] = { + 4: ( tet_graph, ), + 5: ( pyr_graph, build_prism_graph( 3, "Wedge" ) ), + 6: ( build_prism_graph( 4, "Hexahedron" ), ), + 7: ( build_prism_graph( 5, "Prism5" ), ), + 8: ( build_prism_graph( 6, "Prism6" ), ), + 9: ( build_prism_graph( 7, "Prism7" ), ), + 10: ( build_prism_graph( 8, "Prism8" ), ), + 11: ( build_prism_graph( 9, "Prism9" ), ), + 12: ( build_prism_graph( 10, "Prism10" ), ), + 13: ( build_prism_graph( 11, "Prism11" ), ), } - def __is_polyhedron_supported(self, face_stream) -> str: + def __is_polyhedron_supported( self, face_stream ) -> str: """ Checks if a polyhedron can be converted into a supported cell. If so, returns the name of the type. If not, the returned name will be empty. :param face_stream: The polyhedron. :return: The name of the supported type or an empty string. """ - cell_graph = build_face_to_face_connectivity_through_edges(face_stream, add_compatibility=True) - for reference_graph in self.__reference_graphs[cell_graph.order()]: - if networkx.is_isomorphic(reference_graph, cell_graph): - return str(reference_graph.name) + cell_graph = build_face_to_face_connectivity_through_edges( face_stream, add_compatibility=True ) + for reference_graph in self.__reference_graphs[ cell_graph.order() ]: + if networkx.is_isomorphic( reference_graph, cell_graph ): + return str( reference_graph.name ) return "" - def __call__(self, ic: int) -> int: + def __call__( self, ic: int ) -> int: """ Checks if a vtk polyhedron cell can be converted into a supported GEOSX element. :param ic: The index element. @@ -113,51 +114,47 @@ def __call__(self, ic: int) -> int: """ global MESH assert MESH is not None - if MESH.GetCellType(ic) != VTK_POLYHEDRON: + if MESH.GetCellType( ic ) != VTK_POLYHEDRON: return -1 pt_ids = vtkIdList() - MESH.GetFaceStream(ic, pt_ids) - face_stream = FaceStream.build_from_vtk_id_list(pt_ids) - converted_type_name = self.__is_polyhedron_supported(face_stream) + MESH.GetFaceStream( ic, pt_ids ) + face_stream = FaceStream.build_from_vtk_id_list( pt_ids ) + converted_type_name = self.__is_polyhedron_supported( face_stream ) if converted_type_name: - logging.debug(f"Polyhedron cell {ic} can be converted into \"{converted_type_name}\"") + logging.debug( f"Polyhedron cell {ic} can be converted into \"{converted_type_name}\"" ) return -1 else: - logging.debug(f"Polyhedron cell {ic} cannot be converted into any supported element.") + logging.debug( f"Polyhedron cell {ic} cannot be converted into any supported element." ) return ic -def __check(mesh: vtkUnstructuredGrid, options: Options) -> Result: - if hasattr(mesh, "GetDistinctCellTypesArray"): # For more recent versions of vtk. - cell_types = set(vtk_to_numpy(mesh.GetDistinctCellTypesArray())) +def __check( mesh: vtkUnstructuredGrid, options: Options ) -> Result: + if hasattr( mesh, "GetDistinctCellTypesArray" ): # For more recent versions of vtk. + cell_types = set( vtk_to_numpy( mesh.GetDistinctCellTypesArray() ) ) else: cell_types = vtkCellTypes() - mesh.GetCellTypes(cell_types) - cell_types = set(vtk_iter(cell_types)) + mesh.GetCellTypes( cell_types ) + cell_types = set( vtk_iter( cell_types ) ) supported_cell_types = { - VTK_HEXAGONAL_PRISM, - VTK_HEXAHEDRON, - VTK_PENTAGONAL_PRISM, - VTK_POLYHEDRON, - VTK_PYRAMID, - VTK_TETRA, - VTK_VOXEL, + VTK_HEXAGONAL_PRISM, VTK_HEXAHEDRON, VTK_PENTAGONAL_PRISM, VTK_POLYHEDRON, VTK_PYRAMID, VTK_TETRA, VTK_VOXEL, VTK_WEDGE } unsupported_std_elements_types = cell_types - supported_cell_types # Dealing with polyhedron elements. num_cells = mesh.GetNumberOfCells() - result = numpy.ones(num_cells, dtype=int) * -1 - with multiprocessing.Pool(processes=options.num_proc) as pool: - generator = pool.imap_unordered(IsPolyhedronConvertible(mesh), range(num_cells), chunksize=options.chunk_size) - for i, val in enumerate(tqdm(generator, total=num_cells, desc="Testing support for elements")): - result[i] = val - unsupported_polyhedron_elements = [i for i in result if i > -1] - return Result(unsupported_std_elements_types=frozenset(unsupported_std_elements_types), - unsupported_polyhedron_elements=frozenset(unsupported_polyhedron_elements)) - - -def check(vtk_input_file: str, options: Options) -> Result: - mesh: vtkUnstructuredGrid = vtk_utils.read_mesh(vtk_input_file) - return __check(mesh, options) + result = numpy.ones( num_cells, dtype=int ) * -1 + with multiprocessing.Pool( processes=options.num_proc ) as pool: + generator = pool.imap_unordered( IsPolyhedronConvertible( mesh ), + range( num_cells ), + chunksize=options.chunk_size ) + for i, val in enumerate( tqdm( generator, total=num_cells, desc="Testing support for elements" ) ): + result[ i ] = val + unsupported_polyhedron_elements = [ i for i in result if i > -1 ] + return Result( unsupported_std_elements_types=frozenset( unsupported_std_elements_types ), + unsupported_polyhedron_elements=frozenset( unsupported_polyhedron_elements ) ) + + +def check( vtk_input_file: str, options: Options ) -> Result: + mesh: vtkUnstructuredGrid = vtk_utils.read_mesh( vtk_input_file ) + return __check( mesh, options ) diff --git a/geosx_mesh_doctor/checks/triangle_distance.py b/geosx_mesh_doctor/checks/triangle_distance.py index ef1f3c9..dbee0a5 100644 --- a/geosx_mesh_doctor/checks/triangle_distance.py +++ b/geosx_mesh_doctor/checks/triangle_distance.py @@ -6,7 +6,7 @@ from numpy.linalg import norm -def __div_clamp(num: float, den :float) -> float: +def __div_clamp( num: float, den: float ) -> float: """ Computes the division `num / den`. and clamps the result between 0 and 1. If `den` is zero, the result of the division is set to 0. @@ -25,8 +25,8 @@ def __div_clamp(num: float, den :float) -> float: return tmp -def distance_between_two_segments(x0: numpy.ndarray, d0: numpy.ndarray, - x1: numpy.ndarray, d1: numpy.ndarray) -> Tuple[numpy.ndarray, numpy.ndarray]: +def distance_between_two_segments( x0: numpy.ndarray, d0: numpy.ndarray, x1: numpy.ndarray, + d1: numpy.ndarray ) -> Tuple[ numpy.ndarray, numpy.ndarray ]: """ Compute the minimum distance between two segments. :param x0: First point of segment 0. @@ -41,11 +41,11 @@ def distance_between_two_segments(x0: numpy.ndarray, d0: numpy.ndarray, # In the reference, the indices start at 1, while in this implementation, they start at 0. tmp: numpy.ndarray = x1 - x0 - D0: float = numpy.dot(d0, d0) # As such, this is D1 in the reference paper. - D1: float = numpy.dot(d1, d1) - R: float = numpy.dot(d0, d1) - S0: float = numpy.dot(d0, tmp) - S1: float = numpy.dot(d1, tmp) + D0: float = numpy.dot( d0, d0 ) # As such, this is D1 in the reference paper. + D1: float = numpy.dot( d1, d1 ) + R: float = numpy.dot( d0, d1 ) + S0: float = numpy.dot( d0, tmp ) + S1: float = numpy.dot( d1, tmp ) # `t0` parameterizes line 0: # - when t0 = 0 the point is p0. @@ -56,18 +56,20 @@ def distance_between_two_segments(x0: numpy.ndarray, d0: numpy.ndarray, # They'll be considered along the line using `div_clamp`. # Step 2: Computing t0 using eq (11). - t0: float = __div_clamp(S0 * D1 - S1 * R, D0 * D1 - R * R) + t0: float = __div_clamp( S0 * D1 - S1 * R, D0 * D1 - R * R ) # Step 3: compute t1 for point on line 1 closest to point at t0. - t1: float = __div_clamp(t0 * R - S1, D1) # Eq (10, right) - sol_1: numpy.ndarray = x1 + t1 * d1 # Eq (3) - t0: float = __div_clamp(t1 * R + S0, D0) # Eq (10, left) - sol_0: numpy.ndarray = x0 + t0 * d0 # Eq (4) + t1: float = __div_clamp( t0 * R - S1, D1 ) # Eq (10, right) + sol_1: numpy.ndarray = x1 + t1 * d1 # Eq (3) + t0: float = __div_clamp( t1 * R + S0, D0 ) # Eq (10, left) + sol_0: numpy.ndarray = x0 + t0 * d0 # Eq (4) return sol_0, sol_1 -def __compute_nodes_to_triangle_distance(tri_0, edges_0, tri_1) -> Tuple[Union[float, None], Union[numpy.ndarray, None], Union[numpy.ndarray, None], bool]: +def __compute_nodes_to_triangle_distance( + tri_0, edges_0, + tri_1 ) -> Tuple[ Union[ float, None ], Union[ numpy.ndarray, None ], Union[ numpy.ndarray, None ], bool ]: """ Computes the distance from nodes of `tri_1` points onto `tri_0`. :param tri_0: First triangle. @@ -78,42 +80,42 @@ def __compute_nodes_to_triangle_distance(tri_0, edges_0, tri_1) -> Tuple[Union[f then the first three arguments are None. The boolean being still defined. """ are_disjoint: bool = False - tri_0_normal: numpy.ndarray = numpy.cross(edges_0[0], edges_0[1]) - tri_0_normal_norm: float = numpy.dot(tri_0_normal, tri_0_normal) + tri_0_normal: numpy.ndarray = numpy.cross( edges_0[ 0 ], edges_0[ 1 ] ) + tri_0_normal_norm: float = numpy.dot( tri_0_normal, tri_0_normal ) # Forget about degenerate cases. - if tri_0_normal_norm > numpy.finfo(float).eps: + if tri_0_normal_norm > numpy.finfo( float ).eps: # Build projection lengths of `tri_1` points. - tri_1_proj = numpy.empty(3, dtype=float) - for i in range(3): - tri_1_proj[i] = numpy.dot(tri_0[0] - tri_1[i], tri_0_normal) + tri_1_proj = numpy.empty( 3, dtype=float ) + for i in range( 3 ): + tri_1_proj[ i ] = numpy.dot( tri_0[ 0 ] - tri_1[ i ], tri_0_normal ) # Considering `tri_0` separates the space in 2, # let's check if `tri_1` is on one side only. # If so, let's take the closest point. point: int = -1 - if numpy.all(tri_1_proj > 0): - point = numpy.argmin(tri_1_proj) - elif numpy.all(tri_1_proj < 0): - point = numpy.argmax(tri_1_proj) + if numpy.all( tri_1_proj > 0 ): + point = numpy.argmin( tri_1_proj ) + elif numpy.all( tri_1_proj < 0 ): + point = numpy.argmax( tri_1_proj ) # So if `tri_1` is actually "on one side", # point `tri_1[point]` is candidate to be the closest point. if point > -1: are_disjoint = True # But we must check that its projection is inside `tri_0`. - if numpy.dot(tri_1[point] - tri_0[0], numpy.cross(tri_0_normal, edges_0[0])) > 0: - if numpy.dot(tri_1[point] - tri_0[1], numpy.cross(tri_0_normal, edges_0[1])) > 0: - if numpy.dot(tri_1[point] - tri_0[2], numpy.cross(tri_0_normal, edges_0[2])) > 0: + if numpy.dot( tri_1[ point ] - tri_0[ 0 ], numpy.cross( tri_0_normal, edges_0[ 0 ] ) ) > 0: + if numpy.dot( tri_1[ point ] - tri_0[ 1 ], numpy.cross( tri_0_normal, edges_0[ 1 ] ) ) > 0: + if numpy.dot( tri_1[ point ] - tri_0[ 2 ], numpy.cross( tri_0_normal, edges_0[ 2 ] ) ) > 0: # It is! - sol_0 = tri_1[point] - sol_1 = tri_1[point] + (tri_1_proj[point] / tri_0_normal_norm) * tri_0_normal - return norm(sol_1 - sol_0), sol_0, sol_1, are_disjoint + sol_0 = tri_1[ point ] + sol_1 = tri_1[ point ] + ( tri_1_proj[ point ] / tri_0_normal_norm ) * tri_0_normal + return norm( sol_1 - sol_0 ), sol_0, sol_1, are_disjoint return None, None, None, are_disjoint -def distance_between_two_triangles(tri_0: numpy.ndarray, - tri_1: numpy.ndarray) -> Tuple[float, numpy.ndarray, numpy.ndarray]: +def distance_between_two_triangles( tri_0: numpy.ndarray, + tri_1: numpy.ndarray ) -> Tuple[ float, numpy.ndarray, numpy.ndarray ]: """ Returns the minimum distance between two triangles, and the two points where this minimum occurs. If the two triangles touch, then distance is exactly 0. @@ -123,39 +125,39 @@ def distance_between_two_triangles(tri_0: numpy.ndarray, :return: The distance and the two points. """ # Compute vectors along the 6 sides - edges_0 = numpy.empty((3, 3), dtype=float) - edges_1 = numpy.empty((3, 3), dtype=float) - for i in range(3): - edges_0[i][:] = tri_0[(i + 1) % 3] - tri_0[i] - edges_1[i][:] = tri_1[(i + 1) % 3] - tri_1[i] - - min_sol_0 = numpy.empty(3, dtype=float) - min_sol_1 = numpy.empty(3, dtype=float) + edges_0 = numpy.empty( ( 3, 3 ), dtype=float ) + edges_1 = numpy.empty( ( 3, 3 ), dtype=float ) + for i in range( 3 ): + edges_0[ i ][ : ] = tri_0[ ( i + 1 ) % 3 ] - tri_0[ i ] + edges_1[ i ][ : ] = tri_1[ ( i + 1 ) % 3 ] - tri_1[ i ] + + min_sol_0 = numpy.empty( 3, dtype=float ) + min_sol_1 = numpy.empty( 3, dtype=float ) are_disjoint: bool = False min_dist = numpy.inf # Looping over all the pair of edges. - for i, j in itertools.product(range(3), repeat=2): + for i, j in itertools.product( range( 3 ), repeat=2 ): # Find the closest points on edges i and j. - sol_0, sol_1 = distance_between_two_segments(tri_0[i], edges_0[i], tri_1[j], edges_1[j]) + sol_0, sol_1 = distance_between_two_segments( tri_0[ i ], edges_0[ i ], tri_1[ j ], edges_1[ j ] ) # Computing the distance between the two solutions. delta_sol = sol_1 - sol_0 - dist: float = numpy.dot(delta_sol, delta_sol) + dist: float = numpy.dot( delta_sol, delta_sol ) # Update minimum if relevant and check if it's the closest pair of points. if dist <= min_dist: - min_sol_0[:] = sol_0 - min_sol_1[:] = sol_1 + min_sol_0[ : ] = sol_0 + min_sol_1[ : ] = sol_1 min_dist = dist # `tri_0[(i + 2) % 3]` is the points opposite to edges_0[i] where the closest point sol_0 lies. # Computing those scalar products and checking the signs somehow let us determine # if the triangles are getting closer to each other when approaching the sol_(0|1) nodes. # If so, we have a minimum. - a: float = numpy.dot(tri_0[(i + 2) % 3] - sol_0, delta_sol) - b: float = numpy.dot(tri_1[(j + 2) % 3] - sol_1, delta_sol) + a: float = numpy.dot( tri_0[ ( i + 2 ) % 3 ] - sol_0, delta_sol ) + b: float = numpy.dot( tri_1[ ( j + 2 ) % 3 ] - sol_1, delta_sol ) if a <= 0 <= b: - return sqrt(dist), sol_0, sol_1 + return sqrt( dist ), sol_0, sol_1 if a < 0: a = 0 @@ -168,12 +170,12 @@ def distance_between_two_triangles(tri_0: numpy.ndarray, are_disjoint = True # No edge pair contained the closest points. # Checking the node/face situation. - distance, sol_0, sol_1, are_disjoint_tmp = __compute_nodes_to_triangle_distance(tri_0, edges_0, tri_1) + distance, sol_0, sol_1, are_disjoint_tmp = __compute_nodes_to_triangle_distance( tri_0, edges_0, tri_1 ) if distance: return distance, sol_0, sol_1 are_disjoint = are_disjoint or are_disjoint_tmp - distance, sol_0, sol_1, are_disjoint_tmp = __compute_nodes_to_triangle_distance(tri_1, edges_1, tri_0) + distance, sol_0, sol_1, are_disjoint_tmp = __compute_nodes_to_triangle_distance( tri_1, edges_1, tri_0 ) if distance: return distance, sol_0, sol_1 are_disjoint = are_disjoint or are_disjoint_tmp @@ -181,6 +183,6 @@ def distance_between_two_triangles(tri_0: numpy.ndarray, # If the triangles do not overlap, let's return the minimum found during the edges loop. # (maybe an edge was parallel to the other face, and we could not decide for a unique closest point). if are_disjoint: - return sqrt(min_dist), min_sol_0, min_sol_1 + return sqrt( min_dist ), min_sol_0, min_sol_1 else: # Surely overlapping or degenerate triangles. - return 0., numpy.zeros(3, dtype=float), numpy.zeros(3, dtype=float) + return 0., numpy.zeros( 3, dtype=float ), numpy.zeros( 3, dtype=float ) diff --git a/geosx_mesh_doctor/checks/vtk_polyhedron.py b/geosx_mesh_doctor/checks/vtk_polyhedron.py index e246a57..0f09310 100644 --- a/geosx_mesh_doctor/checks/vtk_polyhedron.py +++ b/geosx_mesh_doctor/checks/vtk_polyhedron.py @@ -11,27 +11,25 @@ ) from vtkmodules.vtkCommonCore import ( - vtkIdList, -) + vtkIdList, ) import networkx from .vtk_utils import ( - vtk_iter, -) + vtk_iter, ) -@dataclass(frozen=True) +@dataclass( frozen=True ) class Options: dummy: float -@dataclass(frozen=True) +@dataclass( frozen=True ) class Result: dummy: float -def parse_face_stream(ids: vtkIdList) -> Sequence[Sequence[int]]: +def parse_face_stream( ids: vtkIdList ) -> Sequence[ Sequence[ int ] ]: """ Parses the face stream raw information and converts it into a tuple of tuple of integers, each tuple of integer being the nodes of a face. @@ -39,117 +37,118 @@ def parse_face_stream(ids: vtkIdList) -> Sequence[Sequence[int]]: :return: The tuple of tuple of integers. """ result = [] - it = vtk_iter(ids) - num_faces = next(it) + it = vtk_iter( ids ) + num_faces = next( it ) try: while True: - num_nodes = next(it) + num_nodes = next( it ) tmp = [] - for i in range(num_nodes): - tmp.append(next(it)) - result.append(tuple(tmp)) + for i in range( num_nodes ): + tmp.append( next( it ) ) + result.append( tuple( tmp ) ) except StopIteration: pass - assert len(result) == num_faces - assert sum(map(len, result)) + len(result) + 1 == ids.GetNumberOfIds() + assert len( result ) == num_faces + assert sum( map( len, result ) ) + len( result ) + 1 == ids.GetNumberOfIds() - return tuple(result) + return tuple( result ) class FaceStream: """ Helper class to manipulate the vtk face streams. """ - def __init__(self, data: Sequence[Sequence[int]]): + + def __init__( self, data: Sequence[ Sequence[ int ] ] ): # self.__data contains the list of faces nodes, like it appears in vtk face streams. # Except that the additional size information is removed # in favor of the __len__ of the containers. - self.__data: Sequence[Sequence[int]] = data + self.__data: Sequence[ Sequence[ int ] ] = data @staticmethod - def build_from_vtk_id_list(ids: vtkIdList): + def build_from_vtk_id_list( ids: vtkIdList ): """ Builds a FaceStream from the raw vtk face stream. :param ids: The vtk face stream. :return: A new FaceStream instance. """ - return FaceStream(parse_face_stream(ids)) + return FaceStream( parse_face_stream( ids ) ) @property - def face_nodes(self) -> Iterable[Sequence[int]]: + def face_nodes( self ) -> Iterable[ Sequence[ int ] ]: """ Iterate on the nodes of all the faces. :return: An iterator. """ - return iter(self.__data) + return iter( self.__data ) @property - def num_faces(self) -> int: + def num_faces( self ) -> int: """ Number of faces in the face stream :return: An integer """ - return len(self.__data) + return len( self.__data ) @property - def support_point_ids(self) -> Collection[int]: + def support_point_ids( self ) -> Collection[ int ]: """ The list of all (unique) support points of the face stream, in no specific order. :return: The set of all the point ids. """ - tmp: List[int] = [] + tmp: List[ int ] = [] for nodes in self.face_nodes: tmp += nodes - return frozenset(tmp) + return frozenset( tmp ) @property - def num_support_points(self) -> int: + def num_support_points( self ) -> int: """ The number of unique support nodes of the polyhedron. :return: An integer. """ - return len(self.support_point_ids) + return len( self.support_point_ids ) - def __getitem__(self, face_index) -> Sequence[int]: + def __getitem__( self, face_index ) -> Sequence[ int ]: """ The support point ids for the `face_index` face. :param face_index: The face index (within the face stream). :return: A tuple containing all the point ids. """ - return self.__data[face_index] + return self.__data[ face_index ] - def flip_faces(self, face_indices): + def flip_faces( self, face_indices ): """ Returns a new FaceStream instance with the face indices defined in face_indices flipped., :param face_indices: The faces (local) indices to flip. :return: A newly created instance. """ result = [] - for face_index, face_nodes in enumerate(self.__data): - result.append(tuple(reversed(face_nodes)) if face_index in face_indices else face_nodes) - return FaceStream(tuple(result)) + for face_index, face_nodes in enumerate( self.__data ): + result.append( tuple( reversed( face_nodes ) ) if face_index in face_indices else face_nodes ) + return FaceStream( tuple( result ) ) - def dump(self) -> Sequence[int]: + def dump( self ) -> Sequence[ int ]: """ Returns the face stream awaited by vtk, but in a python container. The content can be used, once converted to a vtkIdList, to define another polyhedron in vtk. :return: The face stream in a python container. """ - result = [len(self.__data)] + result = [ len( self.__data ) ] for face_nodes in self.__data: - result.append(len(face_nodes)) + result.append( len( face_nodes ) ) result += face_nodes - return tuple(result) + return tuple( result ) - def __repr__(self): - result = [str(len(self.__data))] + def __repr__( self ): + result = [ str( len( self.__data ) ) ] for face_nodes in self.__data: - result.append(str(len(face_nodes))) - result.append(", ".join(map(str, face_nodes))) - return ",\n".join(result) + result.append( str( len( face_nodes ) ) ) + result.append( ", ".join( map( str, face_nodes ) ) ) + return ",\n".join( result ) -def build_face_to_face_connectivity_through_edges(face_stream: FaceStream, add_compatibility=False) -> networkx.Graph: +def build_face_to_face_connectivity_through_edges( face_stream: FaceStream, add_compatibility=False ) -> networkx.Graph: """ Given a face stream/polyhedron, builds the connections between the faces. Those connections happen when two faces share an edge. @@ -161,52 +160,52 @@ def build_face_to_face_connectivity_through_edges(face_stream: FaceStream, add_c :return: A graph which nodes are actually the faces of the polyhedron. Two nodes of the graph are connected if they share an edge. """ - edges_to_face_indices: Dict[FrozenSet[int], List[int]] = defaultdict(list) - for face_index, face_nodes in enumerate(face_stream.face_nodes): + edges_to_face_indices: Dict[ FrozenSet[ int ], List[ int ] ] = defaultdict( list ) + for face_index, face_nodes in enumerate( face_stream.face_nodes ): # Each edge is defined by two nodes. We do a small trick to loop on consecutive points. - face_indices: Tuple[int, int] - for face_indices in zip(face_nodes, face_nodes[1:] + (face_nodes[0], )): - edges_to_face_indices[frozenset(face_indices)].append(face_index) + face_indices: Tuple[ int, int ] + for face_indices in zip( face_nodes, face_nodes[ 1: ] + ( face_nodes[ 0 ], ) ): + edges_to_face_indices[ frozenset( face_indices ) ].append( face_index ) # We are doing here some small validations w.r.t. the connections of the faces # which may only make sense in the context of numerical simulations. # As such, an error will be thrown in case the polyhedron is not closed. # So there may be a lack of absolute genericity, and the code may evolve if needed. for face_indices in edges_to_face_indices.values(): - assert len(face_indices) == 2 + assert len( face_indices ) == 2 # Computing the graph degree for validation - degrees: Dict[int, int] = defaultdict(int) + degrees: Dict[ int, int ] = defaultdict( int ) for face_indices in edges_to_face_indices.values(): for face_index in face_indices: - degrees[face_index] += 1 + degrees[ face_index ] += 1 for face_index, degree in degrees.items(): - assert len(face_stream[face_index]) == degree + assert len( face_stream[ face_index ] ) == degree # Validation that there is one unique edge connecting two faces. - face_indices_to_edge_index = defaultdict(list) + face_indices_to_edge_index = defaultdict( list ) for edge_index, face_indices in edges_to_face_indices.items(): - face_indices_to_edge_index[frozenset(face_indices)].append(edge_index) + face_indices_to_edge_index[ frozenset( face_indices ) ].append( edge_index ) for edge_indices in face_indices_to_edge_index.values(): - assert len(edge_indices) == 1 + assert len( edge_indices ) == 1 # Connecting the faces. Neighbor faces with consistent normals (i.e. facing both inward or outward) # will be connected. This will allow us to extract connected components with consistent orientations. # Another step will then reconcile all the components such that all the normals of the cell # will consistently point outward. graph = networkx.Graph() - graph.add_nodes_from(range(face_stream.num_faces)) + graph.add_nodes_from( range( face_stream.num_faces ) ) for edge, face_indices in edges_to_face_indices.items(): face_index_0, face_index_1 = face_indices - face_nodes_0 = face_stream[face_index_0] + (face_stream[face_index_0][0],) - face_nodes_1 = face_stream[face_index_1] + (face_stream[face_index_1][0],) + face_nodes_0 = face_stream[ face_index_0 ] + ( face_stream[ face_index_0 ][ 0 ], ) + face_nodes_1 = face_stream[ face_index_1 ] + ( face_stream[ face_index_1 ][ 0 ], ) node_0, node_1 = edge - order_0 = 1 if face_nodes_0[face_nodes_0.index(node_0) + 1] == node_1 else -1 - order_1 = 1 if face_nodes_1[face_nodes_1.index(node_0) + 1] == node_1 else -1 + order_0 = 1 if face_nodes_0[ face_nodes_0.index( node_0 ) + 1 ] == node_1 else -1 + order_1 = 1 if face_nodes_1[ face_nodes_1.index( node_0 ) + 1 ] == node_1 else -1 # Same order of nodes means that the normals of the faces # are _not_ both in the same "direction" (inward or outward). if order_0 * order_1 == 1: if add_compatibility: - graph.add_edge(face_index_0, face_index_1, compatible="-") + graph.add_edge( face_index_0, face_index_1, compatible="-" ) else: if add_compatibility: - graph.add_edge(face_index_0, face_index_1, compatible="+") + graph.add_edge( face_index_0, face_index_1, compatible="+" ) else: - graph.add_edge(face_index_0, face_index_1) + graph.add_edge( face_index_0, face_index_1 ) return graph diff --git a/geosx_mesh_doctor/checks/vtk_utils.py b/geosx_mesh_doctor/checks/vtk_utils.py index 2604609..9beb375 100644 --- a/geosx_mesh_doctor/checks/vtk_utils.py +++ b/geosx_mesh_doctor/checks/vtk_utils.py @@ -9,11 +9,9 @@ ) from vtkmodules.vtkCommonCore import ( - vtkIdList, -) + vtkIdList, ) from vtkmodules.vtkCommonDataModel import ( - vtkUnstructuredGrid, -) + vtkUnstructuredGrid, ) from vtkmodules.vtkIOLegacy import ( vtkUnstructuredGridWriter, vtkUnstructuredGridReader, @@ -24,103 +22,102 @@ ) -@dataclass(frozen=True) +@dataclass( frozen=True ) class VtkOutput: output: str is_data_mode_binary: bool -def to_vtk_id_list(data) -> vtkIdList: +def to_vtk_id_list( data ) -> vtkIdList: result = vtkIdList() - result.Allocate(len(data)) + result.Allocate( len( data ) ) for d in data: - result.InsertNextId(d) + result.InsertNextId( d ) return result -def vtk_iter(l) -> Iterator[Any]: +def vtk_iter( l ) -> Iterator[ Any ]: """ Utility function transforming a vtk "container" (e.g. vtkIdList) into an iterable to be used for building built-ins python containers. :param l: A vtk container. :return: The iterator. """ - if hasattr(l, "GetNumberOfIds"): - for i in range(l.GetNumberOfIds()): - yield l.GetId(i) - elif hasattr(l, "GetNumberOfTypes"): - for i in range(l.GetNumberOfTypes()): - yield l.GetCellType(i) + if hasattr( l, "GetNumberOfIds" ): + for i in range( l.GetNumberOfIds() ): + yield l.GetId( i ) + elif hasattr( l, "GetNumberOfTypes" ): + for i in range( l.GetNumberOfTypes() ): + yield l.GetCellType( i ) -def __read_vtk(vtk_input_file: str) -> Optional[vtkUnstructuredGrid]: +def __read_vtk( vtk_input_file: str ) -> Optional[ vtkUnstructuredGrid ]: reader = vtkUnstructuredGridReader() - logging.info(f"Testing file format \"{vtk_input_file}\" using legacy format reader...") - reader.SetFileName(vtk_input_file) + logging.info( f"Testing file format \"{vtk_input_file}\" using legacy format reader..." ) + reader.SetFileName( vtk_input_file ) if reader.IsFileUnstructuredGrid(): - logging.info(f"Reader matches. Reading file \"{vtk_input_file}\" using legacy format reader.") + logging.info( f"Reader matches. Reading file \"{vtk_input_file}\" using legacy format reader." ) reader.Update() return reader.GetOutput() else: - logging.info("Reader did not match the input file format.") + logging.info( "Reader did not match the input file format." ) return None -def __read_vtu(vtk_input_file: str) -> Optional[vtkUnstructuredGrid]: +def __read_vtu( vtk_input_file: str ) -> Optional[ vtkUnstructuredGrid ]: reader = vtkXMLUnstructuredGridReader() - logging.info(f"Testing file format \"{vtk_input_file}\" using XML format reader...") - if reader.CanReadFile(vtk_input_file): - reader.SetFileName(vtk_input_file) - logging.info(f"Reader matches. Reading file \"{vtk_input_file}\" using XML format reader.") + logging.info( f"Testing file format \"{vtk_input_file}\" using XML format reader..." ) + if reader.CanReadFile( vtk_input_file ): + reader.SetFileName( vtk_input_file ) + logging.info( f"Reader matches. Reading file \"{vtk_input_file}\" using XML format reader." ) reader.Update() return reader.GetOutput() else: - logging.info("Reader did not match the input file format.") + logging.info( "Reader did not match the input file format." ) return None -def read_mesh(vtk_input_file: str) -> vtkUnstructuredGrid: +def read_mesh( vtk_input_file: str ) -> vtkUnstructuredGrid: """ Read the vtk file and builds an unstructured grid from it. :param vtk_input_file: The file name. The extension will be used to guess the file format. If first guess does not work, eventually all the others reader available will be tested. :return: A unstructured grid. """ - file_extension = os.path.splitext(vtk_input_file)[-1] - extension_to_reader = {".vtk": __read_vtk, - ".vtu": __read_vtu} + file_extension = os.path.splitext( vtk_input_file )[ -1 ] + extension_to_reader = { ".vtk": __read_vtk, ".vtu": __read_vtu } # Testing first the reader that should match if file_extension in extension_to_reader: - output_mesh = extension_to_reader.pop(file_extension)(vtk_input_file) + output_mesh = extension_to_reader.pop( file_extension )( vtk_input_file ) if output_mesh: return output_mesh # If it does not match, then test all the others. for reader in extension_to_reader.values(): - output_mesh = reader(vtk_input_file) + output_mesh = reader( vtk_input_file ) if output_mesh: return output_mesh # No reader did work. Dying. - logging.critical(f"Could not find the appropriate VTK reader for file \"{vtk_input_file}\". Dying...") - sys.exit(1) + logging.critical( f"Could not find the appropriate VTK reader for file \"{vtk_input_file}\". Dying..." ) + sys.exit( 1 ) -def __write_vtk(mesh: vtkUnstructuredGrid, output: str) -> int: - logging.info(f"Writing mesh into file \"{output}\" using legacy format.") +def __write_vtk( mesh: vtkUnstructuredGrid, output: str ) -> int: + logging.info( f"Writing mesh into file \"{output}\" using legacy format." ) writer = vtkUnstructuredGridWriter() - writer.SetFileName(output) - writer.SetInputData(mesh) + writer.SetFileName( output ) + writer.SetInputData( mesh ) return writer.Write() -def __write_vtu(mesh: vtkUnstructuredGrid, output: str, is_data_mode_binary: bool) -> int: - logging.info(f"Writing mesh into file \"{output}\" using XML format.") +def __write_vtu( mesh: vtkUnstructuredGrid, output: str, is_data_mode_binary: bool ) -> int: + logging.info( f"Writing mesh into file \"{output}\" using XML format." ) writer = vtkXMLUnstructuredGridWriter() - writer.SetFileName(output) - writer.SetInputData(mesh) + writer.SetFileName( output ) + writer.SetInputData( mesh ) writer.SetDataModeToBinary() if is_data_mode_binary else writer.SetDataModeToAscii() return writer.Write() -def write_mesh(mesh: vtkUnstructuredGrid, vtk_output: VtkOutput) -> int: +def write_mesh( mesh: vtkUnstructuredGrid, vtk_output: VtkOutput ) -> int: """ Writes the mesh to disk. Nothing will be done if the file already exists. @@ -128,16 +125,16 @@ def write_mesh(mesh: vtkUnstructuredGrid, vtk_output: VtkOutput) -> int: :param vtk_output: Where to write. The file extension will be used to select the VTK file format. :return: 0 in case of success. """ - if os.path.exists(vtk_output.output): - logging.error(f"File \"{vtk_output.output}\" already exists, nothing done.") + if os.path.exists( vtk_output.output ): + logging.error( f"File \"{vtk_output.output}\" already exists, nothing done." ) return 1 - file_extension = os.path.splitext(vtk_output.output)[-1] + file_extension = os.path.splitext( vtk_output.output )[ -1 ] if file_extension == ".vtk": - success_code = __write_vtk(mesh, vtk_output.output) + success_code = __write_vtk( mesh, vtk_output.output ) elif file_extension == ".vtu": - success_code = __write_vtu(mesh, vtk_output.output, vtk_output.is_data_mode_binary) + success_code = __write_vtu( mesh, vtk_output.output, vtk_output.is_data_mode_binary ) else: # No writer found did work. Dying. - logging.critical(f"Could not find the appropriate VTK writer for extension \"{file_extension}\". Dying...") - sys.exit(1) + logging.critical( f"Could not find the appropriate VTK writer for extension \"{file_extension}\". Dying..." ) + sys.exit( 1 ) return 0 if success_code else 2 # the Write member function return 1 in case of success, 0 otherwise. diff --git a/geosx_mesh_doctor/mesh_doctor.py b/geosx_mesh_doctor/mesh_doctor.py index f28cc7e..0e652ac 100644 --- a/geosx_mesh_doctor/mesh_doctor.py +++ b/geosx_mesh_doctor/mesh_doctor.py @@ -1,11 +1,11 @@ import sys try: - min_python_version = (3, 7) + min_python_version = ( 3, 7 ) assert sys.version_info >= min_python_version except AssertionError as e: - print(f"Please update python to at least version {'.'.join(map(str, min_python_version))}.") - sys.exit(1) + print( f"Please update python to at least version {'.'.join(map(str, min_python_version))}." ) + sys.exit( 1 ) import logging @@ -15,20 +15,20 @@ def main(): - logging.basicConfig(format='[%(asctime)s][%(levelname)s] %(message)s') - parse_and_set_verbosity(sys.argv) + logging.basicConfig( format='[%(asctime)s][%(levelname)s] %(message)s' ) + parse_and_set_verbosity( sys.argv ) main_parser, all_checks, all_checks_helpers = register.register() - args = main_parser.parse_args(sys.argv[1:]) - logging.info(f"Checking mesh \"{args.vtk_input_file}\".") - check_options = all_checks_helpers[args.subparsers].convert(vars(args)) + args = main_parser.parse_args( sys.argv[ 1: ] ) + logging.info( f"Checking mesh \"{args.vtk_input_file}\"." ) + check_options = all_checks_helpers[ args.subparsers ].convert( vars( args ) ) try: - check = all_checks[args.subparsers] + check = all_checks[ args.subparsers ] except KeyError as e: - logging.critical(f"Check {args.subparsers} is not a valid check.") - sys.exit(1) - helper: CheckHelper = all_checks_helpers[args.subparsers] - result = check(args.vtk_input_file, check_options) - helper.display_results(check_options, result) + logging.critical( f"Check {args.subparsers} is not a valid check." ) + sys.exit( 1 ) + helper: CheckHelper = all_checks_helpers[ args.subparsers ] + result = check( args.vtk_input_file, check_options ) + helper.display_results( check_options, result ) if __name__ == '__main__': diff --git a/geosx_mesh_doctor/parsing/__init__.py b/geosx_mesh_doctor/parsing/__init__.py index 0d06f73..679f880 100644 --- a/geosx_mesh_doctor/parsing/__init__.py +++ b/geosx_mesh_doctor/parsing/__init__.py @@ -2,7 +2,6 @@ from dataclasses import dataclass from typing import Callable, Any - COLLOCATES_NODES = "collocated_nodes" ELEMENT_VOLUMES = "element_volumes" FIX_ELEMENTS_ORDERINGS = "fix_elements_orderings" @@ -14,8 +13,8 @@ SUPPORTED_ELEMENTS = "supported_elements" -@dataclass(frozen=True) +@dataclass( frozen=True ) class CheckHelper: - fill_subparser: Callable[[Any], argparse.ArgumentParser] - convert: Callable[[Any], Any] - display_results: Callable[[Any, Any], None] + fill_subparser: Callable[ [ Any ], argparse.ArgumentParser ] + convert: Callable[ [ Any ], Any ] + display_results: Callable[ [ Any, Any ], None ] diff --git a/geosx_mesh_doctor/parsing/cli_parsing.py b/geosx_mesh_doctor/parsing/cli_parsing.py index a2eb20e..a34010b 100644 --- a/geosx_mesh_doctor/parsing/cli_parsing.py +++ b/geosx_mesh_doctor/parsing/cli_parsing.py @@ -10,33 +10,33 @@ __QUIET_FLAG = "q" -def parse_and_set_verbosity(cli_args: List[str]) -> None: +def parse_and_set_verbosity( cli_args: List[ str ] ) -> None: """ Parse the verbosity flag only. And sets the logger's level accordingly. :param cli_args: The list of arguments (as strings) :return: None """ - dummy_verbosity_parser = argparse.ArgumentParser(add_help=False) - dummy_verbosity_parser.add_argument('-' + __VERBOSITY_FLAG, - '--' + __VERBOSE_KEY, - action='count', - default=2, - dest=__VERBOSE_KEY) - dummy_verbosity_parser.add_argument('-' + __QUIET_FLAG, - '--' + __QUIET_KEY, - action='count', - default=0, - dest=__QUIET_KEY) - args = dummy_verbosity_parser.parse_known_args(cli_args[1:])[0] - d = vars(args) - v = d[__VERBOSE_KEY] - d[__QUIET_KEY] - verbosity = logging.CRITICAL - (10 * v) + dummy_verbosity_parser = argparse.ArgumentParser( add_help=False ) + dummy_verbosity_parser.add_argument( '-' + __VERBOSITY_FLAG, + '--' + __VERBOSE_KEY, + action='count', + default=2, + dest=__VERBOSE_KEY ) + dummy_verbosity_parser.add_argument( '-' + __QUIET_FLAG, + '--' + __QUIET_KEY, + action='count', + default=0, + dest=__QUIET_KEY ) + args = dummy_verbosity_parser.parse_known_args( cli_args[ 1: ] )[ 0 ] + d = vars( args ) + v = d[ __VERBOSE_KEY ] - d[ __QUIET_KEY ] + verbosity = logging.CRITICAL - ( 10 * v ) if verbosity < logging.DEBUG: verbosity = logging.DEBUG elif verbosity > logging.CRITICAL: verbosity = logging.CRITICAL - logging.getLogger().setLevel(verbosity) - logging.info(f"Logger level set to \"{logging.getLevelName(verbosity)}\"") + logging.getLogger().setLevel( verbosity ) + logging.info( f"Logger level set to \"{logging.getLevelName(verbosity)}\"" ) def init_parser() -> argparse.ArgumentParser: @@ -47,27 +47,28 @@ def init_parser() -> argparse.ArgumentParser: An option may be missing because of an unloaded module. Increase verbosity (-{__VERBOSITY_FLAG}, -{__VERBOSITY_FLAG * 2}) to get full information. """ - formatter = lambda prog: argparse.RawTextHelpFormatter(prog, max_help_position=8) - parser = argparse.ArgumentParser(description='Inspects meshes for GEOSX.', - epilog=textwrap.dedent(epilog_msg), - formatter_class=formatter) + formatter = lambda prog: argparse.RawTextHelpFormatter( prog, max_help_position=8 ) + parser = argparse.ArgumentParser( description='Inspects meshes for GEOSX.', + epilog=textwrap.dedent( epilog_msg ), + formatter_class=formatter ) # Nothing will be done with this verbosity/quiet input. # It's only here for the `--help` message. # `parse_verbosity` does the real parsing instead. - parser.add_argument('-' + __VERBOSITY_FLAG, - action='count', - default=2, - dest=__VERBOSE_KEY, - help=f"Use -{__VERBOSITY_FLAG} 'INFO', -{__VERBOSITY_FLAG * 2} for 'DEBUG'. Defaults to 'WARNING'.") - parser.add_argument('-' + __QUIET_FLAG, - action='count', - default=0, - dest=__QUIET_KEY, - help=f"Use -{__QUIET_FLAG} to reduce the verbosity of the output.") - parser.add_argument('-i', - '--vtk-input-file', - metavar='VTK_MESH_FILE', - type=str, - required=True, - dest=vtk_input_file_key) + parser.add_argument( + '-' + __VERBOSITY_FLAG, + action='count', + default=2, + dest=__VERBOSE_KEY, + help=f"Use -{__VERBOSITY_FLAG} 'INFO', -{__VERBOSITY_FLAG * 2} for 'DEBUG'. Defaults to 'WARNING'." ) + parser.add_argument( '-' + __QUIET_FLAG, + action='count', + default=0, + dest=__QUIET_KEY, + help=f"Use -{__QUIET_FLAG} to reduce the verbosity of the output." ) + parser.add_argument( '-i', + '--vtk-input-file', + metavar='VTK_MESH_FILE', + type=str, + required=True, + dest=vtk_input_file_key ) return parser diff --git a/geosx_mesh_doctor/parsing/collocated_nodes_parsing.py b/geosx_mesh_doctor/parsing/collocated_nodes_parsing.py index 421ae95..fd737a8 100644 --- a/geosx_mesh_doctor/parsing/collocated_nodes_parsing.py +++ b/geosx_mesh_doctor/parsing/collocated_nodes_parsing.py @@ -12,38 +12,38 @@ __TOLERANCE = "tolerance" -def convert(parsed_options) -> Options: - return Options(parsed_options[__TOLERANCE]) +def convert( parsed_options ) -> Options: + return Options( parsed_options[ __TOLERANCE ] ) -def fill_subparser(subparsers) -> None: - p = subparsers.add_parser(COLLOCATES_NODES, - help="Checks if nodes are collocated.") - p.add_argument('--' + __TOLERANCE, - type=float, - required=True, - help="[float]: The absolute distance between two nodes for them to be considered collocated.") +def fill_subparser( subparsers ) -> None: + p = subparsers.add_parser( COLLOCATES_NODES, help="Checks if nodes are collocated." ) + p.add_argument( '--' + __TOLERANCE, + type=float, + required=True, + help="[float]: The absolute distance between two nodes for them to be considered collocated." ) -def display_results(options: Options, result: Result): - all_collocated_nodes: List[int] = [] +def display_results( options: Options, result: Result ): + all_collocated_nodes: List[ int ] = [] for bucket in result.nodes_buckets: for node in bucket: - all_collocated_nodes.append(node) - all_collocated_nodes: FrozenSet[int] = frozenset(all_collocated_nodes) # Surely useless + all_collocated_nodes.append( node ) + all_collocated_nodes: FrozenSet[ int ] = frozenset( all_collocated_nodes ) # Surely useless if all_collocated_nodes: - logging.error(f"You have {len(all_collocated_nodes)} collocated nodes (tolerance = {options.tolerance}).") + logging.error( f"You have {len(all_collocated_nodes)} collocated nodes (tolerance = {options.tolerance})." ) - logging.info("Here are all the buckets of collocated nodes.") - tmp: List[str] = [] + logging.info( "Here are all the buckets of collocated nodes." ) + tmp: List[ str ] = [] for bucket in result.nodes_buckets: - tmp.append(f"({', '.join(map(str, bucket))})") - logging.info(f"({', '.join(tmp)})") + tmp.append( f"({', '.join(map(str, bucket))})" ) + logging.info( f"({', '.join(tmp)})" ) else: - logging.error(f"You have no collocated node (tolerance = {options.tolerance}).") + logging.error( f"You have no collocated node (tolerance = {options.tolerance})." ) if result.wrong_support_elements: - tmp: str = ", ".join(map(str, result.wrong_support_elements)) - logging.error(f"You have {len(result.wrong_support_elements)} elements with duplicated support nodes.\n" + tmp) + tmp: str = ", ".join( map( str, result.wrong_support_elements ) ) + logging.error( f"You have {len(result.wrong_support_elements)} elements with duplicated support nodes.\n" + + tmp ) else: - logging.error("You have no element with duplicated support nodes.") + logging.error( "You have no element with duplicated support nodes." ) diff --git a/geosx_mesh_doctor/parsing/element_volumes_parsing.py b/geosx_mesh_doctor/parsing/element_volumes_parsing.py index 3b19682..3c126cd 100644 --- a/geosx_mesh_doctor/parsing/element_volumes_parsing.py +++ b/geosx_mesh_doctor/parsing/element_volumes_parsing.py @@ -8,27 +8,28 @@ __MIN_VOLUME_DEFAULT = 0. -def fill_subparser(subparsers) -> None: - p = subparsers.add_parser(ELEMENT_VOLUMES, - help=f"Checks if the volumes of the elements are greater than \"{__MIN_VOLUME}\".") - p.add_argument('--' + __MIN_VOLUME, - type=float, - metavar=__MIN_VOLUME_DEFAULT, - default=__MIN_VOLUME_DEFAULT, - required=True, - help=f"[float]: The minimum acceptable volume. Defaults to {__MIN_VOLUME_DEFAULT}.") - - -def convert(parsed_options) -> Options: +def fill_subparser( subparsers ) -> None: + p = subparsers.add_parser( ELEMENT_VOLUMES, + help=f"Checks if the volumes of the elements are greater than \"{__MIN_VOLUME}\"." ) + p.add_argument( '--' + __MIN_VOLUME, + type=float, + metavar=__MIN_VOLUME_DEFAULT, + default=__MIN_VOLUME_DEFAULT, + required=True, + help=f"[float]: The minimum acceptable volume. Defaults to {__MIN_VOLUME_DEFAULT}." ) + + +def convert( parsed_options ) -> Options: """ From the parsed cli options, return the converted options for elements volumes check. :param options_str: Parsed cli options. :return: Options instance. """ - return Options(min_volume=parsed_options[__MIN_VOLUME]) + return Options( min_volume=parsed_options[ __MIN_VOLUME ] ) -def display_results(options: Options, result: Result): - logging.error(f"You have {len(result.element_volumes)} elements with volumes smaller than {options.min_volume}.") +def display_results( options: Options, result: Result ): + logging.error( f"You have {len(result.element_volumes)} elements with volumes smaller than {options.min_volume}." ) if result.element_volumes: - logging.error("The elements indices and their volumes are:\n" + "\n".join(map(str, result.element_volumes))) + logging.error( "The elements indices and their volumes are:\n" + + "\n".join( map( str, result.element_volumes ) ) ) diff --git a/geosx_mesh_doctor/parsing/fix_elements_orderings_parsing.py b/geosx_mesh_doctor/parsing/fix_elements_orderings_parsing.py index c105792..7a62006 100644 --- a/geosx_mesh_doctor/parsing/fix_elements_orderings_parsing.py +++ b/geosx_mesh_doctor/parsing/fix_elements_orderings_parsing.py @@ -15,7 +15,6 @@ from . import vtk_output_parsing, FIX_ELEMENTS_ORDERINGS - __CELL_TYPE_MAPPING = { "Hexahedron": VTK_HEXAHEDRON, "Prism5": VTK_PENTAGONAL_PRISM, @@ -37,22 +36,21 @@ } -def fill_subparser(subparsers) -> None: - p = subparsers.add_parser(FIX_ELEMENTS_ORDERINGS, - help="Reorders the support nodes for the given cell types.") +def fill_subparser( subparsers ) -> None: + p = subparsers.add_parser( FIX_ELEMENTS_ORDERINGS, help="Reorders the support nodes for the given cell types." ) for key, vtk_key in __CELL_TYPE_MAPPING.items(): - tmp = list(range(__CELL_TYPE_SUPPORT_SIZE[vtk_key])) - random.Random(4).shuffle(tmp) - p.add_argument('--' + key, - type=str, - metavar=",".join(map(str, tmp)), - default=None, - required=False, - help=f"[list of integers]: node permutation for \"{key}\".") - vtk_output_parsing.fill_vtk_output_subparser(p) + tmp = list( range( __CELL_TYPE_SUPPORT_SIZE[ vtk_key ] ) ) + random.Random( 4 ).shuffle( tmp ) + p.add_argument( '--' + key, + type=str, + metavar=",".join( map( str, tmp ) ), + default=None, + required=False, + help=f"[list of integers]: node permutation for \"{key}\"." ) + vtk_output_parsing.fill_vtk_output_subparser( p ) -def convert(parsed_options) -> Options: +def convert( parsed_options ) -> Options: """ From the parsed cli options, return the converted options for self intersecting elements check. :param options_str: Parsed cli options. @@ -60,25 +58,24 @@ def convert(parsed_options) -> Options: """ cell_type_to_ordering = {} for key, vtk_key in __CELL_TYPE_MAPPING.items(): - raw_mapping = parsed_options[key] + raw_mapping = parsed_options[ key ] if raw_mapping: - tmp = tuple(map(int, raw_mapping.split(","))) - if not set(tmp) == set(range(__CELL_TYPE_SUPPORT_SIZE[vtk_key])): + tmp = tuple( map( int, raw_mapping.split( "," ) ) ) + if not set( tmp ) == set( range( __CELL_TYPE_SUPPORT_SIZE[ vtk_key ] ) ): err_msg = f"Permutation {raw_mapping} for type {key} is not valid." - logging.error(err_msg) - raise ValueError(err_msg) - cell_type_to_ordering[vtk_key] = tmp - vtk_output = vtk_output_parsing.convert(parsed_options) - return Options(vtk_output=vtk_output, - cell_type_to_ordering=cell_type_to_ordering) + logging.error( err_msg ) + raise ValueError( err_msg ) + cell_type_to_ordering[ vtk_key ] = tmp + vtk_output = vtk_output_parsing.convert( parsed_options ) + return Options( vtk_output=vtk_output, cell_type_to_ordering=cell_type_to_ordering ) -def display_results(options: Options, result: Result): +def display_results( options: Options, result: Result ): if result.output: - logging.info(f"New mesh was written to file '{result.output}'") + logging.info( f"New mesh was written to file '{result.output}'" ) if result.unchanged_cell_types: - logging.info(f"Those vtk types were not reordered: [{', '.join(map(str, result.unchanged_cell_types))}].") + logging.info( f"Those vtk types were not reordered: [{', '.join(map(str, result.unchanged_cell_types))}]." ) else: - logging.info("All the cells of the mesh were reordered.") + logging.info( "All the cells of the mesh were reordered." ) else: - logging.info("No output file was written.") + logging.info( "No output file was written." ) diff --git a/geosx_mesh_doctor/parsing/generate_cube_parsing.py b/geosx_mesh_doctor/parsing/generate_cube_parsing.py index 41c0e04..3c8a17d 100644 --- a/geosx_mesh_doctor/parsing/generate_cube_parsing.py +++ b/geosx_mesh_doctor/parsing/generate_cube_parsing.py @@ -5,83 +5,83 @@ from . import vtk_output_parsing, generate_global_ids_parsing, GENERATE_CUBE from .generate_global_ids_parsing import GlobalIdsInfo - __X, __Y, __Z, __NX, __NY, __NZ = "x", "y", "z", "nx", "ny", "nz" __FIELDS = "fields" -def convert(parsed_options) -> Options: - def check_discretizations(x, nx, title): - if len(x) != len(nx) + 1: - raise ValueError(f"{title} information (\"{x}\" and \"{nx}\") does not have consistent size.") - check_discretizations(parsed_options[__X], parsed_options[__NX], __X) - check_discretizations(parsed_options[__Y], parsed_options[__NY], __Y) - check_discretizations(parsed_options[__Z], parsed_options[__NZ], __Z) +def convert( parsed_options ) -> Options: + + def check_discretizations( x, nx, title ): + if len( x ) != len( nx ) + 1: + raise ValueError( f"{title} information (\"{x}\" and \"{nx}\") does not have consistent size." ) + + check_discretizations( parsed_options[ __X ], parsed_options[ __NX ], __X ) + check_discretizations( parsed_options[ __Y ], parsed_options[ __NY ], __Y ) + check_discretizations( parsed_options[ __Z ], parsed_options[ __NZ ], __Z ) - def parse_fields(s): - name, support, dim = s.split(":") - if support not in ("CELLS", "POINTS"): - raise ValueError(f"Support {support} for field \"{name}\" must be one of \"CELLS\" or \"POINTS\".") + def parse_fields( s ): + name, support, dim = s.split( ":" ) + if support not in ( "CELLS", "POINTS" ): + raise ValueError( f"Support {support} for field \"{name}\" must be one of \"CELLS\" or \"POINTS\"." ) try: - dim = int(dim) + dim = int( dim ) assert dim > 0 except ValueError: - raise ValueError(f"Dimension {dim} cannot be converted to an integer.") + raise ValueError( f"Dimension {dim} cannot be converted to an integer." ) except AssertionError: - raise ValueError(f"Dimension {dim} must be a positive integer") - return FieldInfo(name=name, support=support, dimension=dim) + raise ValueError( f"Dimension {dim} must be a positive integer" ) + return FieldInfo( name=name, support=support, dimension=dim ) - gids: GlobalIdsInfo = generate_global_ids_parsing.convert_global_ids(parsed_options) + gids: GlobalIdsInfo = generate_global_ids_parsing.convert_global_ids( parsed_options ) - return Options(vtk_output=vtk_output_parsing.convert(parsed_options), - generate_cells_global_ids=gids.cells, - generate_points_global_ids=gids.points, - xs=parsed_options[__X], - ys=parsed_options[__Y], - zs=parsed_options[__Z], - nxs=parsed_options[__NX], - nys=parsed_options[__NY], - nzs=parsed_options[__NZ], - fields=tuple(map(parse_fields, parsed_options[__FIELDS]))) + return Options( vtk_output=vtk_output_parsing.convert( parsed_options ), + generate_cells_global_ids=gids.cells, + generate_points_global_ids=gids.points, + xs=parsed_options[ __X ], + ys=parsed_options[ __Y ], + zs=parsed_options[ __Z ], + nxs=parsed_options[ __NX ], + nys=parsed_options[ __NY ], + nzs=parsed_options[ __NZ ], + fields=tuple( map( parse_fields, parsed_options[ __FIELDS ] ) ) ) -def fill_subparser(subparsers) -> None: - p = subparsers.add_parser(GENERATE_CUBE, - help="Generate a cube and its fields.") - p.add_argument('--' + __X, - type=lambda s: tuple(map(float, s.split(":"))), - metavar="0:1.5:3", - help="[list of floats]: X coordinates of the points.") - p.add_argument('--' + __Y, - type=lambda s: tuple(map(float, s.split(":"))), - metavar="0:5:10", - help="[list of floats]: Y coordinates of the points.") - p.add_argument('--' + __Z, - type=lambda s: tuple(map(float, s.split(":"))), - metavar="0:1", - help="[list of floats]: Z coordinates of the points.") - p.add_argument('--' + __NX, - type=lambda s: tuple(map(int, s.split(":"))), - metavar="2:2", - help="[list of integers]: Number of elements in the X direction.") - p.add_argument('--' + __NY, - type=lambda s: tuple(map(int, s.split(":"))), - metavar="1:1", - help="[list of integers]: Number of elements in the Y direction.") - p.add_argument('--' + __NZ, - type=lambda s: tuple(map(int, s.split(":"))), - metavar="4", - help="[list of integers]: Number of elements in the Z direction.") - p.add_argument('--' + __FIELDS, - type=str, - metavar="name:support:dim", - nargs="+", - required=False, - default=(), - help="Create fields on CELLS or POINTS, with given dimension (typically 1 or 3).") - generate_global_ids_parsing.fill_generate_global_ids_subparser(p) - vtk_output_parsing.fill_vtk_output_subparser(p) +def fill_subparser( subparsers ) -> None: + p = subparsers.add_parser( GENERATE_CUBE, help="Generate a cube and its fields." ) + p.add_argument( '--' + __X, + type=lambda s: tuple( map( float, s.split( ":" ) ) ), + metavar="0:1.5:3", + help="[list of floats]: X coordinates of the points." ) + p.add_argument( '--' + __Y, + type=lambda s: tuple( map( float, s.split( ":" ) ) ), + metavar="0:5:10", + help="[list of floats]: Y coordinates of the points." ) + p.add_argument( '--' + __Z, + type=lambda s: tuple( map( float, s.split( ":" ) ) ), + metavar="0:1", + help="[list of floats]: Z coordinates of the points." ) + p.add_argument( '--' + __NX, + type=lambda s: tuple( map( int, s.split( ":" ) ) ), + metavar="2:2", + help="[list of integers]: Number of elements in the X direction." ) + p.add_argument( '--' + __NY, + type=lambda s: tuple( map( int, s.split( ":" ) ) ), + metavar="1:1", + help="[list of integers]: Number of elements in the Y direction." ) + p.add_argument( '--' + __NZ, + type=lambda s: tuple( map( int, s.split( ":" ) ) ), + metavar="4", + help="[list of integers]: Number of elements in the Z direction." ) + p.add_argument( '--' + __FIELDS, + type=str, + metavar="name:support:dim", + nargs="+", + required=False, + default=(), + help="Create fields on CELLS or POINTS, with given dimension (typically 1 or 3)." ) + generate_global_ids_parsing.fill_generate_global_ids_subparser( p ) + vtk_output_parsing.fill_vtk_output_subparser( p ) -def display_results(options: Options, result: Result): - logging.info(result.info) +def display_results( options: Options, result: Result ): + logging.info( result.info ) diff --git a/geosx_mesh_doctor/parsing/generate_fractures_parsing.py b/geosx_mesh_doctor/parsing/generate_fractures_parsing.py index 4789793..f384089 100644 --- a/geosx_mesh_doctor/parsing/generate_fractures_parsing.py +++ b/geosx_mesh_doctor/parsing/generate_fractures_parsing.py @@ -7,7 +7,7 @@ __POLICY = "policy" __FIELD_POLICY = "field" __INTERNAL_SURFACES_POLICY = "internal_surfaces" -__POLICIES = (__FIELD_POLICY, __INTERNAL_SURFACES_POLICY ) +__POLICIES = ( __FIELD_POLICY, __INTERNAL_SURFACES_POLICY ) __FIELD_NAME = "name" __FIELD_VALUES = "values" @@ -15,7 +15,7 @@ __FRACTURE_PREFIX = "fracture" -def convert_to_fracture_policy(s: str) -> FracturePolicy: +def convert_to_fracture_policy( s: str ) -> FracturePolicy: """ Converts the user input to the proper enum chosen. I do not want to use the auto conversion already available to force explicit conversion. @@ -26,41 +26,47 @@ def convert_to_fracture_policy(s: str) -> FracturePolicy: return FracturePolicy.FIELD elif s == __INTERNAL_SURFACES_POLICY: return FracturePolicy.INTERNAL_SURFACES - raise ValueError(f"Policy {s} is not valid. Please use one of \"{', '.join(map(str, __POLICIES))}\".") + raise ValueError( f"Policy {s} is not valid. Please use one of \"{', '.join(map(str, __POLICIES))}\"." ) -def fill_subparser(subparsers) -> None: - p = subparsers.add_parser(GENERATE_FRACTURES, - help="Splits the mesh to generate the faults and fractures. [EXPERIMENTAL]") - p.add_argument('--' + __POLICY, - type=convert_to_fracture_policy, - metavar=", ".join(__POLICIES), - required=True, - help=f"[string]: The criterion to define the surfaces that will be changed into fracture zones. " - f"Possible values are \"{', '.join(__POLICIES)}\"") - p.add_argument('--' + __FIELD_NAME, - type=str, - help=f"[string]: If the \"{__FIELD_POLICY}\" {__POLICY} is selected, defines which field will be considered to define the fractures. " - f"If the \"{__INTERNAL_SURFACES_POLICY}\" {__POLICY} is selected, defines the name of the attribute will be considered to identify the fractures. ") - p.add_argument('--' + __FIELD_VALUES, - type=str, - help=f"[list of comma separated integers]: If the \"{__FIELD_POLICY}\" {__POLICY} is selected, which changes of the field will be considered as a fracture. If the \"{__INTERNAL_SURFACES_POLICY}\" {__POLICY} is selected, list of the fracture attributes.") - vtk_output_parsing.fill_vtk_output_subparser(p) - vtk_output_parsing.fill_vtk_output_subparser(p, prefix=__FRACTURE_PREFIX) +def fill_subparser( subparsers ) -> None: + p = subparsers.add_parser( GENERATE_FRACTURES, + help="Splits the mesh to generate the faults and fractures. [EXPERIMENTAL]" ) + p.add_argument( '--' + __POLICY, + type=convert_to_fracture_policy, + metavar=", ".join( __POLICIES ), + required=True, + help=f"[string]: The criterion to define the surfaces that will be changed into fracture zones. " + f"Possible values are \"{', '.join(__POLICIES)}\"" ) + p.add_argument( + '--' + __FIELD_NAME, + type=str, + help= + f"[string]: If the \"{__FIELD_POLICY}\" {__POLICY} is selected, defines which field will be considered to define the fractures. " + f"If the \"{__INTERNAL_SURFACES_POLICY}\" {__POLICY} is selected, defines the name of the attribute will be considered to identify the fractures. " + ) + p.add_argument( + '--' + __FIELD_VALUES, + type=str, + help= + f"[list of comma separated integers]: If the \"{__FIELD_POLICY}\" {__POLICY} is selected, which changes of the field will be considered as a fracture. If the \"{__INTERNAL_SURFACES_POLICY}\" {__POLICY} is selected, list of the fracture attributes." + ) + vtk_output_parsing.fill_vtk_output_subparser( p ) + vtk_output_parsing.fill_vtk_output_subparser( p, prefix=__FRACTURE_PREFIX ) -def convert(parsed_options) -> Options: - policy = parsed_options[__POLICY] - field = parsed_options[__FIELD_NAME] - field_values = frozenset(map(int, parsed_options[__FIELD_VALUES].split(","))) - vtk_output = vtk_output_parsing.convert(parsed_options) - vtk_fracture_output = vtk_output_parsing.convert(parsed_options, prefix=__FRACTURE_PREFIX) - return Options(policy=policy, - field=field, - field_values=field_values, - vtk_output=vtk_output, - vtk_fracture_output=vtk_fracture_output) +def convert( parsed_options ) -> Options: + policy = parsed_options[ __POLICY ] + field = parsed_options[ __FIELD_NAME ] + field_values = frozenset( map( int, parsed_options[ __FIELD_VALUES ].split( "," ) ) ) + vtk_output = vtk_output_parsing.convert( parsed_options ) + vtk_fracture_output = vtk_output_parsing.convert( parsed_options, prefix=__FRACTURE_PREFIX ) + return Options( policy=policy, + field=field, + field_values=field_values, + vtk_output=vtk_output, + vtk_fracture_output=vtk_fracture_output ) -def display_results(options: Options, result: Result): +def display_results( options: Options, result: Result ): pass diff --git a/geosx_mesh_doctor/parsing/generate_global_ids_parsing.py b/geosx_mesh_doctor/parsing/generate_global_ids_parsing.py index 730599a..06efd43 100644 --- a/geosx_mesh_doctor/parsing/generate_global_ids_parsing.py +++ b/geosx_mesh_doctor/parsing/generate_global_ids_parsing.py @@ -5,53 +5,50 @@ from . import vtk_output_parsing, GENERATE_GLOBAL_IDS - __CELLS, __POINTS = "cells", "points" -@dataclass(frozen=True) +@dataclass( frozen=True ) class GlobalIdsInfo: cells: bool points: bool -def convert_global_ids(parsed_options) -> GlobalIdsInfo: - return GlobalIdsInfo(cells=parsed_options[__CELLS], - points=parsed_options[__POINTS]) +def convert_global_ids( parsed_options ) -> GlobalIdsInfo: + return GlobalIdsInfo( cells=parsed_options[ __CELLS ], points=parsed_options[ __POINTS ] ) -def convert(parsed_options) -> Options: - gids: GlobalIdsInfo = convert_global_ids(parsed_options) - return Options(vtk_output=vtk_output_parsing.convert(parsed_options), - generate_cells_global_ids=gids.cells, - generate_points_global_ids=gids.points) +def convert( parsed_options ) -> Options: + gids: GlobalIdsInfo = convert_global_ids( parsed_options ) + return Options( vtk_output=vtk_output_parsing.convert( parsed_options ), + generate_cells_global_ids=gids.cells, + generate_points_global_ids=gids.points ) -def fill_generate_global_ids_subparser(p): - p.add_argument('--' + __CELLS, - action="store_true", - help=f"[bool]: Generate global ids for cells. Defaults to true.") - p.add_argument('--no-' + __CELLS, - action="store_false", - dest=__CELLS, - help=f"[bool]: Don't generate global ids for cells.") - p.set_defaults(**{__CELLS: True}) - p.add_argument('--' + __POINTS, - action="store_true", - help=f"[bool]: Generate global ids for points. Defaults to true.") - p.add_argument('--no-' + __POINTS, - action="store_false", - dest=__POINTS, - help=f"[bool]: Don't generate global ids for points.") - p.set_defaults(**{__POINTS: True}) +def fill_generate_global_ids_subparser( p ): + p.add_argument( '--' + __CELLS, + action="store_true", + help=f"[bool]: Generate global ids for cells. Defaults to true." ) + p.add_argument( '--no-' + __CELLS, + action="store_false", + dest=__CELLS, + help=f"[bool]: Don't generate global ids for cells." ) + p.set_defaults( **{ __CELLS: True } ) + p.add_argument( '--' + __POINTS, + action="store_true", + help=f"[bool]: Generate global ids for points. Defaults to true." ) + p.add_argument( '--no-' + __POINTS, + action="store_false", + dest=__POINTS, + help=f"[bool]: Don't generate global ids for points." ) + p.set_defaults( **{ __POINTS: True } ) -def fill_subparser(subparsers) -> None: - p = subparsers.add_parser(GENERATE_GLOBAL_IDS, - help="Adds globals ids for points and cells.") - fill_generate_global_ids_subparser(p) - vtk_output_parsing.fill_vtk_output_subparser(p) +def fill_subparser( subparsers ) -> None: + p = subparsers.add_parser( GENERATE_GLOBAL_IDS, help="Adds globals ids for points and cells." ) + fill_generate_global_ids_subparser( p ) + vtk_output_parsing.fill_vtk_output_subparser( p ) -def display_results(options: Options, result: Result): - logging.info(result.info) +def display_results( options: Options, result: Result ): + logging.info( result.info ) diff --git a/geosx_mesh_doctor/parsing/non_conformal_parsing.py b/geosx_mesh_doctor/parsing/non_conformal_parsing.py index 33625f6..046c960 100644 --- a/geosx_mesh_doctor/parsing/non_conformal_parsing.py +++ b/geosx_mesh_doctor/parsing/non_conformal_parsing.py @@ -1,8 +1,8 @@ import logging from typing import ( - FrozenSet, - List, + FrozenSet, + List, ) from checks.non_conformal import Options, Result @@ -15,34 +15,35 @@ __ANGLE_TOLERANCE_DEFAULT = 10. -__ALL_KEYWORDS = {__ANGLE_TOLERANCE, __POINT_TOLERANCE, __FACE_TOLERANCE} +__ALL_KEYWORDS = { __ANGLE_TOLERANCE, __POINT_TOLERANCE, __FACE_TOLERANCE } -def convert(parsed_options) -> Options: - return Options(angle_tolerance=parsed_options[__ANGLE_TOLERANCE], - point_tolerance=parsed_options[__POINT_TOLERANCE], - face_tolerance=parsed_options[__FACE_TOLERANCE]) +def convert( parsed_options ) -> Options: + return Options( angle_tolerance=parsed_options[ __ANGLE_TOLERANCE ], + point_tolerance=parsed_options[ __POINT_TOLERANCE ], + face_tolerance=parsed_options[ __FACE_TOLERANCE ] ) -def fill_subparser(subparsers) -> None: - p = subparsers.add_parser(NON_CONFORMAL, - help="Detects non conformal elements. [EXPERIMENTAL]") - p.add_argument('--' + __ANGLE_TOLERANCE, - type=float, - metavar=__ANGLE_TOLERANCE_DEFAULT, - default=__ANGLE_TOLERANCE_DEFAULT, - help=f"[float]: angle tolerance in degrees. Defaults to {__ANGLE_TOLERANCE_DEFAULT}") - p.add_argument('--' + __POINT_TOLERANCE, - type=float, - help=f"[float]: tolerance for two points to be considered collocated.") - p.add_argument('--' + __FACE_TOLERANCE, - type=float, - help=f"[float]: tolerance for two faces to be considered \"touching\".") +def fill_subparser( subparsers ) -> None: + p = subparsers.add_parser( NON_CONFORMAL, help="Detects non conformal elements. [EXPERIMENTAL]" ) + p.add_argument( '--' + __ANGLE_TOLERANCE, + type=float, + metavar=__ANGLE_TOLERANCE_DEFAULT, + default=__ANGLE_TOLERANCE_DEFAULT, + help=f"[float]: angle tolerance in degrees. Defaults to {__ANGLE_TOLERANCE_DEFAULT}" ) + p.add_argument( '--' + __POINT_TOLERANCE, + type=float, + help=f"[float]: tolerance for two points to be considered collocated." ) + p.add_argument( '--' + __FACE_TOLERANCE, + type=float, + help=f"[float]: tolerance for two faces to be considered \"touching\"." ) -def display_results(options: Options, result: Result): - non_conformal_cells: List[int] = [] +def display_results( options: Options, result: Result ): + non_conformal_cells: List[ int ] = [] for i, j in result.non_conformal_cells: non_conformal_cells += i, j - non_conformal_cells: FrozenSet[int] = frozenset(non_conformal_cells) - logging.error(f"You have {len(non_conformal_cells)} non conformal cells.\n{', '.join(map(str, sorted(non_conformal_cells)))}") + non_conformal_cells: FrozenSet[ int ] = frozenset( non_conformal_cells ) + logging.error( + f"You have {len(non_conformal_cells)} non conformal cells.\n{', '.join(map(str, sorted(non_conformal_cells)))}" + ) diff --git a/geosx_mesh_doctor/parsing/self_intersecting_elements_parsing.py b/geosx_mesh_doctor/parsing/self_intersecting_elements_parsing.py index 70f5d6a..b5c2f54 100644 --- a/geosx_mesh_doctor/parsing/self_intersecting_elements_parsing.py +++ b/geosx_mesh_doctor/parsing/self_intersecting_elements_parsing.py @@ -7,30 +7,34 @@ from . import SELF_INTERSECTING_ELEMENTS __TOLERANCE = "min" -__TOLERANCE_DEFAULT = numpy.finfo(float).eps +__TOLERANCE_DEFAULT = numpy.finfo( float ).eps -def convert(parsed_options) -> Options: - tolerance = parsed_options[__TOLERANCE] +def convert( parsed_options ) -> Options: + tolerance = parsed_options[ __TOLERANCE ] if tolerance == 0: - logging.warning("Having tolerance set to 0 can induce lots of false positive results (adjacent faces may be considered intersecting).") + logging.warning( + "Having tolerance set to 0 can induce lots of false positive results (adjacent faces may be considered intersecting)." + ) elif tolerance < 0: - raise ValueError(f"Negative tolerance ({tolerance}) in the {SELF_INTERSECTING_ELEMENTS} check is not allowed.") - return Options(tolerance=tolerance) + raise ValueError( + f"Negative tolerance ({tolerance}) in the {SELF_INTERSECTING_ELEMENTS} check is not allowed." ) + return Options( tolerance=tolerance ) -def fill_subparser(subparsers) -> None: - p = subparsers.add_parser(SELF_INTERSECTING_ELEMENTS, - help="Checks if the faces of the elements are self intersecting.") - p.add_argument('--' + __TOLERANCE, - type=float, - required=False, - metavar=__TOLERANCE_DEFAULT, - default=__TOLERANCE_DEFAULT, - help=f"[float]: The tolerance in the computation. Defaults to your machine precision {__TOLERANCE_DEFAULT}.") +def fill_subparser( subparsers ) -> None: + p = subparsers.add_parser( SELF_INTERSECTING_ELEMENTS, + help="Checks if the faces of the elements are self intersecting." ) + p.add_argument( + '--' + __TOLERANCE, + type=float, + required=False, + metavar=__TOLERANCE_DEFAULT, + default=__TOLERANCE_DEFAULT, + help=f"[float]: The tolerance in the computation. Defaults to your machine precision {__TOLERANCE_DEFAULT}." ) -def display_results(options: Options, result: Result): - logging.error(f"You have {len(result.intersecting_faces_elements)} elements with self intersecting faces.") +def display_results( options: Options, result: Result ): + logging.error( f"You have {len(result.intersecting_faces_elements)} elements with self intersecting faces." ) if result.intersecting_faces_elements: - logging.error("The elements indices are:\n" + ", ".join(map(str, result.intersecting_faces_elements))) + logging.error( "The elements indices are:\n" + ", ".join( map( str, result.intersecting_faces_elements ) ) ) diff --git a/geosx_mesh_doctor/parsing/supported_elements_parsing.py b/geosx_mesh_doctor/parsing/supported_elements_parsing.py index c68905b..b43ae04 100644 --- a/geosx_mesh_doctor/parsing/supported_elements_parsing.py +++ b/geosx_mesh_doctor/parsing/supported_elements_parsing.py @@ -8,42 +8,47 @@ __CHUNK_SIZE = "chunck_size" __NUM_PROC = "nproc" - -__ALL_KEYWORDS = {__CHUNK_SIZE, __NUM_PROC} +__ALL_KEYWORDS = { __CHUNK_SIZE, __NUM_PROC } __CHUNK_SIZE_DEFAULT = 1 __NUM_PROC_DEFAULT = multiprocessing.cpu_count() -def convert(parsed_options) -> Options: - return Options(chunk_size=parsed_options[__CHUNK_SIZE], - num_proc=parsed_options[__NUM_PROC]) +def convert( parsed_options ) -> Options: + return Options( chunk_size=parsed_options[ __CHUNK_SIZE ], num_proc=parsed_options[ __NUM_PROC ] ) -def fill_subparser(subparsers) -> None: - p = subparsers.add_parser(SUPPORTED_ELEMENTS, - help="Check that all the elements of the mesh are supported by GEOSX.") - p.add_argument('--' + __CHUNK_SIZE, - type=int, - required=False, - metavar=__CHUNK_SIZE_DEFAULT, - default=__CHUNK_SIZE_DEFAULT, - help=f"[int]: Defaults chunk size for parallel processing to {__CHUNK_SIZE_DEFAULT}") - p.add_argument('--' + __NUM_PROC, - type=int, - required=False, - metavar=__NUM_PROC_DEFAULT, - default=__NUM_PROC_DEFAULT, - help=f"[int]: Number of threads used for parallel processing. Defaults to your CPU count {__NUM_PROC_DEFAULT}.") +def fill_subparser( subparsers ) -> None: + p = subparsers.add_parser( SUPPORTED_ELEMENTS, + help="Check that all the elements of the mesh are supported by GEOSX." ) + p.add_argument( '--' + __CHUNK_SIZE, + type=int, + required=False, + metavar=__CHUNK_SIZE_DEFAULT, + default=__CHUNK_SIZE_DEFAULT, + help=f"[int]: Defaults chunk size for parallel processing to {__CHUNK_SIZE_DEFAULT}" ) + p.add_argument( + '--' + __NUM_PROC, + type=int, + required=False, + metavar=__NUM_PROC_DEFAULT, + default=__NUM_PROC_DEFAULT, + help=f"[int]: Number of threads used for parallel processing. Defaults to your CPU count {__NUM_PROC_DEFAULT}." + ) -def display_results(options: Options, result: Result): +def display_results( options: Options, result: Result ): if result.unsupported_polyhedron_elements: - logging.error(f"There is/are {len(result.unsupported_polyhedron_elements)} polyhedra that may not be converted to supported elements.") - logging.error(f"The list of the unsupported polyhedra is\n{tuple(sorted(result.unsupported_polyhedron_elements))}.") + logging.error( + f"There is/are {len(result.unsupported_polyhedron_elements)} polyhedra that may not be converted to supported elements." + ) + logging.error( + f"The list of the unsupported polyhedra is\n{tuple(sorted(result.unsupported_polyhedron_elements))}." ) else: - logging.info("All the polyhedra (if any) can be converted to supported elements.") + logging.info( "All the polyhedra (if any) can be converted to supported elements." ) if result.unsupported_std_elements_types: - logging.error(f"There are unsupported vtk standard element types. The list of those vtk types is {tuple(sorted(result.unsupported_std_elements_types))}.") + logging.error( + f"There are unsupported vtk standard element types. The list of those vtk types is {tuple(sorted(result.unsupported_std_elements_types))}." + ) else: - logging.info("All the standard vtk element types (if any) are supported.") \ No newline at end of file + logging.info( "All the standard vtk element types (if any) are supported." ) diff --git a/geosx_mesh_doctor/parsing/vtk_output_parsing.py b/geosx_mesh_doctor/parsing/vtk_output_parsing.py index 6e9b7d5..be31037 100644 --- a/geosx_mesh_doctor/parsing/vtk_output_parsing.py +++ b/geosx_mesh_doctor/parsing/vtk_output_parsing.py @@ -4,42 +4,43 @@ from checks.vtk_utils import VtkOutput - __OUTPUT_FILE = "output" __OUTPUT_BINARY_MODE = "data-mode" __OUTPUT_BINARY_MODE_VALUES = "binary", "ascii" -__OUTPUT_BINARY_MODE_DEFAULT = __OUTPUT_BINARY_MODE_VALUES[0] +__OUTPUT_BINARY_MODE_DEFAULT = __OUTPUT_BINARY_MODE_VALUES[ 0 ] def get_vtk_output_help(): msg = \ f"""{__OUTPUT_FILE} [string]: The vtk output file destination. {__OUTPUT_BINARY_MODE} [string]: For ".vtu" output format, the data mode can be {" or ".join(__OUTPUT_BINARY_MODE_VALUES)}. Defaults to {__OUTPUT_BINARY_MODE_DEFAULT}.""" - return textwrap.dedent(msg) - - -def __build_arg(prefix, main): - return "-".join(filter(None, (prefix, main))) - - -def fill_vtk_output_subparser(parser, prefix="") -> None: - parser.add_argument('--' + __build_arg(prefix, __OUTPUT_FILE), - type=str, - required=True, - help=f"[string]: The vtk output file destination.") - parser.add_argument('--' + __build_arg(prefix, __OUTPUT_BINARY_MODE), - type=str, - metavar=", ".join(__OUTPUT_BINARY_MODE_VALUES), - default=__OUTPUT_BINARY_MODE_DEFAULT, - help=f"""[string]: For ".vtu" output format, the data mode can be {" or ".join(__OUTPUT_BINARY_MODE_VALUES)}. Defaults to {__OUTPUT_BINARY_MODE_DEFAULT}.""") - - -def convert(parsed_options, prefix="") -> VtkOutput: - output_key = __build_arg(prefix, __OUTPUT_FILE).replace("-", "_") - binary_mode_key = __build_arg(prefix, __OUTPUT_BINARY_MODE).replace("-", "_") - output = parsed_options[output_key] - if parsed_options[binary_mode_key] and os.path.splitext(output)[-1] == ".vtk": - logging.info("VTK data mode will be ignored for legacy file format \"vtk\".") - is_data_mode_binary: bool = parsed_options[binary_mode_key] == __OUTPUT_BINARY_MODE_DEFAULT - return VtkOutput(output=output, - is_data_mode_binary=is_data_mode_binary) + return textwrap.dedent( msg ) + + +def __build_arg( prefix, main ): + return "-".join( filter( None, ( prefix, main ) ) ) + + +def fill_vtk_output_subparser( parser, prefix="" ) -> None: + parser.add_argument( '--' + __build_arg( prefix, __OUTPUT_FILE ), + type=str, + required=True, + help=f"[string]: The vtk output file destination." ) + parser.add_argument( + '--' + __build_arg( prefix, __OUTPUT_BINARY_MODE ), + type=str, + metavar=", ".join( __OUTPUT_BINARY_MODE_VALUES ), + default=__OUTPUT_BINARY_MODE_DEFAULT, + help= + f"""[string]: For ".vtu" output format, the data mode can be {" or ".join(__OUTPUT_BINARY_MODE_VALUES)}. Defaults to {__OUTPUT_BINARY_MODE_DEFAULT}.""" + ) + + +def convert( parsed_options, prefix="" ) -> VtkOutput: + output_key = __build_arg( prefix, __OUTPUT_FILE ).replace( "-", "_" ) + binary_mode_key = __build_arg( prefix, __OUTPUT_BINARY_MODE ).replace( "-", "_" ) + output = parsed_options[ output_key ] + if parsed_options[ binary_mode_key ] and os.path.splitext( output )[ -1 ] == ".vtk": + logging.info( "VTK data mode will be ignored for legacy file format \"vtk\"." ) + is_data_mode_binary: bool = parsed_options[ binary_mode_key ] == __OUTPUT_BINARY_MODE_DEFAULT + return VtkOutput( output=output, is_data_mode_binary=is_data_mode_binary ) diff --git a/geosx_mesh_doctor/register.py b/geosx_mesh_doctor/register.py index a36001e..1626053 100644 --- a/geosx_mesh_doctor/register.py +++ b/geosx_mesh_doctor/register.py @@ -6,67 +6,62 @@ import parsing from parsing import CheckHelper, cli_parsing +__HELPERS: Dict[ str, Callable[ [ None ], CheckHelper ] ] = dict() +__CHECKS: Dict[ str, Callable[ [ None ], Any ] ] = dict() -__HELPERS: Dict[str, Callable[[None], CheckHelper]] = dict() -__CHECKS: Dict[str, Callable[[None], Any]] = dict() +def __load_module_check( module_name: str, check_fct="check" ): + module = importlib.import_module( "checks." + module_name ) + return getattr( module, check_fct ) -def __load_module_check(module_name: str, check_fct="check"): - module = importlib.import_module("checks." + module_name) - return getattr(module, check_fct) +def __load_module_check_helper( module_name: str, parsing_fct_suffix="_parsing" ): + module = importlib.import_module( "parsing." + module_name + parsing_fct_suffix ) + return CheckHelper( fill_subparser=module.fill_subparser, + convert=module.convert, + display_results=module.display_results ) -def __load_module_check_helper(module_name: str, parsing_fct_suffix="_parsing"): - module = importlib.import_module("parsing." + module_name + parsing_fct_suffix) - return CheckHelper(fill_subparser=module.fill_subparser, - convert=module.convert, - display_results=module.display_results) - -def __load_checks() -> Dict[str, Callable[[str, Any], Any]]: +def __load_checks() -> Dict[ str, Callable[ [ str, Any ], Any ] ]: """ Loads all the checks. This function acts like a protection layer if a module fails to load. A check that fails to load won't stop the process. :return: The checks. """ - loaded_checks: Dict[str, Callable[[str, Any], Any]] = dict() + loaded_checks: Dict[ str, Callable[ [ str, Any ], Any ] ] = dict() for check_name, check_provider in __CHECKS.items(): try: - loaded_checks[check_name] = check_provider() - logging.debug(f"Check \"{check_name}\" is loaded.") + loaded_checks[ check_name ] = check_provider() + logging.debug( f"Check \"{check_name}\" is loaded." ) except Exception as e: - logging.warning(f"Could not load module \"{check_name}\": {e}") + logging.warning( f"Could not load module \"{check_name}\": {e}" ) return loaded_checks -def register() -> Tuple[argparse.ArgumentParser, Dict[str, Callable[[str, Any], Any]], Dict[str, CheckHelper]]: +def register( +) -> Tuple[ argparse.ArgumentParser, Dict[ str, Callable[ [ str, Any ], Any ] ], Dict[ str, CheckHelper ] ]: """ Register all the parsing checks. Eventually initiate the registration of all the checks too. :return: The checks and the checks helpers. """ parser = cli_parsing.init_parser() - subparsers = parser.add_subparsers(help="Modules", dest="subparsers") + subparsers = parser.add_subparsers( help="Modules", dest="subparsers" ) + + def closure_trick( cn: str ): + __HELPERS[ check_name ] = lambda: __load_module_check_helper( cn ) + __CHECKS[ check_name ] = lambda: __load_module_check( cn ) - def closure_trick(cn: str): - __HELPERS[check_name] = lambda: __load_module_check_helper(cn) - __CHECKS[check_name] = lambda: __load_module_check(cn) # Register the modules to load here. - for check_name in (parsing.COLLOCATES_NODES, - parsing.ELEMENT_VOLUMES, - parsing.FIX_ELEMENTS_ORDERINGS, - parsing.GENERATE_CUBE, - parsing.GENERATE_FRACTURES, - parsing.GENERATE_GLOBAL_IDS, - parsing.NON_CONFORMAL, - parsing.SELF_INTERSECTING_ELEMENTS, - parsing.SUPPORTED_ELEMENTS): - closure_trick(check_name) - loaded_checks: Dict[str, Callable[[str, Any], Any]] = __load_checks() - loaded_checks_helpers: Dict[str, CheckHelper] = dict() + for check_name in ( parsing.COLLOCATES_NODES, parsing.ELEMENT_VOLUMES, parsing.FIX_ELEMENTS_ORDERINGS, + parsing.GENERATE_CUBE, parsing.GENERATE_FRACTURES, parsing.GENERATE_GLOBAL_IDS, + parsing.NON_CONFORMAL, parsing.SELF_INTERSECTING_ELEMENTS, parsing.SUPPORTED_ELEMENTS ): + closure_trick( check_name ) + loaded_checks: Dict[ str, Callable[ [ str, Any ], Any ] ] = __load_checks() + loaded_checks_helpers: Dict[ str, CheckHelper ] = dict() for check_name in loaded_checks.keys(): - h = __HELPERS[check_name]() - h.fill_subparser(subparsers) - loaded_checks_helpers[check_name] = h - logging.debug(f"Parsing for check \"{check_name}\" is loaded.") + h = __HELPERS[ check_name ]() + h.fill_subparser( subparsers ) + loaded_checks_helpers[ check_name ] = h + logging.debug( f"Parsing for check \"{check_name}\" is loaded." ) return parser, loaded_checks, loaded_checks_helpers diff --git a/geosx_mesh_doctor/setup.py b/geosx_mesh_doctor/setup.py index 1d0b991..dc03ac1 100644 --- a/geosx_mesh_doctor/setup.py +++ b/geosx_mesh_doctor/setup.py @@ -1,3 +1,3 @@ from setuptools import setup, find_packages -setup(name='mesh_doctor', version='0.0.1', packages=find_packages()) +setup( name='mesh_doctor', version='0.0.1', packages=find_packages() ) diff --git a/geosx_mesh_doctor/tests/test_cli_parsing.py b/geosx_mesh_doctor/tests/test_cli_parsing.py index 445b7c9..1989e24 100644 --- a/geosx_mesh_doctor/tests/test_cli_parsing.py +++ b/geosx_mesh_doctor/tests/test_cli_parsing.py @@ -9,8 +9,7 @@ import pytest from checks.vtk_utils import ( - VtkOutput, -) + VtkOutput, ) from checks.generate_fractures import ( FracturePolicy, @@ -23,36 +22,39 @@ ) -@dataclass(frozen=True) +@dataclass( frozen=True ) class TestCase: __test__ = False - cli_args: Sequence[str] + cli_args: Sequence[ str ] options: Options exception: bool = False -def __generate_generate_fractures_parsing_test_data() -> Iterator[TestCase]: +def __generate_generate_fractures_parsing_test_data() -> Iterator[ TestCase ]: field: str = "attribute" main_mesh: str = "output.vtu" fracture_mesh: str = "fracture.vtu" cli_gen: str = f"generate_fractures --policy {{}} --name {field} --values 0,1 --output {main_mesh} --fracture-output {fracture_mesh}" - all_cli_args = cli_gen.format("field").split(), cli_gen.format("internal_surfaces").split(), cli_gen.format("dummy").split() + all_cli_args = cli_gen.format( "field" ).split(), cli_gen.format( "internal_surfaces" ).split(), cli_gen.format( + "dummy" ).split() policies = FracturePolicy.FIELD, FracturePolicy.INTERNAL_SURFACES, FracturePolicy.FIELD exceptions = False, False, True - for cli_args, policy, exception in zip(all_cli_args, policies, exceptions): - options: Options = Options(policy=policy, field=field, field_values=frozenset((0, 1)), - vtk_output=VtkOutput(output=main_mesh, is_data_mode_binary=True), - vtk_fracture_output=VtkOutput(output=fracture_mesh, is_data_mode_binary=True)) - yield TestCase(cli_args, options, exception) + for cli_args, policy, exception in zip( all_cli_args, policies, exceptions ): + options: Options = Options( policy=policy, + field=field, + field_values=frozenset( ( 0, 1 ) ), + vtk_output=VtkOutput( output=main_mesh, is_data_mode_binary=True ), + vtk_fracture_output=VtkOutput( output=fracture_mesh, is_data_mode_binary=True ) ) + yield TestCase( cli_args, options, exception ) -def __f(test_case: TestCase): - parser = argparse.ArgumentParser(description='Testing.') +def __f( test_case: TestCase ): + parser = argparse.ArgumentParser( description='Testing.' ) subparsers = parser.add_subparsers() - fill_subparser(subparsers) - args = parser.parse_args(test_case.cli_args) - options = convert(vars(args)) + fill_subparser( subparsers ) + args = parser.parse_args( test_case.cli_args ) + options = convert( vars( args ) ) assert options.policy == test_case.options.policy assert options.field == test_case.options.field assert options.field_values == test_case.options.field_values @@ -60,13 +62,13 @@ def __f(test_case: TestCase): def test_display_results(): # Dummy test for code coverage only. Shame on me! - display_results(None, None) + display_results( None, None ) -@pytest.mark.parametrize("test_case", __generate_generate_fractures_parsing_test_data()) -def test(test_case: TestCase): +@pytest.mark.parametrize( "test_case", __generate_generate_fractures_parsing_test_data() ) +def test( test_case: TestCase ): if test_case.exception: - with pytest.raises(SystemExit): - __f(test_case) + with pytest.raises( SystemExit ): + __f( test_case ) else: - __f(test_case) + __f( test_case ) diff --git a/geosx_mesh_doctor/tests/test_collocated_nodes.py b/geosx_mesh_doctor/tests/test_collocated_nodes.py index 6936331..d0b6645 100644 --- a/geosx_mesh_doctor/tests/test_collocated_nodes.py +++ b/geosx_mesh_doctor/tests/test_collocated_nodes.py @@ -3,8 +3,7 @@ import pytest from vtkmodules.vtkCommonCore import ( - vtkPoints, -) + vtkPoints, ) from vtkmodules.vtkCommonDataModel import ( VTK_TETRA, vtkCellArray, @@ -15,62 +14,62 @@ from checks.collocated_nodes import Options, __check -def get_points() -> Iterator[Tuple[vtkPoints, int]]: +def get_points() -> Iterator[ Tuple[ vtkPoints, int ] ]: """ Generates the data for the cases. One case has two nodes at the exact same position. The other has two differente nodes :return: Generator to (vtk points, number of expected duplicated locations) """ - for p0, p1 in ((0, 0, 0), (1, 1, 1)), ((0, 0, 0), (0, 0, 0)): + for p0, p1 in ( ( 0, 0, 0 ), ( 1, 1, 1 ) ), ( ( 0, 0, 0 ), ( 0, 0, 0 ) ): points = vtkPoints() - points.SetNumberOfPoints(2) - points.SetPoint(0, p0) - points.SetPoint(1, p1) + points.SetNumberOfPoints( 2 ) + points.SetPoint( 0, p0 ) + points.SetPoint( 1, p1 ) num_nodes_bucket = 1 if p0 == p1 else 0 yield points, num_nodes_bucket -@pytest.mark.parametrize("data", get_points()) -def test_simple_collocated_points(data: Tuple[vtkPoints, int]): +@pytest.mark.parametrize( "data", get_points() ) +def test_simple_collocated_points( data: Tuple[ vtkPoints, int ] ): points, num_nodes_bucket = data mesh = vtkUnstructuredGrid() - mesh.SetPoints(points) + mesh.SetPoints( points ) - result = __check(mesh, Options(tolerance=1.e-12)) + result = __check( mesh, Options( tolerance=1.e-12 ) ) - assert len(result.wrong_support_elements) == 0 - assert len(result.nodes_buckets) == num_nodes_bucket + assert len( result.wrong_support_elements ) == 0 + assert len( result.nodes_buckets ) == num_nodes_bucket if num_nodes_bucket == 1: - assert len(result.nodes_buckets[0]) == points.GetNumberOfPoints() + assert len( result.nodes_buckets[ 0 ] ) == points.GetNumberOfPoints() def test_wrong_support_elements(): points = vtkPoints() - points.SetNumberOfPoints(4) - points.SetPoint(0, (0, 0, 0)) - points.SetPoint(1, (1, 0, 0)) - points.SetPoint(2, (0, 1, 0)) - points.SetPoint(3, (0, 0, 1)) + points.SetNumberOfPoints( 4 ) + points.SetPoint( 0, ( 0, 0, 0 ) ) + points.SetPoint( 1, ( 1, 0, 0 ) ) + points.SetPoint( 2, ( 0, 1, 0 ) ) + points.SetPoint( 3, ( 0, 0, 1 ) ) - cell_types = [VTK_TETRA] + cell_types = [ VTK_TETRA ] cells = vtkCellArray() - cells.AllocateExact(1, 4) + cells.AllocateExact( 1, 4 ) tet = vtkTetra() - tet.GetPointIds().SetId(0, 0) - tet.GetPointIds().SetId(1, 1) - tet.GetPointIds().SetId(2, 2) - tet.GetPointIds().SetId(3, 0) # Intentionally wrong - cells.InsertNextCell(tet) + tet.GetPointIds().SetId( 0, 0 ) + tet.GetPointIds().SetId( 1, 1 ) + tet.GetPointIds().SetId( 2, 2 ) + tet.GetPointIds().SetId( 3, 0 ) # Intentionally wrong + cells.InsertNextCell( tet ) mesh = vtkUnstructuredGrid() - mesh.SetPoints(points) - mesh.SetCells(cell_types, cells) + mesh.SetPoints( points ) + mesh.SetCells( cell_types, cells ) - result = __check(mesh, Options(tolerance=1.e-12)) + result = __check( mesh, Options( tolerance=1.e-12 ) ) - assert len(result.nodes_buckets) == 0 - assert len(result.wrong_support_elements) == 1 - assert result.wrong_support_elements[0] == 0 + assert len( result.nodes_buckets ) == 0 + assert len( result.wrong_support_elements ) == 1 + assert result.wrong_support_elements[ 0 ] == 0 diff --git a/geosx_mesh_doctor/tests/test_element_volumes.py b/geosx_mesh_doctor/tests/test_element_volumes.py index e37c22c..163300c 100644 --- a/geosx_mesh_doctor/tests/test_element_volumes.py +++ b/geosx_mesh_doctor/tests/test_element_volumes.py @@ -1,8 +1,7 @@ import numpy from vtkmodules.vtkCommonCore import ( - vtkPoints, -) + vtkPoints, ) from vtkmodules.vtkCommonDataModel import ( VTK_TETRA, vtkCellArray, @@ -16,33 +15,33 @@ def test_simple_tet(): # creating a simple tetrahedron points = vtkPoints() - points.SetNumberOfPoints(4) - points.SetPoint(0, (0, 0, 0)) - points.SetPoint(1, (1, 0, 0)) - points.SetPoint(2, (0, 1, 0)) - points.SetPoint(3, (0, 0, 1)) + points.SetNumberOfPoints( 4 ) + points.SetPoint( 0, ( 0, 0, 0 ) ) + points.SetPoint( 1, ( 1, 0, 0 ) ) + points.SetPoint( 2, ( 0, 1, 0 ) ) + points.SetPoint( 3, ( 0, 0, 1 ) ) - cell_types = [VTK_TETRA] + cell_types = [ VTK_TETRA ] cells = vtkCellArray() - cells.AllocateExact(1, 4) + cells.AllocateExact( 1, 4 ) tet = vtkTetra() - tet.GetPointIds().SetId(0, 0) - tet.GetPointIds().SetId(1, 1) - tet.GetPointIds().SetId(2, 2) - tet.GetPointIds().SetId(3, 3) - cells.InsertNextCell(tet) + tet.GetPointIds().SetId( 0, 0 ) + tet.GetPointIds().SetId( 1, 1 ) + tet.GetPointIds().SetId( 2, 2 ) + tet.GetPointIds().SetId( 3, 3 ) + cells.InsertNextCell( tet ) mesh = vtkUnstructuredGrid() - mesh.SetPoints(points) - mesh.SetCells(cell_types, cells) + mesh.SetPoints( points ) + mesh.SetCells( cell_types, cells ) - result = __check(mesh, Options(min_volume=1.)) + result = __check( mesh, Options( min_volume=1. ) ) - assert len(result.element_volumes) == 1 - assert result.element_volumes[0][0] == 0 - assert abs(result.element_volumes[0][1] - 1./6.) < 10 * numpy.finfo(float).eps + assert len( result.element_volumes ) == 1 + assert result.element_volumes[ 0 ][ 0 ] == 0 + assert abs( result.element_volumes[ 0 ][ 1 ] - 1. / 6. ) < 10 * numpy.finfo( float ).eps - result = __check(mesh, Options(min_volume=0.)) + result = __check( mesh, Options( min_volume=0. ) ) - assert len(result.element_volumes) == 0 + assert len( result.element_volumes ) == 0 diff --git a/geosx_mesh_doctor/tests/test_generate_cube.py b/geosx_mesh_doctor/tests/test_generate_cube.py index 4d93abd..78713f8 100644 --- a/geosx_mesh_doctor/tests/test_generate_cube.py +++ b/geosx_mesh_doctor/tests/test_generate_cube.py @@ -2,23 +2,19 @@ def test_generate_cube(): - options = Options( - vtk_output=None, - generate_cells_global_ids=True, - generate_points_global_ids=False, - xs=(0, 5, 10), - ys=(0, 4, 8), - zs=(0, 1), - nxs=(5, 2), - nys=(1, 1), - nzs=(1,), - fields=( - FieldInfo(name="test", dimension=2, support="CELLS"), - ) - ) - output = __build(options) + options = Options( vtk_output=None, + generate_cells_global_ids=True, + generate_points_global_ids=False, + xs=( 0, 5, 10 ), + ys=( 0, 4, 8 ), + zs=( 0, 1 ), + nxs=( 5, 2 ), + nys=( 1, 1 ), + nzs=( 1, ), + fields=( FieldInfo( name="test", dimension=2, support="CELLS" ), ) ) + output = __build( options ) assert output.GetNumberOfCells() == 14 assert output.GetNumberOfPoints() == 48 - assert output.GetCellData().GetArray("test").GetNumberOfComponents() == 2 + assert output.GetCellData().GetArray( "test" ).GetNumberOfComponents() == 2 assert output.GetCellData().GetGlobalIds() assert not output.GetPointData().GetGlobalIds() diff --git a/geosx_mesh_doctor/tests/test_generate_fractures.py b/geosx_mesh_doctor/tests/test_generate_fractures.py index f197731..077d9d7 100644 --- a/geosx_mesh_doctor/tests/test_generate_fractures.py +++ b/geosx_mesh_doctor/tests/test_generate_fractures.py @@ -18,19 +18,17 @@ VTK_QUAD, ) from vtkmodules.util.numpy_support import ( - numpy_to_vtk, -) + numpy_to_vtk, ) from checks.vtk_utils import ( - to_vtk_id_list, -) + to_vtk_id_list, ) from checks.check_fractures import format_collocated_nodes from checks.generate_cube import build_rectilinear_blocks_mesh, XYZ from checks.generate_fractures import __split_mesh_on_fracture, Options, FracturePolicy -@dataclass(frozen=True) +@dataclass( frozen=True ) class TestResult: __test__ = False main_mesh_num_points: int @@ -39,224 +37,252 @@ class TestResult: fracture_mesh_num_cells: int -@dataclass(frozen=True) +@dataclass( frozen=True ) class TestCase: __test__ = False input_mesh: vtkUnstructuredGrid options: Options - collocated_nodes: Sequence[Sequence[int]] + collocated_nodes: Sequence[ Sequence[ int ] ] result: TestResult -def __build_test_case(xs: Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray], - attribute: Iterable[int], - field_values: Iterable[int] = None, - policy: FracturePolicy = FracturePolicy.FIELD): - xyz = XYZ(*xs) +def __build_test_case( xs: Tuple[ numpy.ndarray, numpy.ndarray, numpy.ndarray ], + attribute: Iterable[ int ], + field_values: Iterable[ int ] = None, + policy: FracturePolicy = FracturePolicy.FIELD ): + xyz = XYZ( *xs ) - mesh: vtkUnstructuredGrid = build_rectilinear_blocks_mesh((xyz, )) + mesh: vtkUnstructuredGrid = build_rectilinear_blocks_mesh( ( xyz, ) ) - ref = numpy.array(attribute, dtype=int) + ref = numpy.array( attribute, dtype=int ) if policy == FracturePolicy.FIELD: - assert len(ref) == mesh.GetNumberOfCells() - attr = numpy_to_vtk(ref) - attr.SetName("attribute") - mesh.GetCellData().AddArray(attr) + assert len( ref ) == mesh.GetNumberOfCells() + attr = numpy_to_vtk( ref ) + attr.SetName( "attribute" ) + mesh.GetCellData().AddArray( attr ) if field_values is None: - fv = frozenset(attribute) + fv = frozenset( attribute ) else: - fv = frozenset(field_values) + fv = frozenset( field_values ) - options = Options(policy=policy, - field="attribute", - field_values=fv, - vtk_output=None, - vtk_fracture_output=None) + options = Options( policy=policy, field="attribute", field_values=fv, vtk_output=None, vtk_fracture_output=None ) return mesh, options # Utility class to generate the new indices of the newly created collocated nodes. class Incrementor: - def __init__(self, start): + + def __init__( self, start ): self.__val = start - def next(self, num: int) -> Iterable[int]: + def next( self, num: int ) -> Iterable[ int ]: self.__val += num - return range(self.__val - num, self.__val) + return range( self.__val - num, self.__val ) -def __generate_test_data() -> Iterator[TestCase]: - two_nodes = numpy.arange(2, dtype=float) - three_nodes = numpy.arange(3, dtype=float) - four_nodes = numpy.arange(4, dtype=float) +def __generate_test_data() -> Iterator[ TestCase ]: + two_nodes = numpy.arange( 2, dtype=float ) + three_nodes = numpy.arange( 3, dtype=float ) + four_nodes = numpy.arange( 4, dtype=float ) # Split in 2 - mesh, options = __build_test_case((three_nodes, three_nodes, three_nodes), (0, 1, 0, 1, 0, 1, 0, 1)) - yield TestCase(input_mesh=mesh, options=options, - collocated_nodes=tuple(map(lambda i: (1 + 3 * i, 27 + i), range(9))), - result=TestResult(9 * 4, 8, 9, 4)) + mesh, options = __build_test_case( ( three_nodes, three_nodes, three_nodes ), ( 0, 1, 0, 1, 0, 1, 0, 1 ) ) + yield TestCase( input_mesh=mesh, + options=options, + collocated_nodes=tuple( map( lambda i: ( 1 + 3 * i, 27 + i ), range( 9 ) ) ), + result=TestResult( 9 * 4, 8, 9, 4 ) ) # Split in 3 - inc = Incrementor(27) - collocated_nodes: Sequence[Sequence[int]] = ( - (1, *inc.next(1)), - (3, *inc.next(1)), - (4, *inc.next(2)), - (7, *inc.next(1)), - (1 + 9, *inc.next(1)), - (3 + 9, *inc.next(1)), - (4 + 9, *inc.next(2)), - (7 + 9, *inc.next(1)), - (1 + 18, *inc.next(1)), - (3 + 18, *inc.next(1)), - (4 + 18, *inc.next(2)), - (7 + 18, *inc.next(1)), + inc = Incrementor( 27 ) + collocated_nodes: Sequence[ Sequence[ int ] ] = ( + ( 1, *inc.next( 1 ) ), + ( 3, *inc.next( 1 ) ), + ( 4, *inc.next( 2 ) ), + ( 7, *inc.next( 1 ) ), + ( 1 + 9, *inc.next( 1 ) ), + ( 3 + 9, *inc.next( 1 ) ), + ( 4 + 9, *inc.next( 2 ) ), + ( 7 + 9, *inc.next( 1 ) ), + ( 1 + 18, *inc.next( 1 ) ), + ( 3 + 18, *inc.next( 1 ) ), + ( 4 + 18, *inc.next( 2 ) ), + ( 7 + 18, *inc.next( 1 ) ), ) - mesh, options = __build_test_case((three_nodes, three_nodes, three_nodes), (0, 1, 2, 1, 0, 1, 2, 1)) - yield TestCase(input_mesh=mesh, options=options, collocated_nodes=collocated_nodes, - result=TestResult(9 * 4 + 6, 8, 12, 6)) + mesh, options = __build_test_case( ( three_nodes, three_nodes, three_nodes ), ( 0, 1, 2, 1, 0, 1, 2, 1 ) ) + yield TestCase( input_mesh=mesh, + options=options, + collocated_nodes=collocated_nodes, + result=TestResult( 9 * 4 + 6, 8, 12, 6 ) ) # Split in 8 - inc = Incrementor(27) - collocated_nodes: Sequence[Sequence[int]] = ( - (1, *inc.next(1)), - (3, *inc.next(1)), - (4, *inc.next(3)), - (5, *inc.next(1)), - (7, *inc.next(1)), - (0 + 9, *inc.next(1)), - (1 + 9, *inc.next(3)), - (2 + 9, *inc.next(1)), - (3 + 9, *inc.next(3)), - (4 + 9, *inc.next(7)), - (5 + 9, *inc.next(3)), - (6 + 9, *inc.next(1)), - (7 + 9, *inc.next(3)), - (8 + 9, *inc.next(1)), - (1 + 18, *inc.next(1)), - (3 + 18, *inc.next(1)), - (4 + 18, *inc.next(3)), - (5 + 18, *inc.next(1)), - (7 + 18, *inc.next(1)), + inc = Incrementor( 27 ) + collocated_nodes: Sequence[ Sequence[ int ] ] = ( + ( 1, *inc.next( 1 ) ), + ( 3, *inc.next( 1 ) ), + ( 4, *inc.next( 3 ) ), + ( 5, *inc.next( 1 ) ), + ( 7, *inc.next( 1 ) ), + ( 0 + 9, *inc.next( 1 ) ), + ( 1 + 9, *inc.next( 3 ) ), + ( 2 + 9, *inc.next( 1 ) ), + ( 3 + 9, *inc.next( 3 ) ), + ( 4 + 9, *inc.next( 7 ) ), + ( 5 + 9, *inc.next( 3 ) ), + ( 6 + 9, *inc.next( 1 ) ), + ( 7 + 9, *inc.next( 3 ) ), + ( 8 + 9, *inc.next( 1 ) ), + ( 1 + 18, *inc.next( 1 ) ), + ( 3 + 18, *inc.next( 1 ) ), + ( 4 + 18, *inc.next( 3 ) ), + ( 5 + 18, *inc.next( 1 ) ), + ( 7 + 18, *inc.next( 1 ) ), ) - mesh, options = __build_test_case((three_nodes, three_nodes, three_nodes), range(8)) - yield TestCase(input_mesh=mesh, options=options, collocated_nodes=collocated_nodes, - result=TestResult(8 * 8, 8, 3 * 3 * 3 - 8, 12)) + mesh, options = __build_test_case( ( three_nodes, three_nodes, three_nodes ), range( 8 ) ) + yield TestCase( input_mesh=mesh, + options=options, + collocated_nodes=collocated_nodes, + result=TestResult( 8 * 8, 8, 3 * 3 * 3 - 8, 12 ) ) # Straight notch - inc = Incrementor(27) - collocated_nodes: Sequence[Sequence[int]] = ( - (1, *inc.next(1)), - (4,), - (1 + 9, *inc.next(1)), - (4 + 9,), - (1 + 18, *inc.next(1)), - (4 + 18,), + inc = Incrementor( 27 ) + collocated_nodes: Sequence[ Sequence[ int ] ] = ( + ( 1, *inc.next( 1 ) ), + ( 4, ), + ( 1 + 9, *inc.next( 1 ) ), + ( 4 + 9, ), + ( 1 + 18, *inc.next( 1 ) ), + ( 4 + 18, ), ) - mesh, options = __build_test_case((three_nodes, three_nodes, three_nodes), (0, 1, 2, 2, 0, 1, 2, 2), field_values=(0, 1)) - yield TestCase(input_mesh=mesh, options=options, collocated_nodes=collocated_nodes, - result=TestResult(3 * 3 * 3 + 3, 8, 6, 2)) + mesh, options = __build_test_case( ( three_nodes, three_nodes, three_nodes ), ( 0, 1, 2, 2, 0, 1, 2, 2 ), + field_values=( 0, 1 ) ) + yield TestCase( input_mesh=mesh, + options=options, + collocated_nodes=collocated_nodes, + result=TestResult( 3 * 3 * 3 + 3, 8, 6, 2 ) ) # L-shaped notch - inc = Incrementor(27) - collocated_nodes: Sequence[Sequence[int]] = ( - (1, *inc.next(1)), - (4, *inc.next(1)), - (7, *inc.next(1)), - (1 + 9, *inc.next(1)), - (4 + 9,), - (7 + 9,), - (1 + 18, *inc.next(1)), - (4 + 18,), + inc = Incrementor( 27 ) + collocated_nodes: Sequence[ Sequence[ int ] ] = ( + ( 1, *inc.next( 1 ) ), + ( 4, *inc.next( 1 ) ), + ( 7, *inc.next( 1 ) ), + ( 1 + 9, *inc.next( 1 ) ), + ( 4 + 9, ), + ( 7 + 9, ), + ( 1 + 18, *inc.next( 1 ) ), + ( 4 + 18, ), ) - mesh, options = __build_test_case((three_nodes, three_nodes, three_nodes), (0, 1, 0, 1, 0, 1, 2, 2), field_values=(0, 1)) - yield TestCase(input_mesh=mesh, options=options, collocated_nodes=collocated_nodes, - result=TestResult(3 * 3 * 3 + 5, 8, 8, 3)) + mesh, options = __build_test_case( ( three_nodes, three_nodes, three_nodes ), ( 0, 1, 0, 1, 0, 1, 2, 2 ), + field_values=( 0, 1 ) ) + yield TestCase( input_mesh=mesh, + options=options, + collocated_nodes=collocated_nodes, + result=TestResult( 3 * 3 * 3 + 5, 8, 8, 3 ) ) # 3x1x1 split - inc = Incrementor(2 * 2 * 4) - collocated_nodes: Sequence[Sequence[int]] = ( - (1, *inc.next(1)), - (2, *inc.next(1)), - (5, *inc.next(1)), - (6, *inc.next(1)), - (1 + 8, *inc.next(1)), - (2 + 8, *inc.next(1)), - (5 + 8, *inc.next(1)), - (6 + 8, *inc.next(1)), + inc = Incrementor( 2 * 2 * 4 ) + collocated_nodes: Sequence[ Sequence[ int ] ] = ( + ( 1, *inc.next( 1 ) ), + ( 2, *inc.next( 1 ) ), + ( 5, *inc.next( 1 ) ), + ( 6, *inc.next( 1 ) ), + ( 1 + 8, *inc.next( 1 ) ), + ( 2 + 8, *inc.next( 1 ) ), + ( 5 + 8, *inc.next( 1 ) ), + ( 6 + 8, *inc.next( 1 ) ), ) - mesh, options = __build_test_case((four_nodes, two_nodes, two_nodes), (0, 1, 2)) - yield TestCase(input_mesh=mesh, options=options, collocated_nodes=collocated_nodes, - result=TestResult(6 * 4, 3, 2 * 4, 2)) + mesh, options = __build_test_case( ( four_nodes, two_nodes, two_nodes ), ( 0, 1, 2 ) ) + yield TestCase( input_mesh=mesh, + options=options, + collocated_nodes=collocated_nodes, + result=TestResult( 6 * 4, 3, 2 * 4, 2 ) ) # Discarded fracture element if no node duplication. - collocated_nodes: Sequence[Sequence[int]] = () - mesh, options = __build_test_case((three_nodes, four_nodes, four_nodes), [0, ] * 8 + [1, 2] + [0, ] * 8, field_values=(1, 2)) - yield TestCase(input_mesh=mesh, options=options, collocated_nodes=collocated_nodes, - result=TestResult(3 * 4 * 4, 2 * 3 * 3, 0, 0)) + collocated_nodes: Sequence[ Sequence[ int ] ] = () + mesh, options = __build_test_case( ( three_nodes, four_nodes, four_nodes ), [ + 0, + ] * 8 + [ 1, 2 ] + [ + 0, + ] * 8, + field_values=( 1, 2 ) ) + yield TestCase( input_mesh=mesh, + options=options, + collocated_nodes=collocated_nodes, + result=TestResult( 3 * 4 * 4, 2 * 3 * 3, 0, 0 ) ) # Fracture on a corner - inc = Incrementor(3 * 4 * 4) - collocated_nodes: Sequence[Sequence[int]] = ( - (1 + 12,), - (4 + 12,), - (7 + 12,), - (1 + 12 * 2, *inc.next(1)), - (4 + 12 * 2, *inc.next(1)), - (7 + 12 * 2,), - (1 + 12 * 3, *inc.next(1)), - (4 + 12 * 3, *inc.next(1)), - (7 + 12 * 3,), + inc = Incrementor( 3 * 4 * 4 ) + collocated_nodes: Sequence[ Sequence[ int ] ] = ( + ( 1 + 12, ), + ( 4 + 12, ), + ( 7 + 12, ), + ( 1 + 12 * 2, *inc.next( 1 ) ), + ( 4 + 12 * 2, *inc.next( 1 ) ), + ( 7 + 12 * 2, ), + ( 1 + 12 * 3, *inc.next( 1 ) ), + ( 4 + 12 * 3, *inc.next( 1 ) ), + ( 7 + 12 * 3, ), ) - mesh, options = __build_test_case((three_nodes, four_nodes, four_nodes), [0, ] * 6 + [1, 2, 1, 2, 0, 0, 1, 2, 1, 2, 0, 0], field_values=(1, 2)) - yield TestCase(input_mesh=mesh, options=options, collocated_nodes=collocated_nodes, - result=TestResult(3 * 4 * 4 + 4, 2 * 3 * 3, 9, 4)) + mesh, options = __build_test_case( ( three_nodes, four_nodes, four_nodes ), [ + 0, + ] * 6 + [ 1, 2, 1, 2, 0, 0, 1, 2, 1, 2, 0, 0 ], + field_values=( 1, 2 ) ) + yield TestCase( input_mesh=mesh, + options=options, + collocated_nodes=collocated_nodes, + result=TestResult( 3 * 4 * 4 + 4, 2 * 3 * 3, 9, 4 ) ) # Generate mesh with 2 hexs, one being a standard hex, the other a 42 hex. - inc = Incrementor(3 * 2 * 2) - collocated_nodes: Sequence[Sequence[int]] = ( - (1, *inc.next(1)), - (1 + 3, *inc.next(1)), - (1 + 6, *inc.next(1)), - (1 + 9, *inc.next(1)), + inc = Incrementor( 3 * 2 * 2 ) + collocated_nodes: Sequence[ Sequence[ int ] ] = ( + ( 1, *inc.next( 1 ) ), + ( 1 + 3, *inc.next( 1 ) ), + ( 1 + 6, *inc.next( 1 ) ), + ( 1 + 9, *inc.next( 1 ) ), ) - mesh, options = __build_test_case((three_nodes, two_nodes, two_nodes), (0, 1)) + mesh, options = __build_test_case( ( three_nodes, two_nodes, two_nodes ), ( 0, 1 ) ) polyhedron_mesh = vtkUnstructuredGrid() - polyhedron_mesh.SetPoints(mesh.GetPoints()) - polyhedron_mesh.Allocate(2) - polyhedron_mesh.InsertNextCell(VTK_HEXAHEDRON, to_vtk_id_list((1, 2, 5, 4, 7, 8, 10, 11))) - poly = to_vtk_id_list([6] + [4, 0, 1, 7, 6] + [4, 1, 4, 10, 7] + [4, 4, 3, 9, 10] + [4, 3, 0, 6, 9] + [4, 6, 7, 10, 9] + [4, 1, 0, 3, 4]) - polyhedron_mesh.InsertNextCell(VTK_POLYHEDRON, poly) - polyhedron_mesh.GetCellData().AddArray(mesh.GetCellData().GetArray("attribute")) - - yield TestCase(input_mesh=polyhedron_mesh, options=options, collocated_nodes=collocated_nodes, - result=TestResult(4 * 4, 2, 4, 1)) + polyhedron_mesh.SetPoints( mesh.GetPoints() ) + polyhedron_mesh.Allocate( 2 ) + polyhedron_mesh.InsertNextCell( VTK_HEXAHEDRON, to_vtk_id_list( ( 1, 2, 5, 4, 7, 8, 10, 11 ) ) ) + poly = to_vtk_id_list( [ 6 ] + [ 4, 0, 1, 7, 6 ] + [ 4, 1, 4, 10, 7 ] + [ 4, 4, 3, 9, 10 ] + [ 4, 3, 0, 6, 9 ] + + [ 4, 6, 7, 10, 9 ] + [ 4, 1, 0, 3, 4 ] ) + polyhedron_mesh.InsertNextCell( VTK_POLYHEDRON, poly ) + polyhedron_mesh.GetCellData().AddArray( mesh.GetCellData().GetArray( "attribute" ) ) + + yield TestCase( input_mesh=polyhedron_mesh, + options=options, + collocated_nodes=collocated_nodes, + result=TestResult( 4 * 4, 2, 4, 1 ) ) # Split in 2 using the internal fracture description - inc = Incrementor(3 * 2 * 2) - collocated_nodes: Sequence[Sequence[int]] = ( - (1, *inc.next(1)), - (1 + 3, *inc.next(1)), - (1 + 6, *inc.next(1)), - (1 + 9, *inc.next(1)), + inc = Incrementor( 3 * 2 * 2 ) + collocated_nodes: Sequence[ Sequence[ int ] ] = ( + ( 1, *inc.next( 1 ) ), + ( 1 + 3, *inc.next( 1 ) ), + ( 1 + 6, *inc.next( 1 ) ), + ( 1 + 9, *inc.next( 1 ) ), ) - mesh, options = __build_test_case((three_nodes, two_nodes, two_nodes), attribute=(0, 0, 0), field_values=(0,), - policy=FracturePolicy.INTERNAL_SURFACES) - mesh.InsertNextCell(VTK_QUAD, to_vtk_id_list((1, 4, 7, 10))) # Add a fracture on the fly - yield TestCase(input_mesh=mesh, options=options, - collocated_nodes=collocated_nodes, - result=TestResult(4 * 4, 3, 4, 1)) - - -@pytest.mark.parametrize("test_case", __generate_test_data()) -def test_generate_fracture(test_case: TestCase): - main_mesh, fracture_mesh = __split_mesh_on_fracture(test_case.input_mesh, test_case.options) + mesh, options = __build_test_case( ( three_nodes, two_nodes, two_nodes ), + attribute=( 0, 0, 0 ), + field_values=( 0, ), + policy=FracturePolicy.INTERNAL_SURFACES ) + mesh.InsertNextCell( VTK_QUAD, to_vtk_id_list( ( 1, 4, 7, 10 ) ) ) # Add a fracture on the fly + yield TestCase( input_mesh=mesh, + options=options, + collocated_nodes=collocated_nodes, + result=TestResult( 4 * 4, 3, 4, 1 ) ) + + +@pytest.mark.parametrize( "test_case", __generate_test_data() ) +def test_generate_fracture( test_case: TestCase ): + main_mesh, fracture_mesh = __split_mesh_on_fracture( test_case.input_mesh, test_case.options ) assert main_mesh.GetNumberOfPoints() == test_case.result.main_mesh_num_points assert main_mesh.GetNumberOfCells() == test_case.result.main_mesh_num_cells assert fracture_mesh.GetNumberOfPoints() == test_case.result.fracture_mesh_num_points assert fracture_mesh.GetNumberOfCells() == test_case.result.fracture_mesh_num_cells - res = format_collocated_nodes(fracture_mesh) + res = format_collocated_nodes( fracture_mesh ) assert res == test_case.collocated_nodes - assert len(res) == test_case.result.fracture_mesh_num_points + assert len( res ) == test_case.result.fracture_mesh_num_points diff --git a/geosx_mesh_doctor/tests/test_generate_global_ids.py b/geosx_mesh_doctor/tests/test_generate_global_ids.py index 5dc7c1b..f2998d7 100644 --- a/geosx_mesh_doctor/tests/test_generate_global_ids.py +++ b/geosx_mesh_doctor/tests/test_generate_global_ids.py @@ -12,19 +12,19 @@ def test_generate_global_ids(): points = vtkPoints() - points.InsertNextPoint(0, 0, 0) + points.InsertNextPoint( 0, 0, 0 ) vertex = vtkVertex() - vertex.GetPointIds().SetId(0, 0) + vertex.GetPointIds().SetId( 0, 0 ) vertices = vtkCellArray() - vertices.InsertNextCell(vertex) + vertices.InsertNextCell( vertex ) mesh = vtkUnstructuredGrid() - mesh.SetPoints(points) - mesh.SetCells([VTK_VERTEX], vertices) + mesh.SetPoints( points ) + mesh.SetCells( [ VTK_VERTEX ], vertices ) - __build_global_ids(mesh, True, True) + __build_global_ids( mesh, True, True ) global_cell_ids = mesh.GetCellData().GetGlobalIds() global_point_ids = mesh.GetPointData().GetGlobalIds() diff --git a/geosx_mesh_doctor/tests/test_non_conformal.py b/geosx_mesh_doctor/tests/test_non_conformal.py index bcf60fe..0351ae9 100644 --- a/geosx_mesh_doctor/tests/test_non_conformal.py +++ b/geosx_mesh_doctor/tests/test_non_conformal.py @@ -9,59 +9,59 @@ def test_two_close_hexs(): delta = 1.e-6 - tmp = numpy.arange(2, dtype=float) - xyz0 = XYZ(tmp, tmp, tmp) - xyz1 = XYZ(tmp + 1 + delta, tmp, tmp) - mesh = build_rectilinear_blocks_mesh((xyz0, xyz1)) + tmp = numpy.arange( 2, dtype=float ) + xyz0 = XYZ( tmp, tmp, tmp ) + xyz1 = XYZ( tmp + 1 + delta, tmp, tmp ) + mesh = build_rectilinear_blocks_mesh( ( xyz0, xyz1 ) ) # Close enough, but points tolerance is too strict to consider the faces matching. - options = Options(angle_tolerance=1., point_tolerance=delta / 2, face_tolerance=delta * 2) - results = __check(mesh, options) - assert len(results.non_conformal_cells) == 1 - assert set(results.non_conformal_cells[0]) == {0, 1} + options = Options( angle_tolerance=1., point_tolerance=delta / 2, face_tolerance=delta * 2 ) + results = __check( mesh, options ) + assert len( results.non_conformal_cells ) == 1 + assert set( results.non_conformal_cells[ 0 ] ) == { 0, 1 } # Close enough, and points tolerance is loose enough to consider the faces matching. - options = Options(angle_tolerance=1., point_tolerance=delta * 2, face_tolerance=delta * 2) - results = __check(mesh, options) - assert len(results.non_conformal_cells) == 0 + options = Options( angle_tolerance=1., point_tolerance=delta * 2, face_tolerance=delta * 2 ) + results = __check( mesh, options ) + assert len( results.non_conformal_cells ) == 0 def test_two_distant_hexs(): delta = 1 - tmp = numpy.arange(2, dtype=float) - xyz0 = XYZ(tmp, tmp, tmp) - xyz1 = XYZ(tmp + 1 + delta, tmp, tmp) - mesh = build_rectilinear_blocks_mesh((xyz0, xyz1)) + tmp = numpy.arange( 2, dtype=float ) + xyz0 = XYZ( tmp, tmp, tmp ) + xyz1 = XYZ( tmp + 1 + delta, tmp, tmp ) + mesh = build_rectilinear_blocks_mesh( ( xyz0, xyz1 ) ) - options = Options(angle_tolerance=1., point_tolerance=delta / 2., face_tolerance=delta / 2.) + options = Options( angle_tolerance=1., point_tolerance=delta / 2., face_tolerance=delta / 2. ) - results = __check(mesh, options) - assert len(results.non_conformal_cells) == 0 + results = __check( mesh, options ) + assert len( results.non_conformal_cells ) == 0 def test_two_close_shifted_hexs(): delta_x, delta_y = 1.e-6, 0.5 - tmp = numpy.arange(2, dtype=float) - xyz0 = XYZ(tmp, tmp, tmp) - xyz1 = XYZ(tmp + 1 + delta_x, tmp + delta_y, tmp + delta_y) - mesh = build_rectilinear_blocks_mesh((xyz0, xyz1)) + tmp = numpy.arange( 2, dtype=float ) + xyz0 = XYZ( tmp, tmp, tmp ) + xyz1 = XYZ( tmp + 1 + delta_x, tmp + delta_y, tmp + delta_y ) + mesh = build_rectilinear_blocks_mesh( ( xyz0, xyz1 ) ) - options = Options(angle_tolerance=1., point_tolerance=delta_x * 2, face_tolerance=delta_x * 2) + options = Options( angle_tolerance=1., point_tolerance=delta_x * 2, face_tolerance=delta_x * 2 ) - results = __check(mesh, options) - assert len(results.non_conformal_cells) == 1 - assert set(results.non_conformal_cells[0]) == {0, 1} + results = __check( mesh, options ) + assert len( results.non_conformal_cells ) == 1 + assert set( results.non_conformal_cells[ 0 ] ) == { 0, 1 } def test_big_elem_next_to_small_elem(): delta = 1.e-6 - tmp = numpy.arange(2, dtype=float) - xyz0 = XYZ(tmp, tmp + 1, tmp + 1) - xyz1 = XYZ(3 * tmp + 1 + delta, 3 * tmp, 3 * tmp) - mesh = build_rectilinear_blocks_mesh((xyz0, xyz1)) + tmp = numpy.arange( 2, dtype=float ) + xyz0 = XYZ( tmp, tmp + 1, tmp + 1 ) + xyz1 = XYZ( 3 * tmp + 1 + delta, 3 * tmp, 3 * tmp ) + mesh = build_rectilinear_blocks_mesh( ( xyz0, xyz1 ) ) - options = Options(angle_tolerance=1., point_tolerance=delta * 2, face_tolerance=delta * 2) + options = Options( angle_tolerance=1., point_tolerance=delta * 2, face_tolerance=delta * 2 ) - results = __check(mesh, options) - assert len(results.non_conformal_cells) == 1 - assert set(results.non_conformal_cells[0]) == {0, 1} + results = __check( mesh, options ) + assert len( results.non_conformal_cells ) == 1 + assert set( results.non_conformal_cells[ 0 ] ) == { 0, 1 } diff --git a/geosx_mesh_doctor/tests/test_reorient_mesh.py b/geosx_mesh_doctor/tests/test_reorient_mesh.py index 1136bbb..f0ffd29 100644 --- a/geosx_mesh_doctor/tests/test_reorient_mesh.py +++ b/geosx_mesh_doctor/tests/test_reorient_mesh.py @@ -22,88 +22,81 @@ ) -@dataclass(frozen=True) +@dataclass( frozen=True ) class Expected: mesh: vtkUnstructuredGrid face_stream: FaceStream -def __build_test_meshes() -> Generator[Expected, None, None]: +def __build_test_meshes() -> Generator[ Expected, None, None ]: # Creating the support nodes for the polyhedron. # It has a C shape and is actually non-convex, non star-shaped. - front_nodes = numpy.array(( - (0, 0, 0), - (3, 0, 0), - (3, 1, 0), - (1, 1, 0), - (1, 2, 0), - (3, 2, 0), - (3, 3, 0), - (0, 3, 0), - ), dtype=float) - front_nodes = numpy.array(front_nodes, dtype=float) - back_nodes = front_nodes - (0., 0., 1.) - - n = len(front_nodes) + front_nodes = numpy.array( ( + ( 0, 0, 0 ), + ( 3, 0, 0 ), + ( 3, 1, 0 ), + ( 1, 1, 0 ), + ( 1, 2, 0 ), + ( 3, 2, 0 ), + ( 3, 3, 0 ), + ( 0, 3, 0 ), + ), + dtype=float ) + front_nodes = numpy.array( front_nodes, dtype=float ) + back_nodes = front_nodes - ( 0., 0., 1. ) + + n = len( front_nodes ) points = vtkPoints() - points.Allocate(2 * n) + points.Allocate( 2 * n ) for coords in front_nodes: - points.InsertNextPoint(coords) + points.InsertNextPoint( coords ) for coords in back_nodes: - points.InsertNextPoint(coords) + points.InsertNextPoint( coords ) # Creating the polyhedron with faces all directed outward. faces = [] # Creating the side faces - for i in range(n): - faces.append( - (i % n + n, (i + 1) % n + n, (i + 1) % n, i % n) - ) + for i in range( n ): + faces.append( ( i % n + n, ( i + 1 ) % n + n, ( i + 1 ) % n, i % n ) ) # Creating the front faces - faces.append(tuple(range(n))) - faces.append(tuple(reversed(range(n, 2 * n)))) - face_stream = FaceStream(faces) + faces.append( tuple( range( n ) ) ) + faces.append( tuple( reversed( range( n, 2 * n ) ) ) ) + face_stream = FaceStream( faces ) # Creating multiple meshes, each time with one unique polyhedron, # but with different "face flip status". # First case, no face is flipped. mesh = vtkUnstructuredGrid() - mesh.Allocate(1) - mesh.SetPoints(points) - mesh.InsertNextCell(VTK_POLYHEDRON, to_vtk_id_list( - face_stream.dump() - )) - yield Expected(mesh=mesh, face_stream=face_stream) + mesh.Allocate( 1 ) + mesh.SetPoints( points ) + mesh.InsertNextCell( VTK_POLYHEDRON, to_vtk_id_list( face_stream.dump() ) ) + yield Expected( mesh=mesh, face_stream=face_stream ) # Here, two faces are flipped. mesh = vtkUnstructuredGrid() - mesh.Allocate(1) - mesh.SetPoints(points) - mesh.InsertNextCell(VTK_POLYHEDRON, to_vtk_id_list( - face_stream.flip_faces((1, 2)).dump() - )) - yield Expected(mesh=mesh, face_stream=face_stream) + mesh.Allocate( 1 ) + mesh.SetPoints( points ) + mesh.InsertNextCell( VTK_POLYHEDRON, to_vtk_id_list( face_stream.flip_faces( ( 1, 2 ) ).dump() ) ) + yield Expected( mesh=mesh, face_stream=face_stream ) # Last, all faces are flipped. mesh = vtkUnstructuredGrid() - mesh.Allocate(1) - mesh.SetPoints(points) - mesh.InsertNextCell(VTK_POLYHEDRON, to_vtk_id_list( - face_stream.flip_faces(range(len(faces))).dump() - )) - yield Expected(mesh=mesh, face_stream=face_stream) + mesh.Allocate( 1 ) + mesh.SetPoints( points ) + mesh.InsertNextCell( VTK_POLYHEDRON, to_vtk_id_list( face_stream.flip_faces( range( len( faces ) ) ).dump() ) ) + yield Expected( mesh=mesh, face_stream=face_stream ) -@pytest.mark.parametrize("expected", __build_test_meshes()) -def test_reorient_polyhedron(expected: Expected): - output_mesh = reorient_mesh(expected.mesh, range(expected.mesh.GetNumberOfCells())) +@pytest.mark.parametrize( "expected", __build_test_meshes() ) +def test_reorient_polyhedron( expected: Expected ): + output_mesh = reorient_mesh( expected.mesh, range( expected.mesh.GetNumberOfCells() ) ) assert output_mesh.GetNumberOfCells() == 1 - assert output_mesh.GetCell(0).GetCellType() == VTK_POLYHEDRON + assert output_mesh.GetCell( 0 ).GetCellType() == VTK_POLYHEDRON face_stream_ids = vtkIdList() - output_mesh.GetFaceStream(0, face_stream_ids) + output_mesh.GetFaceStream( 0, face_stream_ids ) # Note that the following makes a raw (but simple) check. # But one may need to be more precise some day, # since triangular faces (0, 1, 2) and (1, 2, 0) should be considered as equivalent. # And the current simpler check does not consider this case. - assert tuple(vtk_iter(face_stream_ids)) == expected.face_stream.dump() + assert tuple( vtk_iter( face_stream_ids ) ) == expected.face_stream.dump() diff --git a/geosx_mesh_doctor/tests/test_self_intersecting_elements.py b/geosx_mesh_doctor/tests/test_self_intersecting_elements.py index 8993e68..2b5455a 100644 --- a/geosx_mesh_doctor/tests/test_self_intersecting_elements.py +++ b/geosx_mesh_doctor/tests/test_self_intersecting_elements.py @@ -1,6 +1,5 @@ from vtkmodules.vtkCommonCore import ( - vtkPoints, -) + vtkPoints, ) from vtkmodules.vtkCommonDataModel import ( VTK_HEXAHEDRON, vtkCellArray, @@ -8,43 +7,42 @@ vtkUnstructuredGrid, ) - from checks.self_intersecting_elements import Options, __check def test_jumbled_hex(): # creating a simple hexahedron points = vtkPoints() - points.SetNumberOfPoints(8) - points.SetPoint(0, (0, 0, 0)) - points.SetPoint(1, (1, 0, 0)) - points.SetPoint(2, (1, 1, 0)) - points.SetPoint(3, (0, 1, 0)) - points.SetPoint(4, (0, 0, 1)) - points.SetPoint(5, (1, 0, 1)) - points.SetPoint(6, (1, 1, 1)) - points.SetPoint(7, (0, 1, 1)) - - cell_types = [VTK_HEXAHEDRON] + points.SetNumberOfPoints( 8 ) + points.SetPoint( 0, ( 0, 0, 0 ) ) + points.SetPoint( 1, ( 1, 0, 0 ) ) + points.SetPoint( 2, ( 1, 1, 0 ) ) + points.SetPoint( 3, ( 0, 1, 0 ) ) + points.SetPoint( 4, ( 0, 0, 1 ) ) + points.SetPoint( 5, ( 1, 0, 1 ) ) + points.SetPoint( 6, ( 1, 1, 1 ) ) + points.SetPoint( 7, ( 0, 1, 1 ) ) + + cell_types = [ VTK_HEXAHEDRON ] cells = vtkCellArray() - cells.AllocateExact(1, 8) + cells.AllocateExact( 1, 8 ) hex = vtkHexahedron() - hex.GetPointIds().SetId(0, 0) - hex.GetPointIds().SetId(1, 1) - hex.GetPointIds().SetId(2, 3) # Intentionally wrong - hex.GetPointIds().SetId(3, 2) # Intentionally wrong - hex.GetPointIds().SetId(4, 4) - hex.GetPointIds().SetId(5, 5) - hex.GetPointIds().SetId(6, 6) - hex.GetPointIds().SetId(7, 7) - cells.InsertNextCell(hex) + hex.GetPointIds().SetId( 0, 0 ) + hex.GetPointIds().SetId( 1, 1 ) + hex.GetPointIds().SetId( 2, 3 ) # Intentionally wrong + hex.GetPointIds().SetId( 3, 2 ) # Intentionally wrong + hex.GetPointIds().SetId( 4, 4 ) + hex.GetPointIds().SetId( 5, 5 ) + hex.GetPointIds().SetId( 6, 6 ) + hex.GetPointIds().SetId( 7, 7 ) + cells.InsertNextCell( hex ) mesh = vtkUnstructuredGrid() - mesh.SetPoints(points) - mesh.SetCells(cell_types, cells) + mesh.SetPoints( points ) + mesh.SetCells( cell_types, cells ) - result = __check(mesh, Options(tolerance=0.)) + result = __check( mesh, Options( tolerance=0. ) ) - assert len(result.intersecting_faces_elements) == 1 - assert result.intersecting_faces_elements[0] == 0 + assert len( result.intersecting_faces_elements ) == 1 + assert result.intersecting_faces_elements[ 0 ] == 0 diff --git a/geosx_mesh_doctor/tests/test_supported_elements.py b/geosx_mesh_doctor/tests/test_supported_elements.py index 639d904..9d56932 100644 --- a/geosx_mesh_doctor/tests/test_supported_elements.py +++ b/geosx_mesh_doctor/tests/test_supported_elements.py @@ -15,31 +15,30 @@ from checks.supported_elements import Options, check, __check from checks.vtk_polyhedron import parse_face_stream, build_face_to_face_connectivity_through_edges, FaceStream from checks.vtk_utils import ( - to_vtk_id_list, -) + to_vtk_id_list, ) -@pytest.mark.parametrize("base_name", - ("supportedElements.vtk", "supportedElementsAsVTKPolyhedra.vtk")) -def test_supported_elements(base_name) -> None: +@pytest.mark.parametrize( "base_name", ( "supportedElements.vtk", "supportedElementsAsVTKPolyhedra.vtk" ) ) +def test_supported_elements( base_name ) -> None: """ Testing that the supported elements are properly detected as supported! :param base_name: Supported elements are provided as standard elements or polyhedron elements. """ - directory = os.path.dirname(os.path.realpath(__file__)) - supported_elements_file_name = os.path.join(directory, "../../../../unitTests/meshTests", base_name) - options = Options(chunk_size=1, num_proc=4) - result = check(supported_elements_file_name, options) + directory = os.path.dirname( os.path.realpath( __file__ ) ) + supported_elements_file_name = os.path.join( directory, "../../../../unitTests/meshTests", base_name ) + options = Options( chunk_size=1, num_proc=4 ) + result = check( supported_elements_file_name, options ) assert not result.unsupported_std_elements_types assert not result.unsupported_polyhedron_elements -def make_dodecahedron() -> Tuple[vtkPoints, vtkIdList]: +def make_dodecahedron() -> Tuple[ vtkPoints, vtkIdList ]: """ Returns the points and faces for a dodecahedron. This code was adapted from an official vtk example. :return: The tuple of points and faces (as vtk instances). """ + # yapf: disable points = ( (1.21412, 0, 1.58931), (0.375185, 1.1547, 1.58931), @@ -76,13 +75,14 @@ def make_dodecahedron() -> Tuple[vtkPoints, vtkIdList]: 5, 18, 13, 8, 12, 17, 5, 19, 14, 9, 13, 18, 5, 19, 18, 17, 16, 15) + # yapf: enable p = vtkPoints() - p.Allocate(len(points)) + p.Allocate( len( points ) ) for coords in points: - p.InsertNextPoint(coords) + p.InsertNextPoint( coords ) - f = to_vtk_id_list(faces) + f = to_vtk_id_list( faces ) return p, f @@ -93,18 +93,19 @@ def test_dodecahedron() -> None: """ points, faces = make_dodecahedron() mesh = vtkUnstructuredGrid() - mesh.Allocate(1) - mesh.SetPoints(points) - mesh.InsertNextCell(VTK_POLYHEDRON, faces) + mesh.Allocate( 1 ) + mesh.SetPoints( points ) + mesh.InsertNextCell( VTK_POLYHEDRON, faces ) - result = __check(mesh, Options(num_proc=1, chunk_size=1)) - assert set(result.unsupported_polyhedron_elements) == {0} + result = __check( mesh, Options( num_proc=1, chunk_size=1 ) ) + assert set( result.unsupported_polyhedron_elements ) == { 0 } assert not result.unsupported_std_elements_types def test_parse_face_stream() -> None: _, faces = make_dodecahedron() - result = parse_face_stream(faces) + result = parse_face_stream( faces ) + # yapf: disable expected = ( (0, 1, 2, 3, 4), (0, 5, 10, 6, 1), @@ -119,7 +120,8 @@ def test_parse_face_stream() -> None: (19, 14, 9, 13, 18), (19, 18, 17, 16, 15) ) + # yapf: enable assert result == expected - face_stream = FaceStream.build_from_vtk_id_list(faces) + face_stream = FaceStream.build_from_vtk_id_list( faces ) assert face_stream.num_faces == 12 assert face_stream.num_support_points == 20 diff --git a/geosx_mesh_doctor/tests/test_triangle_distance.py b/geosx_mesh_doctor/tests/test_triangle_distance.py index 605169b..f44b5e8 100644 --- a/geosx_mesh_doctor/tests/test_triangle_distance.py +++ b/geosx_mesh_doctor/tests/test_triangle_distance.py @@ -7,7 +7,7 @@ from checks.triangle_distance import distance_between_two_segments, distance_between_two_triangles -@dataclass(frozen=True) +@dataclass( frozen=True ) class ExpectedSeg: p0: numpy.array u0: numpy.array @@ -17,88 +17,82 @@ class ExpectedSeg: y: numpy.array @classmethod - def from_tuples(cls, p0, u0, p1, u1, x, y): - return cls( - numpy.array(p0), - numpy.array(u0), - numpy.array(p1), - numpy.array(u1), - numpy.array(x), - numpy.array(y) - ) + def from_tuples( cls, p0, u0, p1, u1, x, y ): + return cls( numpy.array( p0 ), numpy.array( u0 ), numpy.array( p1 ), numpy.array( u1 ), numpy.array( x ), + numpy.array( y ) ) def __get_segments_references(): # Node to node configuration. yield ExpectedSeg.from_tuples( - p0=(0., 0., 0.), - u0=(1., 0., 0.), - p1=(2., 0., 0.), - u1=(1., 0., 0.), - x=(1., 0., 0.), - y=(2., 0., 0.), + p0=( 0., 0., 0. ), + u0=( 1., 0., 0. ), + p1=( 2., 0., 0. ), + u1=( 1., 0., 0. ), + x=( 1., 0., 0. ), + y=( 2., 0., 0. ), ) # Node to edge configuration. yield ExpectedSeg.from_tuples( - p0=(0., 0., 0.), - u0=(1., 0., 0.), - p1=(2., -1., -1.), - u1=(0., 1., 1.), - x=(1., 0., 0.), - y=(2., 0., 0.), + p0=( 0., 0., 0. ), + u0=( 1., 0., 0. ), + p1=( 2., -1., -1. ), + u1=( 0., 1., 1. ), + x=( 1., 0., 0. ), + y=( 2., 0., 0. ), ) # Edge to edge configuration. yield ExpectedSeg.from_tuples( - p0=(0., 0., -1.), - u0=(0., 0., 2.), - p1=(1., -1., -1.), - u1=(0., 2., 2.), - x=(0., 0., 0.), - y=(1., 0., 0.), + p0=( 0., 0., -1. ), + u0=( 0., 0., 2. ), + p1=( 1., -1., -1. ), + u1=( 0., 2., 2. ), + x=( 0., 0., 0. ), + y=( 1., 0., 0. ), ) # Example from "On fast computation of distance between line segments" by Vladimir J. Lumelsky. # Information Processing Letters, Vol. 21, number 2, pages 55-61, 08/16/1985. # It's a node to edge configuration. yield ExpectedSeg.from_tuples( - p0=(0., 0., 0.), - u0=(1., 2., 1.), - p1=(1., 0., 0.), - u1=(1., 1., 0.), - x=(1./6., 2./6., 1./6.), - y=(1., 0., 0.), + p0=( 0., 0., 0. ), + u0=( 1., 2., 1. ), + p1=( 1., 0., 0. ), + u1=( 1., 1., 0. ), + x=( 1. / 6., 2. / 6., 1. / 6. ), + y=( 1., 0., 0. ), ) # Overlapping edges. yield ExpectedSeg.from_tuples( - p0=(0., 0., 0.), - u0=(2., 0., 0.), - p1=(1., 0., 0.), - u1=(2., 0., 0.), - x=(0., 0., 0.), - y=(0., 0., 0.), + p0=( 0., 0., 0. ), + u0=( 2., 0., 0. ), + p1=( 1., 0., 0. ), + u1=( 2., 0., 0. ), + x=( 0., 0., 0. ), + y=( 0., 0., 0. ), ) # Crossing edges. yield ExpectedSeg.from_tuples( - p0=(0., 0., 0.), - u0=(2., 0., 0.), - p1=(1., -1., 0.), - u1=(0., 2., 0.), - x=(0., 0., 0.), - y=(0., 0., 0.), + p0=( 0., 0., 0. ), + u0=( 2., 0., 0. ), + p1=( 1., -1., 0. ), + u1=( 0., 2., 0. ), + x=( 0., 0., 0. ), + y=( 0., 0., 0. ), ) -@pytest.mark.parametrize("expected", __get_segments_references()) -def test_segments(expected: ExpectedSeg): - eps = numpy.finfo(float).eps - x, y = distance_between_two_segments(expected.p0, expected.u0, expected.p1, expected.u1) - if norm(expected.x - expected.y) == 0: - assert norm(x - y) == 0. +@pytest.mark.parametrize( "expected", __get_segments_references() ) +def test_segments( expected: ExpectedSeg ): + eps = numpy.finfo( float ).eps + x, y = distance_between_two_segments( expected.p0, expected.u0, expected.p1, expected.u1 ) + if norm( expected.x - expected.y ) == 0: + assert norm( x - y ) == 0. else: - assert norm(expected.x - x) < eps - assert norm(expected.y - y) < eps + assert norm( expected.x - x ) < eps + assert norm( expected.y - y ) < eps -@dataclass(frozen=True) +@dataclass( frozen=True ) class ExpectedTri: t0: numpy.array t1: numpy.array @@ -107,72 +101,54 @@ class ExpectedTri: p1: numpy.array @classmethod - def from_tuples(cls, t0, t1, d, p0, p1): - return cls( - numpy.array(t0), - numpy.array(t1), - float(d), - numpy.array(p0), - numpy.array(p1) - ) + def from_tuples( cls, t0, t1, d, p0, p1 ): + return cls( numpy.array( t0 ), numpy.array( t1 ), float( d ), numpy.array( p0 ), numpy.array( p1 ) ) def __get_triangles_references(): # Node to node configuration. - yield ExpectedTri.from_tuples( - t0=((0., 0., 0.), (1., 0., 0.), (0., 1., 1.)), - t1=((2., 0., 0.), (3., 0., 0.), (2., 1., 1.)), - d=1., - p0=(1., 0., 0.), - p1=(2., 0., 0.) - ) + yield ExpectedTri.from_tuples( t0=( ( 0., 0., 0. ), ( 1., 0., 0. ), ( 0., 1., 1. ) ), + t1=( ( 2., 0., 0. ), ( 3., 0., 0. ), ( 2., 1., 1. ) ), + d=1., + p0=( 1., 0., 0. ), + p1=( 2., 0., 0. ) ) # Node to edge configuration. - yield ExpectedTri.from_tuples( - t0=((0., 0., 0.), (1., 0., 0.), (0., 1., 1.)), - t1=((2., -1., 0.), (3., 0., 0.), (2., 1., 0.)), - d=1., - p0=(1., 0., 0.), - p1=(2., 0., 0.) - ) + yield ExpectedTri.from_tuples( t0=( ( 0., 0., 0. ), ( 1., 0., 0. ), ( 0., 1., 1. ) ), + t1=( ( 2., -1., 0. ), ( 3., 0., 0. ), ( 2., 1., 0. ) ), + d=1., + p0=( 1., 0., 0. ), + p1=( 2., 0., 0. ) ) # Edge to edge configuration. - yield ExpectedTri.from_tuples( - t0=((0., 0., 0.), (1., 1., 1.), (1., -1., -1.)), - t1=((2., -1., 0.), (2., 1., 0.), (3., 0., 0.)), - d=1., - p0=(1., 0., 0.), - p1=(2., 0., 0.) - ) + yield ExpectedTri.from_tuples( t0=( ( 0., 0., 0. ), ( 1., 1., 1. ), ( 1., -1., -1. ) ), + t1=( ( 2., -1., 0. ), ( 2., 1., 0. ), ( 3., 0., 0. ) ), + d=1., + p0=( 1., 0., 0. ), + p1=( 2., 0., 0. ) ) # Point to face configuration. - yield ExpectedTri.from_tuples( - t0=((0., 0., 0.), (1., 0., 0.), (0., 1., 1.)), - t1=((2., -1., 0.), (2., 1., -1.), (2, 1., 1.)), - d=1., - p0=(1., 0., 0.), - p1=(2., 0., 0.) - ) + yield ExpectedTri.from_tuples( t0=( ( 0., 0., 0. ), ( 1., 0., 0. ), ( 0., 1., 1. ) ), + t1=( ( 2., -1., 0. ), ( 2., 1., -1. ), ( 2, 1., 1. ) ), + d=1., + p0=( 1., 0., 0. ), + p1=( 2., 0., 0. ) ) # Same triangles configuration. - yield ExpectedTri.from_tuples( - t0=((0., 0., 0.), (1., 0., 0.), (0., 1., 1.)), - t1=((0., 0., 0.), (1., 0., 0.), (0., 1., 1.)), - d=0., - p0=(0., 0., 0.), - p1=(0., 0., 0.) - ) + yield ExpectedTri.from_tuples( t0=( ( 0., 0., 0. ), ( 1., 0., 0. ), ( 0., 1., 1. ) ), + t1=( ( 0., 0., 0. ), ( 1., 0., 0. ), ( 0., 1., 1. ) ), + d=0., + p0=( 0., 0., 0. ), + p1=( 0., 0., 0. ) ) # Crossing triangles configuration. - yield ExpectedTri.from_tuples( - t0=((0., 0., 0.), (2., 0., 0.), (2., 0., 1.)), - t1=((1., -1., 0.), (1., 1., 0.), (1., 1., 1.)), - d=0., - p0=(0., 0., 0.), - p1=(0., 0., 0.) - ) - - -@pytest.mark.parametrize("expected", __get_triangles_references()) -def test_triangles(expected: ExpectedTri): - eps = numpy.finfo(float).eps - d, p0, p1 = distance_between_two_triangles(expected.t0, expected.t1) - assert abs(d - expected.d) < eps + yield ExpectedTri.from_tuples( t0=( ( 0., 0., 0. ), ( 2., 0., 0. ), ( 2., 0., 1. ) ), + t1=( ( 1., -1., 0. ), ( 1., 1., 0. ), ( 1., 1., 1. ) ), + d=0., + p0=( 0., 0., 0. ), + p1=( 0., 0., 0. ) ) + + +@pytest.mark.parametrize( "expected", __get_triangles_references() ) +def test_triangles( expected: ExpectedTri ): + eps = numpy.finfo( float ).eps + d, p0, p1 = distance_between_two_triangles( expected.t0, expected.t1 ) + assert abs( d - expected.d ) < eps if d != 0: - assert norm(p0 - expected.p0) < eps - assert norm(p1 - expected.p1) < eps + assert norm( p0 - expected.p0 ) < eps + assert norm( p1 - expected.p1 ) < eps diff --git a/geosx_mesh_tools_package/geosx_mesh_tools/abaqus_converter.py b/geosx_mesh_tools_package/geosx_mesh_tools/abaqus_converter.py index 14f6299..9c5e162 100644 --- a/geosx_mesh_tools_package/geosx_mesh_tools/abaqus_converter.py +++ b/geosx_mesh_tools_package/geosx_mesh_tools/abaqus_converter.py @@ -1,10 +1,10 @@ -import meshio # type: ignore[import] -from meshio._mesh import CellBlock # type: ignore[import] +import meshio # type: ignore[import] +from meshio._mesh import CellBlock # type: ignore[import] import numpy as np import logging -def convert_abaqus_to_gmsh(input_mesh: str, output_mesh: str, logger: logging.Logger = None) -> int: +def convert_abaqus_to_gmsh( input_mesh: str, output_mesh: str, logger: logging.Logger = None ) -> int: """ Convert an abaqus mesh to gmsh 2 format, preserving nodeset information. @@ -22,112 +22,113 @@ def convert_abaqus_to_gmsh(input_mesh: str, output_mesh: str, logger: logging.Lo """ # Initialize the logger if it is empty if not logger: - logging.basicConfig(level=logging.WARNING) - logger = logging.getLogger(__name__) + logging.basicConfig( level=logging.WARNING ) + logger = logging.getLogger( __name__ ) # Keep track of the number of warnings n_warnings = 0 # Load the mesh - logger.info('Reading abaqus mesh...') - mesh = meshio.read(input_mesh, file_format="abaqus") + logger.info( 'Reading abaqus mesh...' ) + mesh = meshio.read( input_mesh, file_format="abaqus" ) # Convert the element regions to tags - logger.info('Converting region tags...') - region_list = list(mesh.cell_sets.keys()) - n_regions = len(region_list) + logger.info( 'Converting region tags...' ) + region_list = list( mesh.cell_sets.keys() ) + n_regions = len( region_list ) cell_ids = [] - for block_id, block in enumerate(mesh.cells): - cell_ids.append(np.zeros(len(block[1]), dtype=int) - 1) - for region_id, region in enumerate(region_list): - mesh.field_data[region] = [region_id + 1, 3] - cell_ids[block_id][mesh.cell_sets[region][block_id]] = region_id + 1 + for block_id, block in enumerate( mesh.cells ): + cell_ids.append( np.zeros( len( block[ 1 ] ), dtype=int ) - 1 ) + for region_id, region in enumerate( region_list ): + mesh.field_data[ region ] = [ region_id + 1, 3 ] + cell_ids[ block_id ][ mesh.cell_sets[ region ][ block_id ] ] = region_id + 1 # Check for bad element region conversions - if (-1 in cell_ids[-1]): - logger.warning('Some element regions in block %i did not convert correctly to tags!' % (block_id)) - logger.warning('Note: These will be indicated by a -1 in the output file.') + if ( -1 in cell_ids[ -1 ] ): + logger.warning( 'Some element regions in block %i did not convert correctly to tags!' % ( block_id ) ) + logger.warning( 'Note: These will be indicated by a -1 in the output file.' ) n_warnings += 1 # Add to the meshio datastructure # Note: the copy here is required, so that later appends # do not break these dicts - mesh.cell_data['gmsh:physical'] = cell_ids.copy() - mesh.cell_data['gmsh:geometrical'] = cell_ids.copy() + mesh.cell_data[ 'gmsh:physical' ] = cell_ids.copy() + mesh.cell_data[ 'gmsh:geometrical' ] = cell_ids.copy() # Build the face elements - logger.info('Converting nodesets to face elements, tags...') + logger.info( 'Converting nodesets to face elements, tags...' ) new_tris, tri_nodeset, tri_region = [], [], [] new_quads, quad_nodeset, quad_region = [], [], [] - for nodeset_id, nodeset_name in enumerate(mesh.point_sets): - logger.info(' %s' % (nodeset_name)) - mesh.field_data[nodeset_name] = [nodeset_id + n_regions + 1, 2] - nodeset = mesh.point_sets[nodeset_name] + for nodeset_id, nodeset_name in enumerate( mesh.point_sets ): + logger.info( ' %s' % ( nodeset_name ) ) + mesh.field_data[ nodeset_name ] = [ nodeset_id + n_regions + 1, 2 ] + nodeset = mesh.point_sets[ nodeset_name ] # Search by block, then element - for block_id, block in enumerate(mesh.cells): - for element_id, element in enumerate(block[1]): + for block_id, block in enumerate( mesh.cells ): + for element_id, element in enumerate( block[ 1 ] ): # Find any matching nodes - matching_nodes = [x for x in element if x in nodeset] + matching_nodes = [ x for x in element if x in nodeset ] # Add a new face element if there are enough nodes - n_matching = len(matching_nodes) - if (n_matching >= 3): + n_matching = len( matching_nodes ) + if ( n_matching >= 3 ): # Find the region region_id = -1 for region in region_list: - if (element_id in mesh.cell_sets[region][block_id]): - region_id = mesh.field_data[region][block_id] + if ( element_id in mesh.cell_sets[ region ][ block_id ] ): + region_id = mesh.field_data[ region ][ block_id ] # Test to see if the element is a quad or triangle - tag_id = mesh.field_data[nodeset_name][0] - if (n_matching == 3): - new_tris.append(matching_nodes) - tri_nodeset.append(tag_id) - tri_region.append(region_id) + tag_id = mesh.field_data[ nodeset_name ][ 0 ] + if ( n_matching == 3 ): + new_tris.append( matching_nodes ) + tri_nodeset.append( tag_id ) + tri_region.append( region_id ) - elif (n_matching == 4): - new_quads.append(matching_nodes) - quad_nodeset.append(tag_id) - quad_region.append(region_id) + elif ( n_matching == 4 ): + new_quads.append( matching_nodes ) + quad_nodeset.append( tag_id ) + quad_region.append( region_id ) else: - logger.warning(' Discarding an element with an unexpected number of nodes') - logger.warning(' n_nodes=%i, element=%i, set=%s' % (n_matching, element_id, nodeset_name)) + logger.warning( ' Discarding an element with an unexpected number of nodes' ) + logger.warning( ' n_nodes=%i, element=%i, set=%s' % + ( n_matching, element_id, nodeset_name ) ) n_warnings += 1 # Add new tris if new_tris: - logger.info(' Adding %i new triangles...' % (len(new_tris))) - if (-1 in tri_region): - logger.warning('Triangles with empty region information found!') - logger.warning('Note: These will be indicated by a -1 in the output file.') + logger.info( ' Adding %i new triangles...' % ( len( new_tris ) ) ) + if ( -1 in tri_region ): + logger.warning( 'Triangles with empty region information found!' ) + logger.warning( 'Note: These will be indicated by a -1 in the output file.' ) n_warnings += 1 - mesh.cells.append(CellBlock('triangle', np.array(new_tris))) - mesh.cell_data['gmsh:geometrical'].append(np.array(tri_region)) - mesh.cell_data['gmsh:physical'].append(np.array(tri_nodeset)) + mesh.cells.append( CellBlock( 'triangle', np.array( new_tris ) ) ) + mesh.cell_data[ 'gmsh:geometrical' ].append( np.array( tri_region ) ) + mesh.cell_data[ 'gmsh:physical' ].append( np.array( tri_nodeset ) ) # Add new quads if new_quads: - logger.info(' Adding %i new quads...' % (len(new_quads))) - if (-1 in quad_region): - logger.warning('Quads with empty region information found!') - logger.warning('Note: These will be indicated by a -1 in the output file.') + logger.info( ' Adding %i new quads...' % ( len( new_quads ) ) ) + if ( -1 in quad_region ): + logger.warning( 'Quads with empty region information found!' ) + logger.warning( 'Note: These will be indicated by a -1 in the output file.' ) n_warnings += 1 - mesh.cells.append(CellBlock('quad', np.array(new_quads))) - mesh.cell_data['gmsh:geometrical'].append(np.array(quad_region)) - mesh.cell_data['gmsh:physical'].append(np.array(quad_nodeset)) + mesh.cells.append( CellBlock( 'quad', np.array( new_quads ) ) ) + mesh.cell_data[ 'gmsh:geometrical' ].append( np.array( quad_region ) ) + mesh.cell_data[ 'gmsh:physical' ].append( np.array( quad_nodeset ) ) # Write the final mesh - logger.info('Writing gmsh mesh...') - meshio.write(output_mesh, mesh, file_format="gmsh22", binary=False) - logger.info('Done!') + logger.info( 'Writing gmsh mesh...' ) + meshio.write( output_mesh, mesh, file_format="gmsh22", binary=False ) + logger.info( 'Done!' ) - return (n_warnings > 0) + return ( n_warnings > 0 ) -def convert_abaqus_to_vtu(input_mesh: str, output_mesh: str, logger: logging.Logger = None) -> int: +def convert_abaqus_to_vtu( input_mesh: str, output_mesh: str, logger: logging.Logger = None ) -> int: """ Convert an abaqus mesh to vtu format, preserving nodeset information. @@ -144,27 +145,27 @@ def convert_abaqus_to_vtu(input_mesh: str, output_mesh: str, logger: logging.Log """ # Initialize the logger if it is empty if not logger: - logging.basicConfig(level=logging.WARNING) - logger = logging.getLogger(__name__) + logging.basicConfig( level=logging.WARNING ) + logger = logging.getLogger( __name__ ) # Keep track of the number of warnings n_warnings = 0 # Load the mesh - logger.info('Reading abaqus mesh...') - mesh = meshio.read(input_mesh, file_format="abaqus") + logger.info( 'Reading abaqus mesh...' ) + mesh = meshio.read( input_mesh, file_format="abaqus" ) # Converting nodesets to binary masks for k, nodeset in mesh.point_sets.items(): - mesh.point_data[k] = np.zeros(len(mesh.points), dtype=int) - mesh.point_data[k][nodeset] = 1 + mesh.point_data[ k ] = np.zeros( len( mesh.points ), dtype=int ) + mesh.point_data[ k ][ nodeset ] = 1 # Overwrite point sets to suppress conversion warnings mesh.point_sets = {} # Write the final mesh - logger.info('Writing vtu mesh...') - meshio.write(output_mesh, mesh, file_format="vtu") - logger.info('Done!') + logger.info( 'Writing vtu mesh...' ) + meshio.write( output_mesh, mesh, file_format="vtu" ) + logger.info( 'Done!' ) - return (n_warnings > 0) + return ( n_warnings > 0 ) diff --git a/geosx_mesh_tools_package/geosx_mesh_tools/main.py b/geosx_mesh_tools_package/geosx_mesh_tools/main.py index 1637d07..90b048f 100644 --- a/geosx_mesh_tools_package/geosx_mesh_tools/main.py +++ b/geosx_mesh_tools_package/geosx_mesh_tools/main.py @@ -10,9 +10,9 @@ def build_abaqus_converter_input_parser() -> argparse.ArgumentParser: argparse.ArgumentParser: a parser instance """ parser = argparse.ArgumentParser() - parser.add_argument('input', type=str, help='Input abaqus mesh file name') - parser.add_argument('output', type=str, help='Output gmsh/vtu mesh file name') - parser.add_argument('-v', '--verbose', help='Increase verbosity level', action="store_true") + parser.add_argument( 'input', type=str, help='Input abaqus mesh file name' ) + parser.add_argument( 'output', type=str, help='Output gmsh/vtu mesh file name' ) + parser.add_argument( '-v', '--verbose', help='Increase verbosity level', action="store_true" ) return parser @@ -32,19 +32,19 @@ def main() -> None: args = parser.parse_args() # Set up a logger - logging.basicConfig(level=logging.WARNING) - logger = logging.getLogger(__name__) + logging.basicConfig( level=logging.WARNING ) + logger = logging.getLogger( __name__ ) if args.verbose: - logger.setLevel(logging.INFO) + logger.setLevel( logging.INFO ) # Call the converter err = 0 - if ('.msh' in args.output): - err = abaqus_converter.convert_abaqus_to_gmsh(args.input, args.output, logger) + if ( '.msh' in args.output ): + err = abaqus_converter.convert_abaqus_to_gmsh( args.input, args.output, logger ) else: - err = abaqus_converter.convert_abaqus_to_vtu(args.input, args.output, logger) + err = abaqus_converter.convert_abaqus_to_vtu( args.input, args.output, logger ) if err: - sys.exit('Warnings detected: check the output file for potential errors!') + sys.exit( 'Warnings detected: check the output file for potential errors!' ) if __name__ == '__main__': diff --git a/geosx_xml_tools_package/geosx_xml_tools/attribute_coverage.py b/geosx_xml_tools_package/geosx_xml_tools/attribute_coverage.py index 2b64df8..7f7af99 100644 --- a/geosx_xml_tools_package/geosx_xml_tools/attribute_coverage.py +++ b/geosx_xml_tools_package/geosx_xml_tools/attribute_coverage.py @@ -1,17 +1,17 @@ -from lxml import etree as ElementTree # type: ignore[import] +from lxml import etree as ElementTree # type: ignore[import] import os from pathlib import Path from typing import Any, Iterable, Dict from geosx_xml_tools import command_line_parsers -record_type = Dict[str, Dict[str, Any]] +record_type = Dict[ str, Dict[ str, Any ] ] -def parse_schema_element(root: ElementTree.Element, - node: ElementTree.Element, - xsd: str = '{http://www.w3.org/2001/XMLSchema}', - recursive_types: Iterable[str] = ['PeriodicEvent', 'SoloEvent', 'HaltEvent'], - folders: Iterable[str] = ['src', 'examples']) -> record_type: +def parse_schema_element( root: ElementTree.Element, + node: ElementTree.Element, + xsd: str = '{http://www.w3.org/2001/XMLSchema}', + recursive_types: Iterable[ str ] = [ 'PeriodicEvent', 'SoloEvent', 'HaltEvent' ], + folders: Iterable[ str ] = [ 'src', 'examples' ] ) -> record_type: """Parse the xml schema at the current level Args: @@ -25,30 +25,30 @@ def parse_schema_element(root: ElementTree.Element, dict: Dictionary of attributes and children for the current node """ - element_type = node.get('type') - element_name = node.get('name') - element_def = root.find("%scomplexType[@name='%s']" % (xsd, element_type)) - local_types: record_type = {'attributes': {}, 'children': {}} + element_type = node.get( 'type' ) + element_name = node.get( 'name' ) + element_def = root.find( "%scomplexType[@name='%s']" % ( xsd, element_type ) ) + local_types: record_type = { 'attributes': {}, 'children': {} } # Parse attributes - for attribute in element_def.findall('%sattribute' % (xsd)): - attribute_name = attribute.get('name') - local_types['attributes'][attribute_name] = {ka: [] for ka in folders} - if ('default' in attribute.attrib): - local_types['attributes'][attribute_name]['default'] = attribute.get('default') + for attribute in element_def.findall( '%sattribute' % ( xsd ) ): + attribute_name = attribute.get( 'name' ) + local_types[ 'attributes' ][ attribute_name ] = { ka: [] for ka in folders } + if ( 'default' in attribute.attrib ): + local_types[ 'attributes' ][ attribute_name ][ 'default' ] = attribute.get( 'default' ) # Parse children - choice_node = element_def.findall('%schoice' % (xsd)) + choice_node = element_def.findall( '%schoice' % ( xsd ) ) if choice_node: - for child in choice_node[0].findall('%selement' % (xsd)): - child_name = child.get('name') - if not ((child_name in recursive_types) and (element_name in recursive_types)): - local_types['children'][child_name] = parse_schema_element(root, child) + for child in choice_node[ 0 ].findall( '%selement' % ( xsd ) ): + child_name = child.get( 'name' ) + if not ( ( child_name in recursive_types ) and ( element_name in recursive_types ) ): + local_types[ 'children' ][ child_name ] = parse_schema_element( root, child ) return local_types -def parse_schema(fname: str) -> record_type: +def parse_schema( fname: str ) -> record_type: """Parse the schema file into the xml attribute usage dict Args: @@ -57,13 +57,13 @@ def parse_schema(fname: str) -> record_type: Returns: dict: Dictionary of attributes and children for the entire schema """ - xml_tree = ElementTree.parse(fname) + xml_tree = ElementTree.parse( fname ) xml_root = xml_tree.getroot() - problem_node = xml_root.find("{http://www.w3.org/2001/XMLSchema}element") - return {'Problem': parse_schema_element(xml_root, problem_node)} + problem_node = xml_root.find( "{http://www.w3.org/2001/XMLSchema}element" ) + return { 'Problem': parse_schema_element( xml_root, problem_node ) } -def collect_xml_attributes_level(local_types: record_type, node: ElementTree.Element, folder: str) -> None: +def collect_xml_attributes_level( local_types: record_type, node: ElementTree.Element, folder: str ) -> None: """Collect xml attribute usage at the current level Args: @@ -72,14 +72,14 @@ def collect_xml_attributes_level(local_types: record_type, node: ElementTree.Ele folder (str): the source folder for the current file """ for ka in node.attrib.keys(): - local_types['attributes'][ka][folder].append(node.get(ka)) + local_types[ 'attributes' ][ ka ][ folder ].append( node.get( ka ) ) for child in node: - if child.tag in local_types['children']: - collect_xml_attributes_level(local_types['children'][child.tag], child, folder) + if child.tag in local_types[ 'children' ]: + collect_xml_attributes_level( local_types[ 'children' ][ child.tag ], child, folder ) -def collect_xml_attributes(xml_types: record_type, fname: str, folder: str) -> None: +def collect_xml_attributes( xml_types: record_type, fname: str, folder: str ) -> None: """Collect xml attribute usage in a file Args: @@ -87,16 +87,16 @@ def collect_xml_attributes(xml_types: record_type, fname: str, folder: str) -> N fname (str): name of the target file folder (str): the source folder for the current file """ - parser = ElementTree.XMLParser(remove_comments=True, remove_blank_text=True) - xml_tree = ElementTree.parse(fname, parser=parser) + parser = ElementTree.XMLParser( remove_comments=True, remove_blank_text=True ) + xml_tree = ElementTree.parse( fname, parser=parser ) xml_root = xml_tree.getroot() - collect_xml_attributes_level(xml_types['Problem'], xml_root, folder) + collect_xml_attributes_level( xml_types[ 'Problem' ], xml_root, folder ) -def write_attribute_usage_xml_level(local_types: record_type, - node: ElementTree.Element, - folders: Iterable[str] = ['src', 'examples']) -> None: +def write_attribute_usage_xml_level( local_types: record_type, + node: ElementTree.Element, + folders: Iterable[ str ] = [ 'src', 'examples' ] ) -> None: """Write xml attribute usage file at a given level Args: @@ -105,44 +105,44 @@ def write_attribute_usage_xml_level(local_types: record_type, """ # Write attributes - for ka in local_types['attributes'].keys(): - attribute_node = ElementTree.Element(ka) - node.append(attribute_node) + for ka in local_types[ 'attributes' ].keys(): + attribute_node = ElementTree.Element( ka ) + node.append( attribute_node ) - if ('default' in local_types['attributes'][ka]): - attribute_node.set('default', local_types['attributes'][ka]['default']) + if ( 'default' in local_types[ 'attributes' ][ ka ] ): + attribute_node.set( 'default', local_types[ 'attributes' ][ ka ][ 'default' ] ) unique_values = [] for f in folders: - sub_values = list(set(local_types['attributes'][ka][f])) - unique_values.extend(sub_values) - attribute_node.set(f, ' | '.join(sub_values)) + sub_values = list( set( local_types[ 'attributes' ][ ka ][ f ] ) ) + unique_values.extend( sub_values ) + attribute_node.set( f, ' | '.join( sub_values ) ) - unique_length = len(set(unique_values)) - attribute_node.set('unique_values', str(unique_length)) + unique_length = len( set( unique_values ) ) + attribute_node.set( 'unique_values', str( unique_length ) ) # Write children - for ka in sorted(local_types['children']): - child = ElementTree.Element(ka) - node.append(child) - write_attribute_usage_xml_level(local_types['children'][ka], child) + for ka in sorted( local_types[ 'children' ] ): + child = ElementTree.Element( ka ) + node.append( child ) + write_attribute_usage_xml_level( local_types[ 'children' ][ ka ], child ) -def write_attribute_usage_xml(xml_types: record_type, fname: str) -> None: +def write_attribute_usage_xml( xml_types: record_type, fname: str ) -> None: """Write xml attribute usage file Args: xml_types (dict): dictionary containing attribute usage by xml type fname (str): output file name """ - xml_root = ElementTree.Element('Problem') - xml_tree = ElementTree.ElementTree(xml_root) + xml_root = ElementTree.Element( 'Problem' ) + xml_tree = ElementTree.ElementTree( xml_root ) - write_attribute_usage_xml_level(xml_types['Problem'], xml_root) - xml_tree.write(fname, pretty_print=True) + write_attribute_usage_xml_level( xml_types[ 'Problem' ], xml_root ) + xml_tree.write( fname, pretty_print=True ) -def process_xml_files(geosx_root: str, output_name: str) -> None: +def process_xml_files( geosx_root: str, output_name: str ) -> None: """Test for xml attribute usage Args: @@ -151,20 +151,20 @@ def process_xml_files(geosx_root: str, output_name: str) -> None: """ # Parse the schema - geosx_root = os.path.expanduser(geosx_root) - schema = '%ssrc/coreComponents/schema/schema.xsd' % (geosx_root) - xml_types = parse_schema(schema) + geosx_root = os.path.expanduser( geosx_root ) + schema = '%ssrc/coreComponents/schema/schema.xsd' % ( geosx_root ) + xml_types = parse_schema( schema ) # Find all xml files, collect their attributes - for folder in ['src', 'examples']: - print(folder) - xml_files = Path(os.path.join(geosx_root, folder)).rglob('*.xml') + for folder in [ 'src', 'examples' ]: + print( folder ) + xml_files = Path( os.path.join( geosx_root, folder ) ).rglob( '*.xml' ) for f in xml_files: - print(' %s' % (str(f))) - collect_xml_attributes(xml_types, str(f), folder) + print( ' %s' % ( str( f ) ) ) + collect_xml_attributes( xml_types, str( f ), folder ) # Consolidate attributes - write_attribute_usage_xml(xml_types, output_name) + write_attribute_usage_xml( xml_types, output_name ) def main() -> None: @@ -180,7 +180,7 @@ def main() -> None: args = parser.parse_args() # Parse the xml files - process_xml_files(args.root, args.output) + process_xml_files( args.root, args.output ) if __name__ == "__main__": diff --git a/geosx_xml_tools_package/geosx_xml_tools/command_line_parsers.py b/geosx_xml_tools_package/geosx_xml_tools/command_line_parsers.py index 4c07d11..48b126f 100644 --- a/geosx_xml_tools_package/geosx_xml_tools/command_line_parsers.py +++ b/geosx_xml_tools_package/geosx_xml_tools/command_line_parsers.py @@ -10,25 +10,25 @@ def build_preprocessor_input_parser() -> argparse.ArgumentParser: """ # Parse the user arguments parser = argparse.ArgumentParser() - parser.add_argument('-i', '--input', type=str, action='append', help='Input file name (multiple allowed)') - parser.add_argument('-c', - '--compiled-name', - type=str, - help='Compiled xml file name (otherwise, it is randomly genrated)', - default='') - parser.add_argument('-s', '--schema', type=str, help='GEOSX schema to use for validation', default='') - parser.add_argument('-v', '--verbose', type=int, help='Verbosity of outputs', default=0) - parser.add_argument('-p', - '--parameters', - nargs='+', - action='append', - help='Parameter overrides (name value, multiple allowed)', - default=[]) + parser.add_argument( '-i', '--input', type=str, action='append', help='Input file name (multiple allowed)' ) + parser.add_argument( '-c', + '--compiled-name', + type=str, + help='Compiled xml file name (otherwise, it is randomly genrated)', + default='' ) + parser.add_argument( '-s', '--schema', type=str, help='GEOSX schema to use for validation', default='' ) + parser.add_argument( '-v', '--verbose', type=int, help='Verbosity of outputs', default=0 ) + parser.add_argument( '-p', + '--parameters', + nargs='+', + action='append', + help='Parameter overrides (name value, multiple allowed)', + default=[] ) return parser -def parse_xml_preprocessor_arguments() -> Tuple[argparse.Namespace, Iterable[str]]: +def parse_xml_preprocessor_arguments() -> Tuple[ argparse.Namespace, Iterable[ str ] ]: """Parse user arguments Args: @@ -53,13 +53,13 @@ def build_xml_formatter_input_parser() -> argparse.ArgumentParser: """ parser = argparse.ArgumentParser() - parser.add_argument('input', type=str, help='Input file name') - parser.add_argument('-i', '--indent', type=int, help='Indent size', default=2) - parser.add_argument('-s', '--style', type=int, help='Indent style', default=0) - parser.add_argument('-d', '--depth', type=int, help='Block separation depth', default=2) - parser.add_argument('-a', '--alphebitize', type=int, help='Alphebetize attributes', default=0) - parser.add_argument('-c', '--close', type=int, help='Close tag style', default=0) - parser.add_argument('-n', '--namespace', type=int, help='Include namespace', default=0) + parser.add_argument( 'input', type=str, help='Input file name' ) + parser.add_argument( '-i', '--indent', type=int, help='Indent size', default=2 ) + parser.add_argument( '-s', '--style', type=int, help='Indent style', default=0 ) + parser.add_argument( '-d', '--depth', type=int, help='Block separation depth', default=2 ) + parser.add_argument( '-a', '--alphebitize', type=int, help='Alphebetize attributes', default=0 ) + parser.add_argument( '-c', '--close', type=int, help='Close tag style', default=0 ) + parser.add_argument( '-n', '--namespace', type=int, help='Include namespace', default=0 ) return parser @@ -71,8 +71,8 @@ def build_attribute_coverage_input_parser() -> argparse.ArgumentParser: """ parser = argparse.ArgumentParser() - parser.add_argument('-r', '--root', type=str, help='GEOSX root', default='') - parser.add_argument('-o', '--output', type=str, help='Output file name', default='attribute_test.xml') + parser.add_argument( '-r', '--root', type=str, help='GEOSX root', default='' ) + parser.add_argument( '-o', '--output', type=str, help='Output file name', default='attribute_test.xml' ) return parser @@ -84,5 +84,5 @@ def build_xml_redundancy_input_parser() -> argparse.ArgumentParser: """ parser = argparse.ArgumentParser() - parser.add_argument('-r', '--root', type=str, help='GEOSX root', default='') + parser.add_argument( '-r', '--root', type=str, help='GEOSX root', default='' ) return parser diff --git a/geosx_xml_tools_package/geosx_xml_tools/main.py b/geosx_xml_tools_package/geosx_xml_tools/main.py index b511028..c907cf7 100644 --- a/geosx_xml_tools_package/geosx_xml_tools/main.py +++ b/geosx_xml_tools_package/geosx_xml_tools/main.py @@ -15,19 +15,19 @@ def check_mpi_rank() -> int: int: MPI rank """ rank = 0 - mpi_rank_key_options = ['OMPI_COMM_WORLD_RANK', 'PMI_RANK'] + mpi_rank_key_options = [ 'OMPI_COMM_WORLD_RANK', 'PMI_RANK' ] for k in mpi_rank_key_options: if k in os.environ: - rank = int(os.environ[k]) + rank = int( os.environ[ k ] ) return rank -TFunc = Callable[..., Any] +TFunc = Callable[..., Any ] -def wait_for_file_write_rank_0(target_file_argument: Union[int, str] = 0, - max_wait_time: float = 100, - max_startup_delay: float = 1) -> Callable[[TFunc], TFunc]: +def wait_for_file_write_rank_0( target_file_argument: Union[ int, str ] = 0, + max_wait_time: float = 100, + max_startup_delay: float = 1 ) -> Callable[ [ TFunc ], TFunc ]: """Constructor for a function decorator that waits for a target file to be written on rank 0 Args: @@ -39,47 +39,47 @@ def wait_for_file_write_rank_0(target_file_argument: Union[int, str] = 0, Wrapped function """ - def wait_for_file_write_rank_0_inner(writer: TFunc) -> TFunc: + def wait_for_file_write_rank_0_inner( writer: TFunc ) -> TFunc: """Intermediate constructor for the function decorator Args: writer (typing.Callable): A function that writes to a file """ - def wait_for_file_write_rank_0_decorator(*args, **kwargs) -> Any: + def wait_for_file_write_rank_0_decorator( *args, **kwargs ) -> Any: """Apply the writer on rank 0, and wait for completion on other ranks """ # Check the target file status rank = check_mpi_rank() fname = '' - if isinstance(target_file_argument, int): - fname = args[target_file_argument] + if isinstance( target_file_argument, int ): + fname = args[ target_file_argument ] else: - fname = kwargs[target_file_argument] + fname = kwargs[ target_file_argument ] - target_file_exists = os.path.isfile(fname) + target_file_exists = os.path.isfile( fname ) target_file_edit_time = 0.0 if target_file_exists: - target_file_edit_time = os.path.getmtime(fname) + target_file_edit_time = os.path.getmtime( fname ) # Variations in thread startup times may mean the file has already been processed # If the last edit was done within the specified time, then allow the thread to proceed - if (abs(target_file_edit_time - time.time()) < max_startup_delay): + if ( abs( target_file_edit_time - time.time() ) < max_startup_delay ): target_file_edit_time = 0.0 # Go into the target process or wait for the expected file update - if (rank == 0): - return writer(*args, **kwargs) + if ( rank == 0 ): + return writer( *args, **kwargs ) else: ta = time.time() - while (time.time() - ta < max_wait_time): + while ( time.time() - ta < max_wait_time ): if target_file_exists: - if (os.path.getmtime(fname) > target_file_edit_time): + if ( os.path.getmtime( fname ) > target_file_edit_time ): break else: - if os.path.isfile(fname): + if os.path.isfile( fname ): break - time.sleep(0.1) + time.sleep( 0.1 ) return wait_for_file_write_rank_0_decorator @@ -99,13 +99,14 @@ def preprocess_serial() -> None: # If the rank detection fails, then it will preprocess the file on all ranks, which # sometimes cause a (seemingly harmless) file write conflict. # processor = xml_processor.process - processor = wait_for_file_write_rank_0(target_file_argument='outputFile', max_wait_time=100)(xml_processor.process) - - compiled_name = processor(args.input, - outputFile=args.compiled_name, - schema=args.schema, - verbose=args.verbose, - parameter_override=args.parameters) + processor = wait_for_file_write_rank_0( target_file_argument='outputFile', + max_wait_time=100 )( xml_processor.process ) + + compiled_name = processor( args.input, + outputFile=args.compiled_name, + schema=args.schema, + verbose=args.verbose, + parameter_override=args.parameters ) if not compiled_name: if args.compiled_name: compiled_name = args.compiled_name @@ -116,31 +117,31 @@ def preprocess_serial() -> None: # Note: the return value may be passed to sys.exit, and cause bash to report an error # return format_geosx_arguments(compiled_name, unknown_args) - print(compiled_name) + print( compiled_name ) -def preprocess_parallel() -> Iterable[str]: +def preprocess_parallel() -> Iterable[ str ]: """ MPI aware xml preprocesing """ # Process the xml file - from mpi4py import MPI # type: ignore[import] + from mpi4py import MPI # type: ignore[import] comm = MPI.COMM_WORLD rank = comm.Get_rank() args, unknown_args = command_line_parsers.parse_xml_preprocessor_arguments() compiled_name = '' - if (rank == 0): - compiled_name = xml_processor.process(args.input, - outputFile=args.compiled_name, - schema=args.schema, - verbose=args.verbose, - parameter_override=args.parameters) - compiled_name = comm.bcast(compiled_name, root=0) - return format_geosx_arguments(compiled_name, unknown_args) + if ( rank == 0 ): + compiled_name = xml_processor.process( args.input, + outputFile=args.compiled_name, + schema=args.schema, + verbose=args.verbose, + parameter_override=args.parameters ) + compiled_name = comm.bcast( compiled_name, root=0 ) + return format_geosx_arguments( compiled_name, unknown_args ) -def format_geosx_arguments(compiled_name: str, unknown_args: Iterable[str]) -> Iterable[str]: +def format_geosx_arguments( compiled_name: str, unknown_args: Iterable[ str ] ) -> Iterable[ str ]: """Format GEOSX arguments Args: @@ -150,12 +151,12 @@ def format_geosx_arguments(compiled_name: str, unknown_args: Iterable[str]) -> I Returns: list: List of arguments to pass to GEOSX """ - geosx_args = [sys.argv[0], '-i', compiled_name] + geosx_args = [ sys.argv[ 0 ], '-i', compiled_name ] if unknown_args: - geosx_args.extend(unknown_args) + geosx_args.extend( unknown_args ) # Print the output name for use in bash scripts - print(compiled_name) + print( compiled_name ) return geosx_args diff --git a/geosx_xml_tools_package/geosx_xml_tools/regex_tools.py b/geosx_xml_tools_package/geosx_xml_tools/regex_tools.py index ded5c1a..a7e586b 100644 --- a/geosx_xml_tools_package/geosx_xml_tools/regex_tools.py +++ b/geosx_xml_tools_package/geosx_xml_tools/regex_tools.py @@ -17,7 +17,7 @@ strip_trailing_b| 3.0000e0, 1.23e0 | Removes unnecessary float strings """ -patterns: Dict[str, str] = { +patterns: Dict[ str, str ] = { 'parameters': r"\$:?([a-zA-Z_0-9]*)\$?", 'units': r"([0-9]*?\.?[0-9]+(?:[eE][-+]?[0-9]*?)?)\ *?\[([-+.*/()a-zA-Z0-9]*)\]", 'units_b': r"([a-zA-Z]*)", @@ -31,21 +31,21 @@ symbolic_format = '%1.6e' -def SymbolicMathRegexHandler(match: re.Match) -> str: +def SymbolicMathRegexHandler( match: re.Match ) -> str: """Evaluate symbolic expressions that are identified using the regex_tools.patterns['symbolic']. Args: match (re.match): A matching string identified by the regex. """ - k = match.group(1) + k = match.group( 1 ) if k: # Sanitize the input - sanitized = re.sub(patterns['sanitize'], '', k).strip() - value = eval(sanitized, {'__builtins__': None}) + sanitized = re.sub( patterns[ 'sanitize' ], '', k ).strip() + value = eval( sanitized, { '__builtins__': None } ) # Format the string, removing any trailing zeros, decimals, etc. - str_value = re.sub(patterns['strip_trailing'], '', symbolic_format % (value)) - str_value = re.sub(patterns['strip_trailing_b'], '', str_value) + str_value = re.sub( patterns[ 'strip_trailing' ], '', symbolic_format % ( value ) ) + str_value = re.sub( patterns[ 'strip_trailing_b' ], '', str_value ) return str_value else: return '' @@ -54,25 +54,25 @@ def SymbolicMathRegexHandler(match: re.Match) -> str: class DictRegexHandler(): """This class is used to substitute matched values with those stored in a dict.""" - def __init__(self) -> None: + def __init__( self ) -> None: """Initialize the handler with an empty target list. The key/value pairs of self.target indicate which values to look for and the values they will replace with. """ - self.target: Dict[str, str] = {} + self.target: Dict[ str, str ] = {} - def __call__(self, match: re.Match) -> str: + def __call__( self, match: re.Match ) -> str: """Replace the matching strings with their target. Args: match (re.match): A matching string identified by the regex. """ - k = match.group(1) + k = match.group( 1 ) if k: - if (k not in self.target.keys()): - raise Exception('Error: Target (%s) is not defined in the regex handler' % k) - value = self.target[k] - return str(value) + if ( k not in self.target.keys() ): + raise Exception( 'Error: Target (%s) is not defined in the regex handler' % k ) + value = self.target[ k ] + return str( value ) else: return '' diff --git a/geosx_xml_tools_package/geosx_xml_tools/table_generator.py b/geosx_xml_tools_package/geosx_xml_tools/table_generator.py index bb4a63c..9fbe4d8 100644 --- a/geosx_xml_tools_package/geosx_xml_tools/table_generator.py +++ b/geosx_xml_tools_package/geosx_xml_tools/table_generator.py @@ -4,10 +4,10 @@ from typing import Tuple, Iterable, Dict -def write_GEOS_table(axes_values: Iterable[np.ndarray], - properties: Dict[str, np.ndarray], - axes_names: Iterable[str] = ['x', 'y', 'z', 't'], - string_format: str = '%1.5e') -> None: +def write_GEOS_table( axes_values: Iterable[ np.ndarray ], + properties: Dict[ str, np.ndarray ], + axes_names: Iterable[ str ] = [ 'x', 'y', 'z', 't' ], + string_format: str = '%1.5e' ) -> None: """Write an GEOS-compatible ascii table. Args: @@ -18,23 +18,23 @@ def write_GEOS_table(axes_values: Iterable[np.ndarray], """ # Check to make sure the axes/property files have the correct shape - axes_shape = tuple([len(x) for x in axes_values]) + axes_shape = tuple( [ len( x ) for x in axes_values ] ) for k in properties.keys(): - if (np.shape(properties[k]) != axes_shape): - raise Exception("Shape of parameter %s is incompatible with given axes" % (k)) + if ( np.shape( properties[ k ] ) != axes_shape ): + raise Exception( "Shape of parameter %s is incompatible with given axes" % ( k ) ) # Write axes files - for ka, x in zip(axes_names, axes_values): - np.savetxt('%s.geos' % (ka), x, fmt=string_format, delimiter=',') + for ka, x in zip( axes_names, axes_values ): + np.savetxt( '%s.geos' % ( ka ), x, fmt=string_format, delimiter=',' ) # Write property files for k in properties.keys(): - tmp = np.reshape(properties[k], (-1), order='F') - np.savetxt('%s.geos' % (k), tmp, fmt=string_format, delimiter=',') + tmp = np.reshape( properties[ k ], ( -1 ), order='F' ) + np.savetxt( '%s.geos' % ( k ), tmp, fmt=string_format, delimiter=',' ) -def read_GEOS_table(axes_files: Iterable[str], - property_files: Iterable[str]) -> Tuple[Iterable[np.ndarray], Dict[str, np.ndarray]]: +def read_GEOS_table( axes_files: Iterable[ str ], + property_files: Iterable[ str ] ) -> Tuple[ Iterable[ np.ndarray ], Dict[ str, np.ndarray ] ]: """Read an GEOS-compatible ascii table. Args: @@ -46,14 +46,14 @@ def read_GEOS_table(axes_files: Iterable[str], """ axes_values = [] for f in axes_files: - axes_values.append(np.loadtxt('%s.geos' % (f), unpack=True, delimiter=',')) - axes_shape = tuple([len(x) for x in axes_values]) + axes_values.append( np.loadtxt( '%s.geos' % ( f ), unpack=True, delimiter=',' ) ) + axes_shape = tuple( [ len( x ) for x in axes_values ] ) # Open property files properties = {} for f in property_files: - tmp = np.loadtxt('%s.geos' % (f), unpack=True, delimiter=',') - properties[f] = np.reshape(tmp, axes_shape, order='F') + tmp = np.loadtxt( '%s.geos' % ( f ), unpack=True, delimiter=',' ) + properties[ f ] = np.reshape( tmp, axes_shape, order='F' ) return axes_values, properties @@ -62,14 +62,14 @@ def write_read_GEOS_table_example() -> None: """Table read / write example.""" # Define table axes - a = np.array([0.0, 1.0]) - b = np.array([0.0, 0.5, 1.0]) - axes_values = [a, b] + a = np.array( [ 0.0, 1.0 ] ) + b = np.array( [ 0.0, 0.5, 1.0 ] ) + axes_values = [ a, b ] # Generate table values (note: the indexing argument is important) - A, B = np.meshgrid(a, b, indexing='ij') - properties = {'c': A + 2.0 * B} + A, B = np.meshgrid( a, b, indexing='ij' ) + properties = { 'c': A + 2.0 * B } # Write, then read tables - write_GEOS_table(axes_values, properties, axes_names=['a', 'b']) - axes_b, properties_b = read_GEOS_table(['a', 'b'], ['c']) + write_GEOS_table( axes_values, properties, axes_names=[ 'a', 'b' ] ) + axes_b, properties_b = read_GEOS_table( [ 'a', 'b' ], [ 'c' ] ) diff --git a/geosx_xml_tools_package/geosx_xml_tools/tests/generate_test_xml.py b/geosx_xml_tools_package/geosx_xml_tools/tests/generate_test_xml.py index 3c1b771..80cae43 100644 --- a/geosx_xml_tools_package/geosx_xml_tools/tests/generate_test_xml.py +++ b/geosx_xml_tools_package/geosx_xml_tools/tests/generate_test_xml.py @@ -5,7 +5,7 @@ from geosx_xml_tools import xml_formatter -def generate_test_xml_files(root_dir): +def generate_test_xml_files( root_dir ): """Build example input/output xml files, which can be used to test the parser. These are derived from a GEOSX integrated test xml. @@ -339,36 +339,36 @@ def generate_test_xml_files(root_dir): # Write the files, and apply pretty_print to targets for easy matches # No advanced features case - with open('%s/no_advanced_features_input.xml' % (root_dir), 'w') as f: - f.write(xml_header + xml_base_a + xml_base_b + field_string_base + xml_footer) - with open('%s/no_advanced_features_target.xml' % (root_dir), 'w') as f: - f.write(xml_header + xml_base_a + xml_base_b + field_string_base + xml_footer) - xml_formatter.format_file('%s/no_advanced_features_target.xml' % (root_dir)) + with open( '%s/no_advanced_features_input.xml' % ( root_dir ), 'w' ) as f: + f.write( xml_header + xml_base_a + xml_base_b + field_string_base + xml_footer ) + with open( '%s/no_advanced_features_target.xml' % ( root_dir ), 'w' ) as f: + f.write( xml_header + xml_base_a + xml_base_b + field_string_base + xml_footer ) + xml_formatter.format_file( '%s/no_advanced_features_target.xml' % ( root_dir ) ) # Parameters case - with open('%s/parameters_input.xml' % (root_dir), 'w') as f: - f.write(xml_header + xml_parameters + xml_base_a + xml_base_b + field_string_with_parameters + xml_footer) - with open('%s/parameters_target.xml' % (root_dir), 'w') as f: - f.write(xml_header + xml_base_a + xml_base_b + field_string_base + xml_footer) - xml_formatter.format_file('%s/parameters_target.xml' % (root_dir)) + with open( '%s/parameters_input.xml' % ( root_dir ), 'w' ) as f: + f.write( xml_header + xml_parameters + xml_base_a + xml_base_b + field_string_with_parameters + xml_footer ) + with open( '%s/parameters_target.xml' % ( root_dir ), 'w' ) as f: + f.write( xml_header + xml_base_a + xml_base_b + field_string_base + xml_footer ) + xml_formatter.format_file( '%s/parameters_target.xml' % ( root_dir ) ) # Symbolic + parameters case - with open('%s/symbolic_parameters_input.xml' % (root_dir), 'w') as f: - f.write(xml_header + xml_parameters + xml_base_a + xml_base_b + field_string_with_symbolic + xml_footer) - with open('%s/symbolic_parameters_target.xml' % (root_dir), 'w') as f: - f.write(xml_header + xml_base_a + xml_base_b + field_string_alt + xml_footer) - xml_formatter.format_file('%s/symbolic_parameters_target.xml' % (root_dir)) + with open( '%s/symbolic_parameters_input.xml' % ( root_dir ), 'w' ) as f: + f.write( xml_header + xml_parameters + xml_base_a + xml_base_b + field_string_with_symbolic + xml_footer ) + with open( '%s/symbolic_parameters_target.xml' % ( root_dir ), 'w' ) as f: + f.write( xml_header + xml_base_a + xml_base_b + field_string_alt + xml_footer ) + xml_formatter.format_file( '%s/symbolic_parameters_target.xml' % ( root_dir ) ) # Included case - os.makedirs('%s/included' % (root_dir), exist_ok=True) - with open('%s/included_input.xml' % (root_dir), 'w') as f: - f.write(xml_header + xml_includes + xml_footer) - with open('%s/included/included_a.xml' % (root_dir), 'w') as f: - f.write(xml_header + xml_base_a + xml_footer) - with open('%s/included/included_b.xml' % (root_dir), 'w') as f: - f.write(xml_header + xml_base_b + xml_footer) - with open('%s/included/included_c.xml' % (root_dir), 'w') as f: - f.write(xml_header + field_string_base + xml_footer) - with open('%s/included_target.xml' % (root_dir), 'w') as f: - f.write(xml_header + xml_base_a + xml_base_b + field_string_base + xml_footer) - xml_formatter.format_file('%s/included_target.xml' % (root_dir)) + os.makedirs( '%s/included' % ( root_dir ), exist_ok=True ) + with open( '%s/included_input.xml' % ( root_dir ), 'w' ) as f: + f.write( xml_header + xml_includes + xml_footer ) + with open( '%s/included/included_a.xml' % ( root_dir ), 'w' ) as f: + f.write( xml_header + xml_base_a + xml_footer ) + with open( '%s/included/included_b.xml' % ( root_dir ), 'w' ) as f: + f.write( xml_header + xml_base_b + xml_footer ) + with open( '%s/included/included_c.xml' % ( root_dir ), 'w' ) as f: + f.write( xml_header + field_string_base + xml_footer ) + with open( '%s/included_target.xml' % ( root_dir ), 'w' ) as f: + f.write( xml_header + xml_base_a + xml_base_b + field_string_base + xml_footer ) + xml_formatter.format_file( '%s/included_target.xml' % ( root_dir ) ) diff --git a/geosx_xml_tools_package/geosx_xml_tools/tests/test_manager.py b/geosx_xml_tools_package/geosx_xml_tools/tests/test_manager.py index 54e60a7..367d2dd 100644 --- a/geosx_xml_tools_package/geosx_xml_tools/tests/test_manager.py +++ b/geosx_xml_tools_package/geosx_xml_tools/tests/test_manager.py @@ -12,145 +12,148 @@ # Test the unit manager definitions -class TestUnitManager(unittest.TestCase): +class TestUnitManager( unittest.TestCase ): @classmethod - def setUpClass(cls): + def setUpClass( cls ): cls.tol = 1e-6 - def test_unit_dict(self): + def test_unit_dict( self ): unitManager.buildUnits() - self.assertTrue(bool(unitManager.units)) + self.assertTrue( bool( unitManager.units ) ) # Scale value tests - @parameterized.expand([['meter', '2', 2.0], ['meter', '1.234', 1.234], ['meter', '1.234e1', 12.34], - ['meter', '1.234E1', 12.34], ['meter', '1.234e+1', 12.34], ['meter', '1.234e-1', 0.1234], - ['mumeter', '1', 1.0e-6], ['micrometer', '1', 1.0e-6], ['kilometer', '1', 1.0e3], - ['ms', '1', 1.0e-3], ['millisecond', '1', 1.0e-3], ['Ms', '1', 1.0e6], ['m/s', '1', 1.0], - ['micrometer/s', '1', 1.0e-6], ['micrometer/ms', '1', 1.0e-3], - ['micrometer/microsecond', '1', 1.0], ['m**2', '1', 1.0], ['km**2', '1', 1.0e6], - ['kilometer**2', '1', 1.0e6], ['(km*mm)', '1', 1.0], ['(km*mm)**2', '1', 1.0], - ['km^2', '1', 1.0e6, True], ['bbl/day', '1', 0.000001840130728333], ['cP', '1', 0.001]]) - def test_units(self, unit, scale, expected_value, expect_fail=False): + @parameterized.expand( [ [ 'meter', '2', 2.0 ], [ 'meter', '1.234', 1.234 ], [ 'meter', '1.234e1', 12.34 ], + [ 'meter', '1.234E1', 12.34 ], [ 'meter', '1.234e+1', 12.34 ], + [ 'meter', '1.234e-1', 0.1234 ], [ 'mumeter', '1', 1.0e-6 ], [ 'micrometer', '1', 1.0e-6 ], + [ 'kilometer', '1', 1.0e3 ], [ 'ms', '1', 1.0e-3 ], [ 'millisecond', '1', 1.0e-3 ], + [ 'Ms', '1', 1.0e6 ], [ 'm/s', '1', 1.0 ], [ 'micrometer/s', '1', 1.0e-6 ], + [ 'micrometer/ms', '1', 1.0e-3 ], [ 'micrometer/microsecond', '1', 1.0 ], + [ 'm**2', '1', 1.0 ], [ 'km**2', '1', 1.0e6 ], [ 'kilometer**2', '1', 1.0e6 ], + [ '(km*mm)', '1', 1.0 ], [ '(km*mm)**2', '1', 1.0 ], [ 'km^2', '1', 1.0e6, True ], + [ 'bbl/day', '1', 0.000001840130728333 ], [ 'cP', '1', 0.001 ] ] ) + def test_units( self, unit, scale, expected_value, expect_fail=False ): try: - val = float(unitManager([scale, unit])) - self.assertTrue((abs(val - expected_value) < self.tol) != expect_fail) + val = float( unitManager( [ scale, unit ] ) ) + self.assertTrue( ( abs( val - expected_value ) < self.tol ) != expect_fail ) except TypeError: - self.assertTrue(expect_fail) + self.assertTrue( expect_fail ) # Test the behavior of the parameter regex -class TestParameterRegex(unittest.TestCase): +class TestParameterRegex( unittest.TestCase ): @classmethod - def setUpClass(cls): + def setUpClass( cls ): cls.regexHandler = regex_tools.DictRegexHandler() - cls.regexHandler.target['foo'] = '1.23' - cls.regexHandler.target['bar'] = '4.56e7' - - @parameterized.expand([['$:foo*1.234', '1.23*1.234'], ['$:foo*1.234/$:bar', '1.23*1.234/4.56e7'], - ['$:foo*1.234/($:bar*$:foo)', '1.23*1.234/(4.56e7*1.23)'], - ['$foo*1.234/$bar', '1.23*1.234/4.56e7'], ['$foo$*1.234/$bar', '1.23*1.234/4.56e7'], - ['$foo$*1.234/$bar$', '1.23*1.234/4.56e7'], - ['$blah$*1.234/$bar$', '1.23*1.234/4.56e7', True], - ['$foo$*1.234/$bar$', '4.56e7*1.234/4.56e7', True]]) - def test_parameter_regex(self, parameterInput, expectedValue, expect_fail=False): + cls.regexHandler.target[ 'foo' ] = '1.23' + cls.regexHandler.target[ 'bar' ] = '4.56e7' + + @parameterized.expand( [ [ '$:foo*1.234', '1.23*1.234' ], [ '$:foo*1.234/$:bar', '1.23*1.234/4.56e7' ], + [ '$:foo*1.234/($:bar*$:foo)', '1.23*1.234/(4.56e7*1.23)' ], + [ '$foo*1.234/$bar', '1.23*1.234/4.56e7' ], [ '$foo$*1.234/$bar', '1.23*1.234/4.56e7' ], + [ '$foo$*1.234/$bar$', '1.23*1.234/4.56e7' ], + [ '$blah$*1.234/$bar$', '1.23*1.234/4.56e7', True ], + [ '$foo$*1.234/$bar$', '4.56e7*1.234/4.56e7', True ] ] ) + def test_parameter_regex( self, parameterInput, expectedValue, expect_fail=False ): try: - result = re.sub(regex_tools.patterns['parameters'], self.regexHandler, parameterInput) - self.assertTrue((result == expectedValue) != expect_fail) + result = re.sub( regex_tools.patterns[ 'parameters' ], self.regexHandler, parameterInput ) + self.assertTrue( ( result == expectedValue ) != expect_fail ) except Exception: - self.assertTrue(expect_fail) + self.assertTrue( expect_fail ) # Test the behavior of the unit regex -class TestUnitsRegex(unittest.TestCase): +class TestUnitsRegex( unittest.TestCase ): @classmethod - def setUpClass(cls): + def setUpClass( cls ): cls.tol = 1e-6 - @parameterized.expand([['1.234[m**2/s]', '1.234'], ['1.234 [m**2/s]', '1.234'], ['1.234[m**2/s]*3.4', '1.234*3.4'], - ['1.234[m**2/s] + 5.678[mm/s]', '1.234 + 5.678e-3'], - ['1.234 [m**2/s] + 5.678 [mm/s]', '1.234 + 5.678e-3'], - ['(1.234[m**2/s])*5.678', '(1.234)*5.678']]) - def test_units_regex(self, unitInput, expectedValue, expect_fail=False): + @parameterized.expand( [ [ '1.234[m**2/s]', '1.234' ], [ '1.234 [m**2/s]', '1.234' ], + [ '1.234[m**2/s]*3.4', '1.234*3.4' ], + [ '1.234[m**2/s] + 5.678[mm/s]', '1.234 + 5.678e-3' ], + [ '1.234 [m**2/s] + 5.678 [mm/s]', '1.234 + 5.678e-3' ], + [ '(1.234[m**2/s])*5.678', '(1.234)*5.678' ] ] ) + def test_units_regex( self, unitInput, expectedValue, expect_fail=False ): try: - result = re.sub(regex_tools.patterns['units'], unitManager.regexHandler, unitInput) - self.assertTrue((result == expectedValue) != expect_fail) + result = re.sub( regex_tools.patterns[ 'units' ], unitManager.regexHandler, unitInput ) + self.assertTrue( ( result == expectedValue ) != expect_fail ) except Exception: - self.assertTrue(expect_fail) + self.assertTrue( expect_fail ) # Test the symbolic math regex -class TestSymbolicRegex(unittest.TestCase): +class TestSymbolicRegex( unittest.TestCase ): @classmethod - def setUpClass(cls): + def setUpClass( cls ): cls.tol = 1e-6 - @parameterized.expand([['`1.234`', '1.234'], ['`1.234*2.0`', '2.468'], ['`10`', '1e1'], ['`10*2`', '2e1'], - ['`1.0/2.0`', '5e-1'], ['`2.0**2`', '4'], ['`1.0 + 2.0**2`', '5'], ['`(1.0 + 2.0)**2`', '9'], - ['`((1.0 + 2.0)**2)**(0.5)`', '3'], ['`(1.2e3)*2`', '2.4e3'], ['`1.2e3*2`', '2.4e3'], - ['`2.0^2`', '4', True], ['`sqrt(4.0)`', '2', True]]) - def test_symbolic_regex(self, symbolicInput, expectedValue, expect_fail=False): + @parameterized.expand( [ [ '`1.234`', '1.234' ], [ '`1.234*2.0`', '2.468' ], [ '`10`', '1e1' ], [ '`10*2`', '2e1' ], + [ '`1.0/2.0`', '5e-1' ], [ '`2.0**2`', '4' ], [ '`1.0 + 2.0**2`', '5' ], + [ '`(1.0 + 2.0)**2`', '9' ], [ '`((1.0 + 2.0)**2)**(0.5)`', '3' ], + [ '`(1.2e3)*2`', '2.4e3' ], [ '`1.2e3*2`', '2.4e3' ], [ '`2.0^2`', '4', True ], + [ '`sqrt(4.0)`', '2', True ] ] ) + def test_symbolic_regex( self, symbolicInput, expectedValue, expect_fail=False ): try: - result = re.sub(regex_tools.patterns['symbolic'], regex_tools.SymbolicMathRegexHandler, symbolicInput) - self.assertTrue((result == expectedValue) != expect_fail) + result = re.sub( regex_tools.patterns[ 'symbolic' ], regex_tools.SymbolicMathRegexHandler, symbolicInput ) + self.assertTrue( ( result == expectedValue ) != expect_fail ) except Exception: - self.assertTrue(expect_fail) + self.assertTrue( expect_fail ) # Test the complete xml processor -class TestXMLProcessor(unittest.TestCase): +class TestXMLProcessor( unittest.TestCase ): @classmethod - def setUpClass(cls): - generate_test_xml.generate_test_xml_files('.') - - @parameterized.expand([['no_advanced_features_input.xml', 'no_advanced_features_target.xml'], - ['parameters_input.xml', 'parameters_target.xml'], - ['included_input.xml', 'included_target.xml'], - ['symbolic_parameters_input.xml', 'symbolic_parameters_target.xml']]) - def test_xml_processor(self, input_file, target_file, expect_fail=False): + def setUpClass( cls ): + generate_test_xml.generate_test_xml_files( '.' ) + + @parameterized.expand( [ [ 'no_advanced_features_input.xml', 'no_advanced_features_target.xml' ], + [ 'parameters_input.xml', 'parameters_target.xml' ], + [ 'included_input.xml', 'included_target.xml' ], + [ 'symbolic_parameters_input.xml', 'symbolic_parameters_target.xml' ] ] ) + def test_xml_processor( self, input_file, target_file, expect_fail=False ): try: - tmp = xml_processor.process(input_file, - outputFile=input_file + '.processed', - verbose=0, - keep_parameters=False, - keep_includes=False) - self.assertTrue(filecmp.cmp(tmp, target_file) != expect_fail) + tmp = xml_processor.process( input_file, + outputFile=input_file + '.processed', + verbose=0, + keep_parameters=False, + keep_includes=False ) + self.assertTrue( filecmp.cmp( tmp, target_file ) != expect_fail ) except Exception: - self.assertTrue(expect_fail) + self.assertTrue( expect_fail ) # Main entry point for the unit tests -def run_unit_tests(test_dir, verbose): +def run_unit_tests( test_dir, verbose ): # Create and move to the test directory pwd = os.getcwd() - os.makedirs(test_dir, exist_ok=True) - os.chdir(test_dir) + os.makedirs( test_dir, exist_ok=True ) + os.chdir( test_dir ) # Unit manager tests - suite = unittest.TestLoader().loadTestsFromTestCase(TestUnitManager) - unittest.TextTestRunner(verbosity=verbose).run(suite) + suite = unittest.TestLoader().loadTestsFromTestCase( TestUnitManager ) + unittest.TextTestRunner( verbosity=verbose ).run( suite ) # Parameter regex handler tests - suite = unittest.TestLoader().loadTestsFromTestCase(TestParameterRegex) - unittest.TextTestRunner(verbosity=verbose).run(suite) + suite = unittest.TestLoader().loadTestsFromTestCase( TestParameterRegex ) + unittest.TextTestRunner( verbosity=verbose ).run( suite ) # Regex handler tests - suite = unittest.TestLoader().loadTestsFromTestCase(TestUnitsRegex) - unittest.TextTestRunner(verbosity=verbose).run(suite) + suite = unittest.TestLoader().loadTestsFromTestCase( TestUnitsRegex ) + unittest.TextTestRunner( verbosity=verbose ).run( suite ) # Symbolic regex handler tests - suite = unittest.TestLoader().loadTestsFromTestCase(TestSymbolicRegex) - unittest.TextTestRunner(verbosity=verbose).run(suite) + suite = unittest.TestLoader().loadTestsFromTestCase( TestSymbolicRegex ) + unittest.TextTestRunner( verbosity=verbose ).run( suite ) # xml processor tests - suite = unittest.TestLoader().loadTestsFromTestCase(TestXMLProcessor) - unittest.TextTestRunner(verbosity=verbose).run(suite) + suite = unittest.TestLoader().loadTestsFromTestCase( TestXMLProcessor ) + unittest.TextTestRunner( verbosity=verbose ).run( suite ) - os.chdir(pwd) + os.chdir( pwd ) def main(): @@ -161,12 +164,12 @@ def main(): # Parse the user arguments parser = argparse.ArgumentParser() - parser.add_argument('-t', '--test_dir', type=str, help='Test output directory', default='./test_results') - parser.add_argument('-v', '--verbose', type=int, help='Verbosity level', default=2) + parser.add_argument( '-t', '--test_dir', type=str, help='Test output directory', default='./test_results' ) + parser.add_argument( '-v', '--verbose', type=int, help='Verbosity level', default=2 ) args = parser.parse_args() # Process the xml file - run_unit_tests(args.test_dir, args.verbose) + run_unit_tests( args.test_dir, args.verbose ) if __name__ == "__main__": diff --git a/geosx_xml_tools_package/geosx_xml_tools/unit_manager.py b/geosx_xml_tools_package/geosx_xml_tools/unit_manager.py index 44360b0..4f04fa8 100644 --- a/geosx_xml_tools_package/geosx_xml_tools/unit_manager.py +++ b/geosx_xml_tools_package/geosx_xml_tools/unit_manager.py @@ -8,13 +8,13 @@ class UnitManager(): """This class is used to manage unit definitions.""" - def __init__(self) -> None: + def __init__( self ) -> None: """Initialize the class by creating an instance of the dict regex handler, building units.""" - self.units: Dict[str, str] = {} + self.units: Dict[ str, str ] = {} self.unitMatcher = regex_tools.DictRegexHandler() self.buildUnits() - def __call__(self, unitStruct: List[Any]) -> str: + def __call__( self, unitStruct: List[ Any ] ) -> str: """Evaluate the symbolic expression for matched strings. Args: @@ -25,20 +25,20 @@ def __call__(self, unitStruct: List[Any]) -> str: """ # Replace all instances of units in the string with their scale defined in self.units - symbolicUnits = re.sub(regex_tools.patterns['units_b'], self.unitMatcher, unitStruct[1]) + symbolicUnits = re.sub( regex_tools.patterns[ 'units_b' ], self.unitMatcher, unitStruct[ 1 ] ) # Strip out any undesired characters and evaluate # Note: the only allowed alpha characters are e and E. This could be relaxed to allow # functions such as sin, cos, etc. - symbolicUnits_sanitized = re.sub(regex_tools.patterns['sanitize'], '', symbolicUnits).strip() - value = float(unitStruct[0]) * eval(symbolicUnits_sanitized, {'__builtins__': None}) + symbolicUnits_sanitized = re.sub( regex_tools.patterns[ 'sanitize' ], '', symbolicUnits ).strip() + value = float( unitStruct[ 0 ] ) * eval( symbolicUnits_sanitized, { '__builtins__': None } ) # Format the string, removing any trailing zeros, decimals, extraneous exponential formats - str_value = re.sub(regex_tools.patterns['strip_trailing'], '', regex_tools.symbolic_format % (value)) - str_value = re.sub(regex_tools.patterns['strip_trailing_b'], '', str_value) + str_value = re.sub( regex_tools.patterns[ 'strip_trailing' ], '', regex_tools.symbolic_format % ( value ) ) + str_value = re.sub( regex_tools.patterns[ 'strip_trailing_b' ], '', str_value ) return str_value - def regexHandler(self, match: re.Match) -> str: + def regexHandler( self, match: re.Match ) -> str: """Split the matched string into a scale and unit definition. Args: @@ -49,9 +49,9 @@ def regexHandler(self, match: re.Match) -> str: """ # The first matched group includes the scale of the value (e.g. 1.234) # The second matches the string inside the unit definition (e.g. m/s**2) - return self.__call__([match.group(1), match.group(2)]) + return self.__call__( [ match.group( 1 ), match.group( 2 ) ] ) - def buildUnits(self) -> None: + def buildUnits( self ) -> None: """Build the unit definitions.""" # yapf: disable @@ -117,35 +117,35 @@ def buildUnits(self) -> None: # yapf: enable # Combine the unit dicts - unit_defs.update(imp_defs) - unit_defs.update(other_defs) + unit_defs.update( imp_defs ) + unit_defs.update( other_defs ) # Use brute-force to generate a list of potential units, rather than trying to parse # unit strings on the fly. This is still quite fast, and allows us to do simple # checks for overlapping definitions # Expand prefix and alternate names - for p in list(prefixes.keys()): - if prefixes[p]['alt']: - prefixes[prefixes[p]['alt']] = {'value': prefixes[p]['value']} - for u in list(unit_defs.keys()): - for alt in unit_defs[u]['alt']: - unit_defs[alt] = {'value': unit_defs[u]['value'], 'usePrefix': unit_defs[u]['usePrefix']} + for p in list( prefixes.keys() ): + if prefixes[ p ][ 'alt' ]: + prefixes[ prefixes[ p ][ 'alt' ] ] = { 'value': prefixes[ p ][ 'value' ] } + for u in list( unit_defs.keys() ): + for alt in unit_defs[ u ][ 'alt' ]: + unit_defs[ alt ] = { 'value': unit_defs[ u ][ 'value' ], 'usePrefix': unit_defs[ u ][ 'usePrefix' ] } # Combine the results into the final dictionary for u in unit_defs.keys(): - if (unit_defs[u]['usePrefix']): + if ( unit_defs[ u ][ 'usePrefix' ] ): for p in prefixes.keys(): - self.units[p + u] = prefixes[p]['value'] * unit_defs[u]['value'] + self.units[ p + u ] = prefixes[ p ][ 'value' ] * unit_defs[ u ][ 'value' ] else: - self.units[u] = unit_defs[u]['value'] + self.units[ u ] = unit_defs[ u ][ 'value' ] # Test to make sure that there are no overlapping unit definitions from collections import Counter - tmp = list(self.units.keys()) - duplicates = [k for k, v in Counter(tmp).items() if v > 1] - if (duplicates): - print(duplicates) - raise Exception('Error: There are overlapping unit definitions in the UnitManager') + tmp = list( self.units.keys() ) + duplicates = [ k for k, v in Counter( tmp ).items() if v > 1 ] + if ( duplicates ): + print( duplicates ) + raise Exception( 'Error: There are overlapping unit definitions in the UnitManager' ) self.unitMatcher.target = self.units diff --git a/geosx_xml_tools_package/geosx_xml_tools/xml_formatter.py b/geosx_xml_tools_package/geosx_xml_tools/xml_formatter.py index eb745ab..4695e65 100644 --- a/geosx_xml_tools_package/geosx_xml_tools/xml_formatter.py +++ b/geosx_xml_tools_package/geosx_xml_tools/xml_formatter.py @@ -1,11 +1,11 @@ import os -from lxml import etree as ElementTree # type: ignore[import] +from lxml import etree as ElementTree # type: ignore[import] import re from typing import List, Any, TextIO from geosx_xml_tools import command_line_parsers -def format_attribute(attribute_indent: str, ka: str, attribute_value: str) -> str: +def format_attribute( attribute_indent: str, ka: str, attribute_value: str ) -> str: """Format xml attribute strings Args: @@ -17,37 +17,37 @@ def format_attribute(attribute_indent: str, ka: str, attribute_value: str) -> st str: Formatted attribute value """ # Make sure that a space follows commas - attribute_value = re.sub(r",\s*", ", ", attribute_value) + attribute_value = re.sub( r",\s*", ", ", attribute_value ) # Handle external brackets - attribute_value = re.sub(r"{\s*", "{ ", attribute_value) - attribute_value = re.sub(r"\s*}", " }", attribute_value) + attribute_value = re.sub( r"{\s*", "{ ", attribute_value ) + attribute_value = re.sub( r"\s*}", " }", attribute_value ) # Consolidate whitespace - attribute_value = re.sub(r"\s+", " ", attribute_value) + attribute_value = re.sub( r"\s+", " ", attribute_value ) # Identify and split multi-line attributes - if re.match(r"\s*{\s*({[-+.,0-9a-zA-Z\s]*},?\s*)*\s*}", attribute_value): - split_positions: List[Any] = [match.end() for match in re.finditer(r"}\s*,", attribute_value)] - newline_indent = '\n%s' % (' ' * (len(attribute_indent) + len(ka) + 4)) + if re.match( r"\s*{\s*({[-+.,0-9a-zA-Z\s]*},?\s*)*\s*}", attribute_value ): + split_positions: List[ Any ] = [ match.end() for match in re.finditer( r"}\s*,", attribute_value ) ] + newline_indent = '\n%s' % ( ' ' * ( len( attribute_indent ) + len( ka ) + 4 ) ) new_values = [] - for a, b in zip([None] + split_positions, split_positions + [None]): - new_values.append(attribute_value[a:b].strip()) + for a, b in zip( [ None ] + split_positions, split_positions + [ None ] ): + new_values.append( attribute_value[ a:b ].strip() ) if new_values: - attribute_value = newline_indent.join(new_values) + attribute_value = newline_indent.join( new_values ) return attribute_value -def format_xml_level(output: TextIO, - node: ElementTree.Element, - level: int, - indent: str = ' ' * 2, - block_separation_max_depth: int = 2, - modify_attribute_indent: bool = False, - sort_attributes: bool = False, - close_tag_newline: bool = False, - include_namespace: bool = False) -> None: +def format_xml_level( output: TextIO, + node: ElementTree.Element, + level: int, + indent: str = ' ' * 2, + block_separation_max_depth: int = 2, + modify_attribute_indent: bool = False, + sort_attributes: bool = False, + close_tag_newline: bool = False, + include_namespace: bool = False ) -> None: """Iteratively format the xml file Args: @@ -64,76 +64,77 @@ def format_xml_level(output: TextIO, # Handle comments if node.tag is ElementTree.Comment: - output.write('\n%s' % (indent * level, node.text)) + output.write( '\n%s' % ( indent * level, node.text ) ) else: # Write opening line - opening_line = '\n%s<%s' % (indent * level, node.tag) - output.write(opening_line) + opening_line = '\n%s<%s' % ( indent * level, node.tag ) + output.write( opening_line ) # Write attributes - if (len(node.attrib) > 0): + if ( len( node.attrib ) > 0 ): # Choose indentation - attribute_indent = '%s' % (indent * (level + 1)) + attribute_indent = '%s' % ( indent * ( level + 1 ) ) if modify_attribute_indent: - attribute_indent = ' ' * (len(opening_line)) + attribute_indent = ' ' * ( len( opening_line ) ) # Get a copy of the attributes attribute_dict = {} - if ((level == 0) & include_namespace): + if ( ( level == 0 ) & include_namespace ): # Handle the optional namespace information at the root level # Note: preferably, this would point to a schema we host online - attribute_dict['xmlns:xsi'] = 'http://www.w3.org/2001/XMLSchema-instance' - attribute_dict['xsi:noNamespaceSchemaLocation'] = '/usr/gapps/GEOS/schema/schema.xsd' - elif (level > 0): + attribute_dict[ 'xmlns:xsi' ] = 'http://www.w3.org/2001/XMLSchema-instance' + attribute_dict[ 'xsi:noNamespaceSchemaLocation' ] = '/usr/gapps/GEOS/schema/schema.xsd' + elif ( level > 0 ): attribute_dict = node.attrib # Sort attribute names - akeys = list(attribute_dict.keys()) + akeys = list( attribute_dict.keys() ) if sort_attributes: - akeys = sorted(akeys) + akeys = sorted( akeys ) # Format attributes for ka in akeys: # Avoid formatting mathpresso expressions - if not (node.tag in ["SymbolicFunction", "CompositeFunction"] and ka == "expression"): - attribute_dict[ka] = format_attribute(attribute_indent, ka, attribute_dict[ka]) + if not ( node.tag in [ "SymbolicFunction", "CompositeFunction" ] and ka == "expression" ): + attribute_dict[ ka ] = format_attribute( attribute_indent, ka, attribute_dict[ ka ] ) - for ii in range(0, len(akeys)): - k = akeys[ii] - if ((ii == 0) & modify_attribute_indent): - output.write(' %s=\"%s\"' % (k, attribute_dict[k])) + for ii in range( 0, len( akeys ) ): + k = akeys[ ii ] + if ( ( ii == 0 ) & modify_attribute_indent ): + output.write( ' %s=\"%s\"' % ( k, attribute_dict[ k ] ) ) else: - output.write('\n%s%s=\"%s\"' % (attribute_indent, k, attribute_dict[k])) + output.write( '\n%s%s=\"%s\"' % ( attribute_indent, k, attribute_dict[ k ] ) ) # Write children - if len(node): - output.write('>') - Nc = len(node) - for ii, child in zip(range(Nc), node): - format_xml_level(output, child, level + 1, indent, block_separation_max_depth, modify_attribute_indent, - sort_attributes, close_tag_newline, include_namespace) + if len( node ): + output.write( '>' ) + Nc = len( node ) + for ii, child in zip( range( Nc ), node ): + format_xml_level( output, child, level + 1, indent, block_separation_max_depth, modify_attribute_indent, + sort_attributes, close_tag_newline, include_namespace ) # Add space between blocks - if ((level < block_separation_max_depth) & (ii < Nc - 1) & (child.tag is not ElementTree.Comment)): - output.write('\n') + if ( ( level < block_separation_max_depth ) & ( ii < Nc - 1 ) & + ( child.tag is not ElementTree.Comment ) ): + output.write( '\n' ) # Write the end tag - output.write('\n%s' % (indent * level, node.tag)) + output.write( '\n%s' % ( indent * level, node.tag ) ) else: if close_tag_newline: - output.write('\n%s/>' % (indent * level)) + output.write( '\n%s/>' % ( indent * level ) ) else: - output.write('/>') + output.write( '/>' ) -def format_file(input_fname: str, - indent_size: int = 2, - indent_style: bool = False, - block_separation_max_depth: int = 2, - alphebitize_attributes: bool = False, - close_style: bool = False, - namespace: bool = False) -> None: +def format_file( input_fname: str, + indent_size: int = 2, + indent_style: bool = False, + block_separation_max_depth: int = 2, + alphebitize_attributes: bool = False, + close_style: bool = False, + namespace: bool = False ) -> None: """Script to format xml files Args: @@ -145,37 +146,37 @@ def format_file(input_fname: str, close_style (bool): Style of close tag (0=same line, 1=new line) namespace (bool): Insert this namespace in the xml description """ - fname = os.path.expanduser(input_fname) + fname = os.path.expanduser( input_fname ) try: - tree = ElementTree.parse(fname) + tree = ElementTree.parse( fname ) root = tree.getroot() - prologue_comments = [tmp.text for tmp in root.itersiblings(preceding=True)] - epilog_comments = [tmp.text for tmp in root.itersiblings()] + prologue_comments = [ tmp.text for tmp in root.itersiblings( preceding=True ) ] + epilog_comments = [ tmp.text for tmp in root.itersiblings() ] - with open(fname, 'w') as f: - f.write('\n') + with open( fname, 'w' ) as f: + f.write( '\n' ) - for comment in reversed(prologue_comments): - f.write('\n' % (comment)) + for comment in reversed( prologue_comments ): + f.write( '\n' % ( comment ) ) - format_xml_level(f, - root, - 0, - indent=' ' * indent_size, - block_separation_max_depth=block_separation_max_depth, - modify_attribute_indent=indent_style, - sort_attributes=alphebitize_attributes, - close_tag_newline=close_style, - include_namespace=namespace) + format_xml_level( f, + root, + 0, + indent=' ' * indent_size, + block_separation_max_depth=block_separation_max_depth, + modify_attribute_indent=indent_style, + sort_attributes=alphebitize_attributes, + close_tag_newline=close_style, + include_namespace=namespace ) for comment in epilog_comments: - f.write('\n' % (comment)) - f.write('\n') + f.write( '\n' % ( comment ) ) + f.write( '\n' ) except ElementTree.ParseError as err: - print('\nCould not load file: %s' % (fname)) - print(err.msg) - raise Exception('\nCheck input file!') + print( '\nCould not load file: %s' % ( fname ) ) + print( err.msg ) + raise Exception( '\nCheck input file!' ) def main() -> None: @@ -192,13 +193,13 @@ def main() -> None: """ parser = command_line_parsers.build_xml_formatter_input_parser() args = parser.parse_args() - format_file(args.input, - indent_size=args.indent, - indent_style=args.style, - block_separation_max_depth=args.depth, - alphebitize_attributes=args.alphebitize, - close_style=args.close, - namespace=args.namespace) + format_file( args.input, + indent_size=args.indent, + indent_style=args.style, + block_separation_max_depth=args.depth, + alphebitize_attributes=args.alphebitize, + close_style=args.close, + namespace=args.namespace ) if __name__ == "__main__": diff --git a/geosx_xml_tools_package/geosx_xml_tools/xml_processor.py b/geosx_xml_tools_package/geosx_xml_tools/xml_processor.py index 757d025..d5a34c3 100644 --- a/geosx_xml_tools_package/geosx_xml_tools/xml_processor.py +++ b/geosx_xml_tools_package/geosx_xml_tools/xml_processor.py @@ -1,7 +1,7 @@ """Tools for processing xml files in GEOSX""" -from lxml import etree as ElementTree # type: ignore[import] -from lxml.etree import XMLSyntaxError # type: ignore[import] +from lxml import etree as ElementTree # type: ignore[import] +from lxml.etree import XMLSyntaxError # type: ignore[import] import re import os from geosx_xml_tools import regex_tools, unit_manager @@ -13,7 +13,7 @@ parameterHandler = regex_tools.DictRegexHandler() -def merge_xml_nodes(existingNode: ElementTree.Element, targetNode: ElementTree.Element, level: int) -> None: +def merge_xml_nodes( existingNode: ElementTree.Element, targetNode: ElementTree.Element, level: int ) -> None: """Merge nodes in an included file into the current structure level by level. Args: @@ -24,7 +24,7 @@ def merge_xml_nodes(existingNode: ElementTree.Element, targetNode: ElementTree.E # Copy attributes on the current level for tk in targetNode.attrib.keys(): - existingNode.set(tk, targetNode.get(tk)) + existingNode.set( tk, targetNode.get( tk ) ) # Copy target children into the xml structure currentTag = '' @@ -35,32 +35,32 @@ def merge_xml_nodes(existingNode: ElementTree.Element, targetNode: ElementTree.E # Check to see if a node with the appropriate type # exists at this level - if (currentTag != target.tag): + if ( currentTag != target.tag ): currentTag = target.tag - matchingSubNodes = existingNode.findall(target.tag) + matchingSubNodes = existingNode.findall( target.tag ) - if (matchingSubNodes): - targetName = target.get('name') + if ( matchingSubNodes ): + targetName = target.get( 'name' ) # Special case for the root Problem node (which may be unnamed) - if (level == 0): + if ( level == 0 ): insertCurrentLevel = False - merge_xml_nodes(matchingSubNodes[0], target, level + 1) + merge_xml_nodes( matchingSubNodes[ 0 ], target, level + 1 ) # Handle named xml nodes - elif (targetName and (currentTag not in ['Nodeset'])): + elif ( targetName and ( currentTag not in [ 'Nodeset' ] ) ): for match in matchingSubNodes: - if (match.get('name') == targetName): + if ( match.get( 'name' ) == targetName ): insertCurrentLevel = False - merge_xml_nodes(match, target, level + 1) + merge_xml_nodes( match, target, level + 1 ) # Insert any unnamed nodes or named nodes that aren't present # in the current xml structure - if (insertCurrentLevel): - existingNode.insert(-1, target) + if ( insertCurrentLevel ): + existingNode.insert( -1, target ) -def merge_included_xml_files(root: ElementTree.Element, fname: str, includeCount: int, maxInclude: int = 100) -> None: +def merge_included_xml_files( root: ElementTree.Element, fname: str, includeCount: int, maxInclude: int = 100 ) -> None: """Recursively merge included files into the current structure. Args: @@ -72,40 +72,40 @@ def merge_included_xml_files(root: ElementTree.Element, fname: str, includeCount # Expand the input path pwd = os.getcwd() - includePath, fname = os.path.split(os.path.abspath(os.path.expanduser(fname))) - os.chdir(includePath) + includePath, fname = os.path.split( os.path.abspath( os.path.expanduser( fname ) ) ) + os.chdir( includePath ) # Check to see if the code has fallen into a loop includeCount += 1 - if (includeCount > maxInclude): - raise Exception('Reached maximum recursive includes... Is there an include loop?') + if ( includeCount > maxInclude ): + raise Exception( 'Reached maximum recursive includes... Is there an include loop?' ) # Check to make sure the file exists - if (not os.path.isfile(fname)): - print('Included file does not exist: %s' % (fname)) - raise Exception('Check included file path!') + if ( not os.path.isfile( fname ) ): + print( 'Included file does not exist: %s' % ( fname ) ) + raise Exception( 'Check included file path!' ) # Load target xml try: - parser = ElementTree.XMLParser(remove_comments=True, remove_blank_text=True) - includeTree = ElementTree.parse(fname, parser) + parser = ElementTree.XMLParser( remove_comments=True, remove_blank_text=True ) + includeTree = ElementTree.parse( fname, parser ) includeRoot = includeTree.getroot() except XMLSyntaxError as err: - print('\nCould not load included file: %s' % (fname)) - print(err.msg) - raise Exception('\nCheck included file!') + print( '\nCould not load included file: %s' % ( fname ) ) + print( err.msg ) + raise Exception( '\nCheck included file!' ) # Recursively add the includes: - for includeNode in includeRoot.findall('Included'): - for f in includeNode.findall('File'): - merge_included_xml_files(root, f.get('name'), includeCount) + for includeNode in includeRoot.findall( 'Included' ): + for f in includeNode.findall( 'File' ): + merge_included_xml_files( root, f.get( 'name' ), includeCount ) # Merge the results into the xml tree - merge_xml_nodes(root, includeRoot, 0) - os.chdir(pwd) + merge_xml_nodes( root, includeRoot, 0 ) + os.chdir( pwd ) -def apply_regex_to_node(node: ElementTree.Element) -> None: +def apply_regex_to_node( node: ElementTree.Element ) -> None: """Apply regexes that handle parameters, units, and symbolic math to each xml attribute in the structure. @@ -114,35 +114,35 @@ def apply_regex_to_node(node: ElementTree.Element) -> None: """ for k in node.attrib.keys(): - value = node.get(k) + value = node.get( k ) # Parameter format: $Parameter or $:Parameter ii = 0 - while ('$' in value): - value = re.sub(regex_tools.patterns['parameters'], parameterHandler, value) + while ( '$' in value ): + value = re.sub( regex_tools.patterns[ 'parameters' ], parameterHandler, value ) ii += 1 - if (ii > 100): - raise Exception('Reached maximum parameter expands (Node=%s, value=%s)' % (node.tag, value)) + if ( ii > 100 ): + raise Exception( 'Reached maximum parameter expands (Node=%s, value=%s)' % ( node.tag, value ) ) # Unit format: 9.81[m**2/s] or 1.0 [bbl/day] - if ('[' in value): - value = re.sub(regex_tools.patterns['units'], unitManager.regexHandler, value) + if ( '[' in value ): + value = re.sub( regex_tools.patterns[ 'units' ], unitManager.regexHandler, value ) # Symbolic format: `1 + 2.34e5*2 * ...` ii = 0 - while ('`' in value): - value = re.sub(regex_tools.patterns['symbolic'], regex_tools.SymbolicMathRegexHandler, value) + while ( '`' in value ): + value = re.sub( regex_tools.patterns[ 'symbolic' ], regex_tools.SymbolicMathRegexHandler, value ) ii += 1 - if (ii > 100): - raise Exception('Reached maximum symbolic expands (Node=%s, value=%s)' % (node.tag, value)) + if ( ii > 100 ): + raise Exception( 'Reached maximum symbolic expands (Node=%s, value=%s)' % ( node.tag, value ) ) - node.set(k, value) + node.set( k, value ) for subNode in node.getchildren(): - apply_regex_to_node(subNode) + apply_regex_to_node( subNode ) -def generate_random_name(prefix: str = '', suffix: str = '.xml') -> str: +def generate_random_name( prefix: str = '', suffix: str = '.xml' ) -> str: """If the target name is not specified, generate a random name for the compiled xml Args: @@ -156,17 +156,17 @@ def generate_random_name(prefix: str = '', suffix: str = '.xml') -> str: from time import time from os import getpid - tmp = str(time()) + str(getpid()) - return '%s%s%s' % (prefix, md5(tmp.encode('utf-8')).hexdigest(), suffix) + tmp = str( time() ) + str( getpid() ) + return '%s%s%s' % ( prefix, md5( tmp.encode( 'utf-8' ) ).hexdigest(), suffix ) -def process(inputFiles: Iterable[str], - outputFile: str = '', - schema: str = '', - verbose: int = 0, - parameter_override: List[Tuple[str, str]] = [], - keep_parameters: bool = True, - keep_includes: bool = True) -> str: +def process( inputFiles: Iterable[ str ], + outputFile: str = '', + schema: str = '', + verbose: int = 0, + parameter_override: List[ Tuple[ str, str ] ] = [], + keep_parameters: bool = True, + keep_includes: bool = True ) -> str: """Process an xml file Args: @@ -182,120 +182,120 @@ def process(inputFiles: Iterable[str], str: Output file name """ if verbose: - print('\nReading input xml parameters and parsing symbolic math...') + print( '\nReading input xml parameters and parsing symbolic math...' ) # Check the type of inputFiles - if isinstance(inputFiles, str): - inputFiles = [inputFiles] + if isinstance( inputFiles, str ): + inputFiles = [ inputFiles ] # Expand the input path pwd = os.getcwd() - expanded_files = [os.path.abspath(os.path.expanduser(f)) for f in inputFiles] - single_path, single_input = os.path.split(expanded_files[0]) - os.chdir(single_path) + expanded_files = [ os.path.abspath( os.path.expanduser( f ) ) for f in inputFiles ] + single_path, single_input = os.path.split( expanded_files[ 0 ] ) + os.chdir( single_path ) # Handle single vs. multiple command line inputs - root = ElementTree.Element("Problem") + root = ElementTree.Element( "Problem" ) tree = ElementTree.ElementTree() - if (len(expanded_files) == 1): + if ( len( expanded_files ) == 1 ): # Load single files directly try: - parser = ElementTree.XMLParser(remove_comments=True, remove_blank_text=True) - tree = ElementTree.parse(single_input, parser=parser) + parser = ElementTree.XMLParser( remove_comments=True, remove_blank_text=True ) + tree = ElementTree.parse( single_input, parser=parser ) root = tree.getroot() except XMLSyntaxError as err: - print('\nCould not load input file: %s' % (single_input)) - print(err.msg) - raise Exception('\nCheck input file!') + print( '\nCould not load input file: %s' % ( single_input ) ) + print( err.msg ) + raise Exception( '\nCheck input file!' ) else: # For multiple inputs, create a simple xml structure to hold # the included files. These will be saved as comments in the compiled file - root = ElementTree.Element('Problem') - tree = ElementTree.ElementTree(root) - included_node = ElementTree.Element("Included") - root.append(included_node) + root = ElementTree.Element( 'Problem' ) + tree = ElementTree.ElementTree( root ) + included_node = ElementTree.Element( "Included" ) + root.append( included_node ) for f in expanded_files: - included_file = ElementTree.Element("File") - included_file.set('name', f) - included_node.append(included_file) + included_file = ElementTree.Element( "File" ) + included_file.set( 'name', f ) + included_node.append( included_file ) # Add the included files to the xml structure # Note: doing this first assumes that parameters aren't used in Included block includeCount = 0 - for includeNode in root.findall('Included'): - for f in includeNode.findall('File'): - merge_included_xml_files(root, f.get('name'), includeCount) # type: ignore[attr-defined] - os.chdir(pwd) + for includeNode in root.findall( 'Included' ): + for f in includeNode.findall( 'File' ): + merge_included_xml_files( root, f.get( 'name' ), includeCount ) # type: ignore[attr-defined] + os.chdir( pwd ) # Build the parameter map Pmap = {} - for parameters in root.findall('Parameters'): - for p in parameters.findall('Parameter'): - Pmap[p.get('name')] = p.get('value') + for parameters in root.findall( 'Parameters' ): + for p in parameters.findall( 'Parameter' ): + Pmap[ p.get( 'name' ) ] = p.get( 'value' ) # Apply any parameter overrides - if len(parameter_override): + if len( parameter_override ): # Save overriden values to a new xml element - command_override_node = ElementTree.Element("CommandLineOverride") - root.append(command_override_node) - for ii in range(len(parameter_override)): - pname = parameter_override[ii][0] - pval = ' '.join(parameter_override[ii][1:]) - Pmap[pname] = pval - override_parameter = ElementTree.Element("Parameter") - override_parameter.set('name', pname) - override_parameter.set('value', pval) - command_override_node.append(override_parameter) + command_override_node = ElementTree.Element( "CommandLineOverride" ) + root.append( command_override_node ) + for ii in range( len( parameter_override ) ): + pname = parameter_override[ ii ][ 0 ] + pval = ' '.join( parameter_override[ ii ][ 1: ] ) + Pmap[ pname ] = pval + override_parameter = ElementTree.Element( "Parameter" ) + override_parameter.set( 'name', pname ) + override_parameter.set( 'value', pval ) + command_override_node.append( override_parameter ) # Add the parameter map to the handler parameterHandler.target = Pmap # Process any parameters, units, and symbolic math in the xml - apply_regex_to_node(root) + apply_regex_to_node( root ) # Comment out or remove the Parameter, Included nodes - for includeNode in root.findall('Included'): + for includeNode in root.findall( 'Included' ): if keep_includes: - root.insert(-1, ElementTree.Comment(ElementTree.tostring(includeNode))) - root.remove(includeNode) - for parameterNode in root.findall('Parameters'): + root.insert( -1, ElementTree.Comment( ElementTree.tostring( includeNode ) ) ) + root.remove( includeNode ) + for parameterNode in root.findall( 'Parameters' ): if keep_parameters: - root.insert(-1, ElementTree.Comment(ElementTree.tostring(parameterNode))) - root.remove(parameterNode) - for overrideNode in root.findall('CommandLineOverride'): + root.insert( -1, ElementTree.Comment( ElementTree.tostring( parameterNode ) ) ) + root.remove( parameterNode ) + for overrideNode in root.findall( 'CommandLineOverride' ): if keep_parameters: - root.insert(-1, ElementTree.Comment(ElementTree.tostring(overrideNode))) - root.remove(overrideNode) + root.insert( -1, ElementTree.Comment( ElementTree.tostring( overrideNode ) ) ) + root.remove( overrideNode ) # Generate a random output name if not specified if not outputFile: - outputFile = generate_random_name(prefix='prep_') + outputFile = generate_random_name( prefix='prep_' ) # Write the output file - tree.write(outputFile, pretty_print=True) + tree.write( outputFile, pretty_print=True ) # Check for un-matched special characters - with open(outputFile, 'r') as ofile: + with open( outputFile, 'r' ) as ofile: for line in ofile: - if any([sc in line for sc in ['$', '[', ']', '`']]): + if any( [ sc in line for sc in [ '$', '[', ']', '`' ] ] ): raise Exception( 'Found un-matched special characters in the pre-processed input file on line:\n%s\n Check your input xml for errors!' - % (line)) + % ( line ) ) # Apply formatting to the file - xml_formatter.format_file(outputFile) + xml_formatter.format_file( outputFile ) if verbose: - print('Preprocessed xml file stored in %s' % (outputFile)) + print( 'Preprocessed xml file stored in %s' % ( outputFile ) ) if schema: - validate_xml(outputFile, schema, verbose) + validate_xml( outputFile, schema, verbose ) return outputFile -def validate_xml(fname: str, schema: str, verbose: int) -> None: +def validate_xml( fname: str, schema: str, verbose: int ) -> None: """Validate an xml file, and parse the warnings. Args: @@ -304,18 +304,18 @@ def validate_xml(fname: str, schema: str, verbose: int) -> None: verbose (int): Verbosity level. """ if verbose: - print('Validating the xml against the schema...') + print( 'Validating the xml against the schema...' ) try: - ofile = ElementTree.parse(fname) - sfile = ElementTree.XMLSchema(ElementTree.parse(os.path.expanduser(schema))) - sfile.assertValid(ofile) + ofile = ElementTree.parse( fname ) + sfile = ElementTree.XMLSchema( ElementTree.parse( os.path.expanduser( schema ) ) ) + sfile.assertValid( ofile ) except ElementTree.DocumentInvalid as err: - print(err) - print('\nWarning: input XML contains potentially invalid input parameters:') - print('-' * 20 + '\n') - print(sfile.error_log) - print('\n' + '-' * 20) - print('(Total schema warnings: %i)\n' % (len(sfile.error_log))) + print( err ) + print( '\nWarning: input XML contains potentially invalid input parameters:' ) + print( '-' * 20 + '\n' ) + print( sfile.error_log ) + print( '\n' + '-' * 20 ) + print( '(Total schema warnings: %i)\n' % ( len( sfile.error_log ) ) ) if verbose: - print('Done!') + print( 'Done!' ) diff --git a/geosx_xml_tools_package/geosx_xml_tools/xml_redundancy_check.py b/geosx_xml_tools_package/geosx_xml_tools/xml_redundancy_check.py index 251e1df..f970a6d 100644 --- a/geosx_xml_tools_package/geosx_xml_tools/xml_redundancy_check.py +++ b/geosx_xml_tools_package/geosx_xml_tools/xml_redundancy_check.py @@ -1,15 +1,15 @@ from geosx_xml_tools.attribute_coverage import parse_schema from geosx_xml_tools.xml_formatter import format_file -from lxml import etree as ElementTree # type: ignore[import] +from lxml import etree as ElementTree # type: ignore[import] import os from pathlib import Path from geosx_xml_tools import command_line_parsers from typing import Iterable, Dict, Any -def check_redundancy_level(local_schema: Dict[str, Any], - node: ElementTree.Element, - whitelist: Iterable[str] = ['component']) -> int: +def check_redundancy_level( local_schema: Dict[ str, Any ], + node: ElementTree.Element, + whitelist: Iterable[ str ] = [ 'component' ] ) -> int: """Check xml redundancy at the current level Args: @@ -22,43 +22,43 @@ def check_redundancy_level(local_schema: Dict[str, Any], """ node_is_required = 0 for ka in node.attrib.keys(): - if (ka in whitelist): + if ( ka in whitelist ): node_is_required += 1 - elif (ka not in local_schema['attributes']): + elif ( ka not in local_schema[ 'attributes' ] ): node_is_required += 1 - elif ('default' not in local_schema['attributes'][ka]): + elif ( 'default' not in local_schema[ 'attributes' ][ ka ] ): node_is_required += 1 - elif (node.get(ka) != local_schema['attributes'][ka]['default']): + elif ( node.get( ka ) != local_schema[ 'attributes' ][ ka ][ 'default' ] ): node_is_required += 1 else: - node.attrib.pop(ka) + node.attrib.pop( ka ) for child in node: # Comments will not appear in the schema - if child.tag in local_schema['children']: - child_is_required = check_redundancy_level(local_schema['children'][child.tag], child) + if child.tag in local_schema[ 'children' ]: + child_is_required = check_redundancy_level( local_schema[ 'children' ][ child.tag ], child ) node_is_required += child_is_required if not child_is_required: - node.remove(child) + node.remove( child ) return node_is_required -def check_xml_redundancy(schema: Dict[str, Any], fname: str) -> None: +def check_xml_redundancy( schema: Dict[ str, Any ], fname: str ) -> None: """Check redundancy in an xml file Args: schema (dict): Schema definitions fname (str): Name of the target file """ - xml_tree = ElementTree.parse(fname) + xml_tree = ElementTree.parse( fname ) xml_root = xml_tree.getroot() - check_redundancy_level(schema['Problem'], xml_root) - xml_tree.write(fname) - format_file(fname) + check_redundancy_level( schema[ 'Problem' ], xml_root ) + xml_tree.write( fname ) + format_file( fname ) -def process_xml_files(geosx_root: str) -> None: +def process_xml_files( geosx_root: str ) -> None: """Test for xml redundancy Args: @@ -66,17 +66,17 @@ def process_xml_files(geosx_root: str) -> None: """ # Parse the schema - geosx_root = os.path.expanduser(geosx_root) - schema_fname = '%ssrc/coreComponents/schema/schema.xsd' % (geosx_root) - schema = parse_schema(schema_fname) + geosx_root = os.path.expanduser( geosx_root ) + schema_fname = '%ssrc/coreComponents/schema/schema.xsd' % ( geosx_root ) + schema = parse_schema( schema_fname ) # Find all xml files, collect their attributes - for folder in ['src', 'examples']: - print(folder) - xml_files = Path(os.path.join(geosx_root, folder)).rglob('*.xml') + for folder in [ 'src', 'examples' ]: + print( folder ) + xml_files = Path( os.path.join( geosx_root, folder ) ).rglob( '*.xml' ) for f in xml_files: - print(' %s' % (str(f))) - check_xml_redundancy(schema, str(f)) + print( ' %s' % ( str( f ) ) ) + check_xml_redundancy( schema, str( f ) ) def main() -> None: @@ -91,7 +91,7 @@ def main() -> None: args = parser.parse_args() # Parse the xml files - process_xml_files(args.root) + process_xml_files( args.root ) if __name__ == "__main__": diff --git a/hdf5_wrapper_package/hdf5_wrapper/use_example.py b/hdf5_wrapper_package/hdf5_wrapper/use_example.py index 5bbdbe3..9243451 100644 --- a/hdf5_wrapper_package/hdf5_wrapper/use_example.py +++ b/hdf5_wrapper_package/hdf5_wrapper/use_example.py @@ -3,7 +3,7 @@ from typing import Union, Dict -def print_database_iterative(database: hdf5_wrapper.hdf5_wrapper, level: int = 0) -> None: +def print_database_iterative( database: hdf5_wrapper.hdf5_wrapper, level: int = 0 ) -> None: """ Print the database targets iteratively by level @@ -13,14 +13,14 @@ def print_database_iterative(database: hdf5_wrapper.hdf5_wrapper, level: int = 0 """ # Note: you can also iterate over the hdf5_wrapper object directly for k in database.keys(): - print('%s%s' % (' ' * level, k)) + print( '%s%s' % ( ' ' * level, k ) ) - if isinstance(database[k], hdf5_wrapper.hdf5_wrapper): + if isinstance( database[ k ], hdf5_wrapper.hdf5_wrapper ): # This is a group, so continue iterating downward - print_database_iterative(database[k], level + 1) + print_database_iterative( database[ k ], level + 1 ) else: # This is likely to be an array - print(database[k]) + print( database[ k ] ) print() @@ -32,19 +32,19 @@ def read_write_hdf5_database_example() -> None: # ------------------------ # Generate test data # ------------------------ - nested_dict_type = Dict[str, Union[np.ndarray, Dict[str, np.ndarray]]] + nested_dict_type = Dict[ str, Union[ np.ndarray, Dict[ str, np.ndarray ] ] ] source_a: nested_dict_type = { - '1D_double_array': np.random.randn(10), - 'string_array': np.array(['a', 'list', 'of', 'strings']), + '1D_double_array': np.random.randn( 10 ), + 'string_array': np.array( [ 'a', 'list', 'of', 'strings' ] ), 'child_a': { - '2D_double_array': np.random.randn(2, 3) + '2D_double_array': np.random.randn( 2, 3 ) } } source_b: nested_dict_type = { - '1D_integer_array': np.random.randint(0, 100, 5), + '1D_integer_array': np.random.randint( 0, 100, 5 ), 'child_b': { - '3D_double_array': np.random.randn(4, 5, 2) + '3D_double_array': np.random.randn( 4, 5, 2 ) } } @@ -53,36 +53,36 @@ def read_write_hdf5_database_example() -> None: # ------------------------ # Write the first piece-by-piece to an hdf5_file # Note: when you exit the following scope, the database is automatically closed - with hdf5_wrapper.hdf5_wrapper('database_a.hdf5', mode='a') as database_a: + with hdf5_wrapper.hdf5_wrapper( 'database_a.hdf5', mode='a' ) as database_a: # Assign the two array objects to this level - database_a['1D_double_array'] = source_a['1D_double_array'] - database_a['string_array'] = source_a['string_array'] + database_a[ '1D_double_array' ] = source_a[ '1D_double_array' ] + database_a[ 'string_array' ] = source_a[ 'string_array' ] # Create a child group and assign the final array - child_a = database_a['child_a'] - child_a['2D_double_array'] = source_a['child_a']['2D_double_array'] + child_a = database_a[ 'child_a' ] + child_a[ '2D_double_array' ] = source_a[ 'child_a' ][ '2D_double_array' ] # Automatically write the second source to a second database - with hdf5_wrapper.hdf5_wrapper('database_b.hdf5', mode='a') as database_b: - database_b['/'] = source_b + with hdf5_wrapper.hdf5_wrapper( 'database_b.hdf5', mode='a' ) as database_b: + database_b[ '/' ] = source_b # Create a third database that links the either two - with hdf5_wrapper.hdf5_wrapper('database_c.hdf5', mode='a') as database_c: - database_c.link('database_a', 'database_a.hdf5') - database_c.link('database_b', 'database_b.hdf5') + with hdf5_wrapper.hdf5_wrapper( 'database_c.hdf5', mode='a' ) as database_c: + database_c.link( 'database_a', 'database_a.hdf5' ) + database_c.link( 'database_b', 'database_b.hdf5' ) # --------------------------------------- # Read the databases from the filesystem # --------------------------------------- - print('Database contents:') - with hdf5_wrapper.hdf5_wrapper('database_c.hdf5') as database_c: + print( 'Database contents:' ) + with hdf5_wrapper.hdf5_wrapper( 'database_c.hdf5' ) as database_c: # Iteratively print the database contents - print_database_iterative(database_c, 1) + print_database_iterative( database_c, 1 ) # As a final note, you can also access low-level h5py functionality # by interacting directly with the database target, e.g.: - print('Database attributes:') - print(' ', database_c.target.attrs) + print( 'Database attributes:' ) + print( ' ', database_c.target.attrs ) if __name__ == "__main__": diff --git a/hdf5_wrapper_package/hdf5_wrapper/wrapper.py b/hdf5_wrapper_package/hdf5_wrapper/wrapper.py index a8a11b8..0895c24 100644 --- a/hdf5_wrapper_package/hdf5_wrapper/wrapper.py +++ b/hdf5_wrapper_package/hdf5_wrapper/wrapper.py @@ -1,13 +1,13 @@ -import h5py # type: ignore[import] +import h5py # type: ignore[import] import numpy as np from numpy.core.defchararray import encode, decode from typing import Union, Dict, Any, Iterable, Optional, Tuple # Note: I would like to replace Any here with str, float, int, np.ndarray, etc. # However, this heterogeneous pattern causes issues with mypy indexing -hdf5_get_types = Union['hdf5_wrapper', Any] -nested_dict_type = Dict[str, Any] -hdf5_set_types = Union['hdf5_wrapper', nested_dict_type, Any] +hdf5_get_types = Union[ 'hdf5_wrapper', Any ] +nested_dict_type = Dict[ str, Any ] +hdf5_set_types = Union[ 'hdf5_wrapper', nested_dict_type, Any ] class hdf5_wrapper(): @@ -15,7 +15,7 @@ class hdf5_wrapper(): A class for reading/writing hdf5 files, which behaves similar to a native dict """ - def __init__(self, fname: str = '', target: Optional[h5py.File] = None, mode: str = 'r') -> None: + def __init__( self, fname: str = '', target: Optional[ h5py.File ] = None, mode: str = 'r' ) -> None: """ Initialize the hdf5_wrapper class @@ -37,9 +37,9 @@ def __init__(self, fname: str = '', target: Optional[h5py.File] = None, mode: st self.mode: str = mode self.target: h5py.File = target if fname: - self.target = h5py.File(fname, self.mode) + self.target = h5py.File( fname, self.mode ) - def __getitem__(self, k: str) -> hdf5_get_types: + def __getitem__( self, k: str ) -> hdf5_get_types: """ Get a target from the database @@ -53,32 +53,32 @@ def __getitem__(self, k: str) -> hdf5_get_types: Returns: hdf5_wrapper/np.ndarray: The returned value """ - if (k not in self.target): - if (self.mode in ['w', 'a']): - self.target.create_group(k) + if ( k not in self.target ): + if ( self.mode in [ 'w', 'a' ] ): + self.target.create_group( k ) else: - raise ValueError('Entry does not exist in database: %s' % (k)) + raise ValueError( 'Entry does not exist in database: %s' % ( k ) ) - tmp = self.target[k] + tmp = self.target[ k ] - if isinstance(tmp, h5py._hl.group.Group): - return hdf5_wrapper(target=tmp, mode=self.mode) - elif isinstance(tmp, h5py._hl.dataset.Dataset): - tmp = np.array(tmp) + if isinstance( tmp, h5py._hl.group.Group ): + return hdf5_wrapper( target=tmp, mode=self.mode ) + elif isinstance( tmp, h5py._hl.dataset.Dataset ): + tmp = np.array( tmp ) # Decode any string types - if (tmp.dtype.kind in ['S', 'U', 'O']): - tmp = decode(tmp) + if ( tmp.dtype.kind in [ 'S', 'U', 'O' ] ): + tmp = decode( tmp ) # Convert any 0-length arrays to native types if not tmp.shape: - tmp = tmp[()] + tmp = tmp[ () ] return tmp else: return tmp - def __setitem__(self, k: str, value: hdf5_set_types): + def __setitem__( self, k: str, value: hdf5_set_types ): """ Write an object to the database if write-mode is enabled @@ -86,30 +86,30 @@ def __setitem__(self, k: str, value: hdf5_set_types): k (str): the name of the object value (dict, np.ndarray, float, int, str): the object to be written """ - if (self.mode in ['w', 'a']): - if isinstance(value, (dict, hdf5_wrapper)): + if ( self.mode in [ 'w', 'a' ] ): + if isinstance( value, ( dict, hdf5_wrapper ) ): # Recursively add groups and their children - if (k not in self.target): - self.target.create_group(k) - new_group = self[k] + if ( k not in self.target ): + self.target.create_group( k ) + new_group = self[ k ] for kb, x in value.items(): - new_group[kb] = x + new_group[ kb ] = x else: # Delete the old copy if necessary - if (k in self.target): - del (self.target[k]) + if ( k in self.target ): + del ( self.target[ k ] ) # Add everything else as an ndarray - tmp = np.array(value) - if (tmp.dtype.kind in ['S', 'U', 'O']): - tmp = encode(tmp) - self.target[k] = tmp + tmp = np.array( value ) + if ( tmp.dtype.kind in [ 'S', 'U', 'O' ] ): + tmp = encode( tmp ) + self.target[ k ] = tmp else: raise ValueError( 'Cannot write to an hdf5 opened in read-only mode! This can be changed by overriding the default mode argument for the wrapper.' ) - def link(self, k: str, target: str) -> None: + def link( self, k: str, target: str ) -> None: """ Link an external hdf5 file to this location in the database @@ -117,76 +117,76 @@ def link(self, k: str, target: str) -> None: k (str): the name of the new link in the database target (str): the path to the external database """ - self.target[k] = h5py.ExternalLink(target, '/') + self.target[ k ] = h5py.ExternalLink( target, '/' ) - def keys(self) -> Iterable[str]: + def keys( self ) -> Iterable[ str ]: """ Get a list of groups and arrays located at the current level Returns: list: a list of key names pointing to objects at the current level """ - if isinstance(self.target, h5py._hl.group.Group): - return list(self.target) + if isinstance( self.target, h5py._hl.group.Group ): + return list( self.target ) else: - raise ValueError('Object not a group!') + raise ValueError( 'Object not a group!' ) - def values(self) -> Iterable[hdf5_get_types]: + def values( self ) -> Iterable[ hdf5_get_types ]: """ Get a list of values located on the current level """ - return [self[k] for k in self.keys()] + return [ self[ k ] for k in self.keys() ] - def items(self) -> Iterable[Tuple[str, hdf5_get_types]]: - return zip(self.keys(), self.values()) + def items( self ) -> Iterable[ Tuple[ str, hdf5_get_types ] ]: + return zip( self.keys(), self.values() ) - def __enter__(self): + def __enter__( self ): """ Entry point for an iterator """ return self - def __exit__(self, type, value, traceback) -> None: + def __exit__( self, type, value, traceback ) -> None: """ End point for an iterator """ self.target.close() - def __del__(self) -> None: + def __del__( self ) -> None: """ Closes the database on wrapper deletion """ try: - if isinstance(self.target, h5py._hl.files.File): + if isinstance( self.target, h5py._hl.files.File ): self.target.close() except: pass - def close(self) -> None: + def close( self ) -> None: """ Closes the database """ - if isinstance(self.target, h5py._hl.files.File): + if isinstance( self.target, h5py._hl.files.File ): self.target.close() - def get_copy(self) -> nested_dict_type: + def get_copy( self ) -> nested_dict_type: """ Copy the entire database into memory Returns: dict: a dictionary holding the database contents """ - result: Dict[Union[str, int], Any] = {} + result: Dict[ Union[ str, int ], Any ] = {} for k in self.keys(): - tmp = self[k] - if isinstance(tmp, hdf5_wrapper): - result[k] = tmp.get_copy() + tmp = self[ k ] + if isinstance( tmp, hdf5_wrapper ): + result[ k ] = tmp.get_copy() else: - result[k] = tmp + result[ k ] = tmp return result - def copy(self) -> nested_dict_type: + def copy( self ) -> nested_dict_type: """ Copy the entire database into memory @@ -195,7 +195,7 @@ def copy(self) -> nested_dict_type: """ return self.get_copy() - def insert(self, x: Union[nested_dict_type, 'hdf5_wrapper']) -> None: + def insert( self, x: Union[ nested_dict_type, 'hdf5_wrapper' ] ) -> None: """ Insert the contents of the target object to the current location @@ -203,4 +203,4 @@ def insert(self, x: Union[nested_dict_type, 'hdf5_wrapper']) -> None: x (dict, hdf5_wrapper): the dictionary to insert """ for k, v in x.items(): - self[k] = v + self[ k ] = v diff --git a/hdf5_wrapper_package/hdf5_wrapper/wrapper_tests.py b/hdf5_wrapper_package/hdf5_wrapper/wrapper_tests.py index 7e496ee..8f963df 100644 --- a/hdf5_wrapper_package/hdf5_wrapper/wrapper_tests.py +++ b/hdf5_wrapper_package/hdf5_wrapper/wrapper_tests.py @@ -7,100 +7,100 @@ import hdf5_wrapper -def random_string(N): - return ''.join(random.choices(string.ascii_uppercase + string.ascii_lowercase + string.digits, k=N)) +def random_string( N ): + return ''.join( random.choices( string.ascii_uppercase + string.ascii_lowercase + string.digits, k=N ) ) -def build_test_dict(depth=0, max_depth=3): - r = [np.random.randint(2, 20) for x in range(5)] +def build_test_dict( depth=0, max_depth=3 ): + r = [ np.random.randint( 2, 20 ) for x in range( 5 ) ] test = { - 'int': np.random.randint(-1000000, 1000000), + 'int': np.random.randint( -1000000, 1000000 ), 'float': np.random.random(), - '1d_array': np.random.randn(r[0]), - '3d_array': np.random.randn(r[1], r[2], r[3]), - 'string': random_string(10), - 'string_array': np.array([random_string(x + 10) for x in range(r[4])]) + '1d_array': np.random.randn( r[ 0 ] ), + '3d_array': np.random.randn( r[ 1 ], r[ 2 ], r[ 3 ] ), + 'string': random_string( 10 ), + 'string_array': np.array( [ random_string( x + 10 ) for x in range( r[ 4 ] ) ] ) } - if (depth < max_depth): - test['child_a'] = build_test_dict(depth + 1, max_depth) - test['child_b'] = build_test_dict(depth + 1, max_depth) - test['child_c'] = build_test_dict(depth + 1, max_depth) + if ( depth < max_depth ): + test[ 'child_a' ] = build_test_dict( depth + 1, max_depth ) + test[ 'child_b' ] = build_test_dict( depth + 1, max_depth ) + test[ 'child_c' ] = build_test_dict( depth + 1, max_depth ) return test # Test the unit manager definitions -class TestHDF5Wrapper(unittest.TestCase): +class TestHDF5Wrapper( unittest.TestCase ): @classmethod - def setUpClass(cls): + def setUpClass( cls ): cls.test_dir = 'wrapper_tests' - os.makedirs(cls.test_dir, exist_ok=True) + os.makedirs( cls.test_dir, exist_ok=True ) cls.test_dict = build_test_dict() - def compare_wrapper_dict(self, x, y): + def compare_wrapper_dict( self, x, y ): kx = x.keys() ky = y.keys() for k in kx: if k not in ky: - raise Exception('y key not in x object (%s)' % (k)) + raise Exception( 'y key not in x object (%s)' % ( k ) ) for k in ky: if k not in kx: - raise Exception('x key not in y object (%s)' % (k)) + raise Exception( 'x key not in y object (%s)' % ( k ) ) - vx, vy = x[k], y[k] - tx, ty = type(vx), type(vy) - if ((tx != ty) and not (isinstance(vx, (dict, hdf5_wrapper.hdf5_wrapper)) - and isinstance(vy, (dict, hdf5_wrapper.hdf5_wrapper)))): - self.assertTrue(np.issubdtype(tx, ty)) + vx, vy = x[ k ], y[ k ] + tx, ty = type( vx ), type( vy ) + if ( ( tx != ty ) and not ( isinstance( vx, ( dict, hdf5_wrapper.hdf5_wrapper ) ) + and isinstance( vy, ( dict, hdf5_wrapper.hdf5_wrapper ) ) ) ): + self.assertTrue( np.issubdtype( tx, ty ) ) - if isinstance(vx, (dict, hdf5_wrapper.hdf5_wrapper)): - self.compare_wrapper_dict(vx, vy) + if isinstance( vx, ( dict, hdf5_wrapper.hdf5_wrapper ) ): + self.compare_wrapper_dict( vx, vy ) else: - if isinstance(vx, np.ndarray): - self.assertTrue(np.shape(vx) == np.shape(vy)) - self.assertTrue((vx == vy).all()) + if isinstance( vx, np.ndarray ): + self.assertTrue( np.shape( vx ) == np.shape( vy ) ) + self.assertTrue( ( vx == vy ).all() ) else: - self.assertTrue(vx == vy) + self.assertTrue( vx == vy ) - def test_a_insert_write(self): - data = hdf5_wrapper.hdf5_wrapper(os.path.join(self.test_dir, 'test_insert.hdf5'), mode='w') - data.insert(self.test_dict) + def test_a_insert_write( self ): + data = hdf5_wrapper.hdf5_wrapper( os.path.join( self.test_dir, 'test_insert.hdf5' ), mode='w' ) + data.insert( self.test_dict ) - def test_b_manual_write(self): - data = hdf5_wrapper.hdf5_wrapper(os.path.join(self.test_dir, 'test_manual.hdf5'), mode='w') + def test_b_manual_write( self ): + data = hdf5_wrapper.hdf5_wrapper( os.path.join( self.test_dir, 'test_manual.hdf5' ), mode='w' ) for k, v in self.test_dict.items(): - data[k] = v + data[ k ] = v - def test_c_link_write(self): - data = hdf5_wrapper.hdf5_wrapper(os.path.join(self.test_dir, 'test_linked.hdf5'), mode='w') + def test_c_link_write( self ): + data = hdf5_wrapper.hdf5_wrapper( os.path.join( self.test_dir, 'test_linked.hdf5' ), mode='w' ) for k, v in self.test_dict.items(): - if ('child' in k): - child_path = os.path.join(self.test_dir, 'test_%s.hdf5' % (k)) - data_child = hdf5_wrapper.hdf5_wrapper(child_path, mode='w') - data_child.insert(v) - data.link(k, child_path) + if ( 'child' in k ): + child_path = os.path.join( self.test_dir, 'test_%s.hdf5' % ( k ) ) + data_child = hdf5_wrapper.hdf5_wrapper( child_path, mode='w' ) + data_child.insert( v ) + data.link( k, child_path ) else: - data[k] = v + data[ k ] = v - def test_d_compare_wrapper(self): - data = hdf5_wrapper.hdf5_wrapper(os.path.join(self.test_dir, 'test_insert.hdf5')) - self.compare_wrapper_dict(self.test_dict, data) + def test_d_compare_wrapper( self ): + data = hdf5_wrapper.hdf5_wrapper( os.path.join( self.test_dir, 'test_insert.hdf5' ) ) + self.compare_wrapper_dict( self.test_dict, data ) - def test_e_compare_wrapper_copy(self): - data = hdf5_wrapper.hdf5_wrapper(os.path.join(self.test_dir, 'test_insert.hdf5')) + def test_e_compare_wrapper_copy( self ): + data = hdf5_wrapper.hdf5_wrapper( os.path.join( self.test_dir, 'test_insert.hdf5' ) ) tmp = data.copy() - self.compare_wrapper_dict(self.test_dict, tmp) + self.compare_wrapper_dict( self.test_dict, tmp ) - def test_f_compare_wrapper(self): - data = hdf5_wrapper.hdf5_wrapper(os.path.join(self.test_dir, 'test_manual.hdf5')) - self.compare_wrapper_dict(self.test_dict, data) + def test_f_compare_wrapper( self ): + data = hdf5_wrapper.hdf5_wrapper( os.path.join( self.test_dir, 'test_manual.hdf5' ) ) + self.compare_wrapper_dict( self.test_dict, data ) - def test_g_compare_wrapper(self): - data = hdf5_wrapper.hdf5_wrapper(os.path.join(self.test_dir, 'test_linked.hdf5')) - self.compare_wrapper_dict(self.test_dict, data) + def test_g_compare_wrapper( self ): + data = hdf5_wrapper.hdf5_wrapper( os.path.join( self.test_dir, 'test_linked.hdf5' ) ) + self.compare_wrapper_dict( self.test_dict, data ) def main(): @@ -112,12 +112,12 @@ def main(): # Parse the user arguments parser = argparse.ArgumentParser() - parser.add_argument('-v', '--verbose', type=int, help='Verbosity level', default=2) + parser.add_argument( '-v', '--verbose', type=int, help='Verbosity level', default=2 ) args = parser.parse_args() # Unit manager tests - suite = unittest.TestLoader().loadTestsFromTestCase(TestHDF5Wrapper) - unittest.TextTestRunner(verbosity=args.verbose).run(suite) + suite = unittest.TestLoader().loadTestsFromTestCase( TestHDF5Wrapper ) + unittest.TextTestRunner( verbosity=args.verbose ).run( suite ) if __name__ == "__main__": diff --git a/pygeosx_tools_package/pygeosx_tools/file_io.py b/pygeosx_tools_package/pygeosx_tools/file_io.py index 12093e8..a8a517c 100644 --- a/pygeosx_tools_package/pygeosx_tools/file_io.py +++ b/pygeosx_tools_package/pygeosx_tools/file_io.py @@ -3,10 +3,10 @@ from typing import Dict, Iterable, List, Tuple -def save_tables(axes: Iterable[np.ndarray], - properties: Dict[str, np.ndarray], - table_root: str = './tables', - axes_names: List[str] = []) -> None: +def save_tables( axes: Iterable[ np.ndarray ], + properties: Dict[ str, np.ndarray ], + table_root: str = './tables', + axes_names: List[ str ] = [] ) -> None: """ Saves a set of tables in GEOSX format @@ -22,45 +22,45 @@ def save_tables(axes: Iterable[np.ndarray], axes_names (list): A list of names for each potential axis (optional) """ # Check to see if the axes, properties have consistent shapes - axes_size = tuple([len(x) for x in axes]) - axes_dimension = len(axes_size) + axes_size = tuple( [ len( x ) for x in axes ] ) + axes_dimension = len( axes_size ) for k, p in properties.items(): - property_size = np.shape(p) - if (property_size != axes_size): - print('Property:', k) - print('Grid size:', axes_size) - print('Property size', property_size) - raise Exception('Table dimensions do not match proprerties') + property_size = np.shape( p ) + if ( property_size != axes_size ): + print( 'Property:', k ) + print( 'Grid size:', axes_size ) + print( 'Property size', property_size ) + raise Exception( 'Table dimensions do not match proprerties' ) # Check the axes names if axes_names: - if (axes_dimension != len(axes_names)): - print('Axes dimensions:', axes_dimension) - print('Number of axis names provided:', len(axes_names)) - raise Exception('The grid dimensions and axes names do not match') + if ( axes_dimension != len( axes_names ) ): + print( 'Axes dimensions:', axes_dimension ) + print( 'Number of axis names provided:', len( axes_names ) ) + raise Exception( 'The grid dimensions and axes names do not match' ) else: - if (axes_dimension == 1): - axes_names = ['t'] - elif (axes_dimension == 3): - axes_names = ['x', 'y', 'z'] - elif (axes_dimension == 4): - axes_names = ['x', 'y', 'z', 't'] + if ( axes_dimension == 1 ): + axes_names = [ 't' ] + elif ( axes_dimension == 3 ): + axes_names = [ 'x', 'y', 'z' ] + elif ( axes_dimension == 4 ): + axes_names = [ 'x', 'y', 'z', 't' ] else: - axes_names = ['x%i' % (ii) for ii in range(axes_dimension)] + axes_names = [ 'x%i' % ( ii ) for ii in range( axes_dimension ) ] # Write the axes - os.makedirs(table_root, exist_ok=True) - for g, a in zip(axes, axes_names): - np.savetxt('%s/%s.csv' % (table_root, a), g, fmt='%1.5f', delimiter=',') + os.makedirs( table_root, exist_ok=True ) + for g, a in zip( axes, axes_names ): + np.savetxt( '%s/%s.csv' % ( table_root, a ), g, fmt='%1.5f', delimiter=',' ) for k, p in properties.items(): - np.savetxt('%s/%s.csv' % (table_root, k), np.reshape(p, (-1), order='F'), fmt='%1.5e', delimiter=',') + np.savetxt( '%s/%s.csv' % ( table_root, k ), np.reshape( p, ( -1 ), order='F' ), fmt='%1.5e', delimiter=',' ) -def load_tables(axes_names: Iterable[str], - property_names: Iterable[str], - table_root: str = './tables', - extension: str = 'csv') -> Tuple[Iterable[np.ndarray], Dict[str, np.ndarray]]: +def load_tables( axes_names: Iterable[ str ], + property_names: Iterable[ str ], + table_root: str = './tables', + extension: str = 'csv' ) -> Tuple[ Iterable[ np.ndarray ], Dict[ str, np.ndarray ] ]: """ Load a set of tables in GEOSX format @@ -74,12 +74,12 @@ def load_tables(axes_names: Iterable[str], tuple: List of axes values, and dictionary of table values """ # Load axes - axes = [np.loadtxt('%s/%s.%s' % (table_root, axis, extension), unpack=True) for axis in axes_names] - N = tuple([len(x) for x in axes]) + axes = [ np.loadtxt( '%s/%s.%s' % ( table_root, axis, extension ), unpack=True ) for axis in axes_names ] + N = tuple( [ len( x ) for x in axes ] ) # Load properties properties = { - p: np.reshape(np.loadtxt('%s/%s.%s' % (table_root, p, extension)), N, order='F') + p: np.reshape( np.loadtxt( '%s/%s.%s' % ( table_root, p, extension ) ), N, order='F' ) for p in property_names } diff --git a/pygeosx_tools_package/pygeosx_tools/mesh_interpolation.py b/pygeosx_tools_package/pygeosx_tools/mesh_interpolation.py index 49b2011..5b73330 100644 --- a/pygeosx_tools_package/pygeosx_tools/mesh_interpolation.py +++ b/pygeosx_tools_package/pygeosx_tools/mesh_interpolation.py @@ -1,13 +1,13 @@ import numpy as np -from scipy import stats # type: ignore[import] +from scipy import stats # type: ignore[import] from typing import Dict, Iterable, List, Tuple, Callable, Union -def apply_to_bins(fn: Callable[[Union[float, np.ndarray]], float], - position: np.ndarray, - value: np.ndarray, - bins: np.ndarray, - collapse_edges: bool = True): +def apply_to_bins( fn: Callable[ [ Union[ float, np.ndarray ] ], float ], + position: np.ndarray, + value: np.ndarray, + bins: np.ndarray, + collapse_edges: bool = True ): """ Apply a function to values that are located within a series of bins Note: if a bin is empty, this function will fill a nan value @@ -23,29 +23,29 @@ def apply_to_bins(fn: Callable[[Union[float, np.ndarray]], float], np.ndarray: an array of function results for each bin """ # Sort values into bins - Nr = len(bins) + 1 - Ibin = np.digitize(position, bins) + Nr = len( bins ) + 1 + Ibin = np.digitize( position, bins ) if collapse_edges: Nr -= 2 Ibin -= 1 - Ibin[Ibin == -1] = 0 - Ibin[Ibin == Nr] = Nr - 1 + Ibin[ Ibin == -1 ] = 0 + Ibin[ Ibin == Nr ] = Nr - 1 # Apply functions to bins - binned_values = np.zeros(Nr) - for ii in range(Nr): - tmp = (Ibin == ii) - if np.sum(tmp): - binned_values[ii] = fn(value[tmp]) + binned_values = np.zeros( Nr ) + for ii in range( Nr ): + tmp = ( Ibin == ii ) + if np.sum( tmp ): + binned_values[ ii ] = fn( value[ tmp ] ) else: # Empty bin - binned_values[ii] = np.NaN + binned_values[ ii ] = np.NaN return binned_values -def extrapolate_nan_values(x, y, slope_scale=0.0): +def extrapolate_nan_values( x, y, slope_scale=0.0 ): """ Fill in any nan values in two 1D arrays by extrapolating @@ -57,13 +57,13 @@ def extrapolate_nan_values(x, y, slope_scale=0.0): Returns: np.ndarray: The input array with nan values replaced by extrapolated data """ - Inan = np.isnan(y) - reg = stats.linregress(x[~Inan], y[~Inan]) - y[Inan] = reg[0] * x[Inan] * slope_scale + reg[1] + Inan = np.isnan( y ) + reg = stats.linregress( x[ ~Inan ], y[ ~Inan ] ) + y[ Inan ] = reg[ 0 ] * x[ Inan ] * slope_scale + reg[ 1 ] return y -def get_random_realization(x, bins, value, rand_fill=0, rand_scale=0, slope_scale=0): +def get_random_realization( x, bins, value, rand_fill=0, rand_scale=0, slope_scale=0 ): """ Get a random realization for a noisy signal with a set of bins @@ -78,20 +78,20 @@ def get_random_realization(x, bins, value, rand_fill=0, rand_scale=0, slope_scal Returns: np.ndarray: An array containing the random realization """ - y_mean = apply_to_bins(np.mean, x, value, bins) - y_std = apply_to_bins(np.std, x, value, bins) + y_mean = apply_to_bins( np.mean, x, value, bins ) + y_std = apply_to_bins( np.std, x, value, bins ) # Extrapolate to fill the upper/lower bounds - x_mid = bins[:-1] + 0.5 * (bins[1] - bins[0]) - y_mean = extrapolate_nan_values(x_mid, y_mean, slope_scale) - y_std[np.isnan(y_std)] = rand_fill + x_mid = bins[ :-1 ] + 0.5 * ( bins[ 1 ] - bins[ 0 ] ) + y_mean = extrapolate_nan_values( x_mid, y_mean, slope_scale ) + y_std[ np.isnan( y_std ) ] = rand_fill # Add a random perturbation to the target value to match missing high/lows - y_final = y_mean + (rand_scale * y_std * np.random.randn(len(y_mean))) + y_final = y_mean + ( rand_scale * y_std * np.random.randn( len( y_mean ) ) ) return y_final -def get_realizations(x, bins, targets): +def get_realizations( x, bins, targets ): """ Get random realizations for noisy signals on target bins @@ -106,5 +106,5 @@ def get_realizations(x, bins, targets): """ results = {} for k, t in targets.items(): - results[k] = get_random_realization(x, bins, **t) + results[ k ] = get_random_realization( x, bins, **t ) return results diff --git a/pygeosx_tools_package/pygeosx_tools/well_log.py b/pygeosx_tools_package/pygeosx_tools/well_log.py index 58453f0..92ad1e7 100644 --- a/pygeosx_tools_package/pygeosx_tools/well_log.py +++ b/pygeosx_tools_package/pygeosx_tools/well_log.py @@ -2,7 +2,7 @@ import re -def parse_las(fname, variable_start='~C', body_start='~A'): +def parse_las( fname, variable_start='~C', body_start='~A' ): """ Parse an las format log file @@ -19,51 +19,51 @@ def parse_las(fname, variable_start='~C', body_start='~A'): # The expected format of the varible definition block is: # name.units code:description - variable_regex = re.compile('\s*([^\.^\s]*)\s*(\.[^ ]*) ([^:]*):(.*)') + variable_regex = re.compile( '\s*([^\.^\s]*)\s*(\.[^ ]*) ([^:]*):(.*)' ) - with open(fname) as f: + with open( fname ) as f: file_location = 0 for line in f: - line = line.split('#')[0] + line = line.split( '#' )[ 0 ] if line: # Preamble - if (file_location == 0): + if ( file_location == 0 ): if variable_start in line: file_location += 1 # Variable definitions - elif (file_location == 1): + elif ( file_location == 1 ): # This is not a comment line if body_start in line: file_location += 1 else: - match = variable_regex.match(line) + match = variable_regex.match( line ) if match: - variable_order.append(match[1]) - results[match[1]] = { - 'units': match[2][0:], - 'code': match[3], - 'description': match[4], + variable_order.append( match[ 1 ] ) + results[ match[ 1 ] ] = { + 'units': match[ 2 ][ 0: ], + 'code': match[ 3 ], + 'description': match[ 4 ], 'values': [] } else: # As a fall-back use the full line - variable_order.append(line[:-1]) - results[line[:-1]] = {'units': '', 'code': '', 'description': '', 'values': []} + variable_order.append( line[ :-1 ] ) + results[ line[ :-1 ] ] = { 'units': '', 'code': '', 'description': '', 'values': [] } # Body else: - for k, v in zip(variable_order, line.split()): - results[k]['values'].append(float(v)) + for k, v in zip( variable_order, line.split() ): + results[ k ][ 'values' ].append( float( v ) ) # Convert values to numpy arrays for k in results: - results[k]['values'] = np.array(results[k]['values']) + results[ k ][ 'values' ] = np.array( results[ k ][ 'values' ] ) return results -def convert_E_nu_to_K_G(E, nu): +def convert_E_nu_to_K_G( E, nu ): """ Convert young's modulus and poisson's ratio to bulk and shear modulus @@ -74,12 +74,12 @@ def convert_E_nu_to_K_G(E, nu): Returns: tuple: bulk modulus, shear modulus with same size as inputs """ - K = E / (3.0 * (1 - 2.0 * nu)) - G = E / (2.0 * (1 + nu)) + K = E / ( 3.0 * ( 1 - 2.0 * nu ) ) + G = E / ( 2.0 * ( 1 + nu ) ) return K, G -def estimate_shmin(z, rho, nu): +def estimate_shmin( z, rho, nu ): """ Estimate the minimum horizontal stress using the poisson's ratio @@ -91,6 +91,6 @@ def estimate_shmin(z, rho, nu): Returns: float: minimum horizontal stress """ - k = nu / (1.0 - nu) + k = nu / ( 1.0 - nu ) sigma_h = k * rho * 9.81 * z return sigma_h diff --git a/pygeosx_tools_package/pygeosx_tools/wrapper.py b/pygeosx_tools_package/pygeosx_tools/wrapper.py index 7c46e74..0c154c9 100644 --- a/pygeosx_tools_package/pygeosx_tools/wrapper.py +++ b/pygeosx_tools_package/pygeosx_tools/wrapper.py @@ -10,7 +10,7 @@ rank = comm.Get_rank() -def get_wrapper(problem, target_key, write_flag=False): +def get_wrapper( problem, target_key, write_flag=False ): """ Get a local copy of a wrapper as a numpy ndarray @@ -23,22 +23,22 @@ def get_wrapper(problem, target_key, write_flag=False): Returns: np.ndarray: The wrapper as a numpy ndarray """ - local_values = problem.get_wrapper(target_key).value() + local_values = problem.get_wrapper( target_key ).value() - if hasattr(local_values, "set_access_level"): + if hasattr( local_values, "set_access_level" ): # Array types will have the set_access_level method # These require additional manipulation before use if write_flag: - local_values.set_access_level(pylvarray.MODIFIABLE, pylvarray.CPU) + local_values.set_access_level( pylvarray.MODIFIABLE, pylvarray.CPU ) else: - local_values.set_access_level(pylvarray.CPU) + local_values.set_access_level( pylvarray.CPU ) - if hasattr(local_values, "to_numpy"): + if hasattr( local_values, "to_numpy" ): local_values = local_values.to_numpy() return local_values -def get_wrapper_par(problem, target_key, allgather=False, ghost_key=''): +def get_wrapper_par( problem, target_key, allgather=False, ghost_key='' ): """ Get a global copy of a wrapper as a numpy ndarray. Note: if ghost_key is set, it will try to remove any ghost elements @@ -52,62 +52,62 @@ def get_wrapper_par(problem, target_key, allgather=False, ghost_key=''): Returns: np.ndarray: The wrapper as a numpy ndarray """ - if (comm.size == 1): + if ( comm.size == 1 ): # This is a serial problem - return get_wrapper(problem, target_key) + return get_wrapper( problem, target_key ) else: # This is a parallel problem # Get the local wrapper size, shape - local_values = get_wrapper(problem, target_key) + local_values = get_wrapper( problem, target_key ) # Filter out ghost ranks if requested if ghost_key: - ghost_values = get_wrapper(problem, ghost_key) - local_values = local_values[ghost_values < -0.5] + ghost_values = get_wrapper( problem, ghost_key ) + local_values = local_values[ ghost_values < -0.5 ] # Find buffer size - N = np.shape(local_values) - M = np.prod(N) + N = np.shape( local_values ) + M = np.prod( N ) all_M = [] max_M = 0 if allgather: - all_M = comm.allgather(M) - max_M = np.amax(all_M) + all_M = comm.allgather( M ) + max_M = np.amax( all_M ) else: - all_M = comm.gather(M, root=0) - if (rank == 0): - max_M = np.amax(all_M) - max_M = comm.bcast(max_M, root=0) + all_M = comm.gather( M, root=0 ) + if ( rank == 0 ): + max_M = np.amax( all_M ) + max_M = comm.bcast( max_M, root=0 ) # Pack the array into a buffer - send_buff = np.zeros(max_M) - send_buff[:M] = np.reshape(local_values, (-1)) - receive_buff = np.zeros((comm.size, max_M)) + send_buff = np.zeros( max_M ) + send_buff[ :M ] = np.reshape( local_values, ( -1 ) ) + receive_buff = np.zeros( ( comm.size, max_M ) ) # Gather the buffers if allgather: - comm.Allgather([send_buff, MPI.DOUBLE], [receive_buff, MPI.DOUBLE]) + comm.Allgather( [ send_buff, MPI.DOUBLE ], [ receive_buff, MPI.DOUBLE ] ) else: - comm.Gather([send_buff, MPI.DOUBLE], [receive_buff, MPI.DOUBLE], root=0) + comm.Gather( [ send_buff, MPI.DOUBLE ], [ receive_buff, MPI.DOUBLE ], root=0 ) # Unpack the buffers all_values = [] - R = list(N) - R[0] = -1 - if ((rank == 0) | allgather): + R = list( N ) + R[ 0 ] = -1 + if ( ( rank == 0 ) | allgather ): # Reshape each rank's contribution - for ii in range(comm.size): - if (all_M[ii] > 0): - tmp = np.reshape(receive_buff[ii, :all_M[ii]], R) - all_values.append(tmp) + for ii in range( comm.size ): + if ( all_M[ ii ] > 0 ): + tmp = np.reshape( receive_buff[ ii, :all_M[ ii ] ], R ) + all_values.append( tmp ) # Concatenate into a single array - all_values = np.concatenate(all_values, axis=0) + all_values = np.concatenate( all_values, axis=0 ) return all_values -def gather_wrapper(problem, key, ghost_key=''): +def gather_wrapper( problem, key, ghost_key='' ): """ Get a global copy of a wrapper as a numpy ndarray on rank 0 @@ -118,10 +118,10 @@ def gather_wrapper(problem, key, ghost_key=''): Returns: np.ndarray: The wrapper as a numpy ndarray """ - return get_wrapper_par(problem, key, ghost_key=ghost_key) + return get_wrapper_par( problem, key, ghost_key=ghost_key ) -def allgather_wrapper(problem, key, ghost_key=''): +def allgather_wrapper( problem, key, ghost_key='' ): """ Get a global copy of a wrapper as a numpy ndarray on all ranks @@ -132,10 +132,10 @@ def allgather_wrapper(problem, key, ghost_key=''): Returns: np.ndarray: The wrapper as a numpy ndarray """ - return get_wrapper_par(problem, key, allgather=True, ghost_key=ghost_key) + return get_wrapper_par( problem, key, allgather=True, ghost_key=ghost_key ) -def get_global_value_range(problem, key): +def get_global_value_range( problem, key ): """ Get the range of a target value across all processes @@ -146,38 +146,38 @@ def get_global_value_range(problem, key): Returns: tuple: The global min/max of the target """ - local_values = get_wrapper(problem, key) + local_values = get_wrapper( problem, key ) # 1D arrays will return a scalar, ND arrays an array - N = np.shape(local_values) + N = np.shape( local_values ) local_min = 1e100 local_max = -1e100 - if (len(N) > 1): - local_min = np.zeros(N[1]) + 1e100 - local_max = np.zeros(N[1]) - 1e100 + if ( len( N ) > 1 ): + local_min = np.zeros( N[ 1 ] ) + 1e100 + local_max = np.zeros( N[ 1 ] ) - 1e100 # For >1D arrays, keep the last dimension query_axis = 0 - if (len(N) > 2): - query_axis = tuple([ii for ii in range(0, len(N) - 1)]) + if ( len( N ) > 2 ): + query_axis = tuple( [ ii for ii in range( 0, len( N ) - 1 ) ] ) # Ignore zero-length results - if len(local_values): - local_min = np.amin(local_values, axis=query_axis) - local_max = np.amax(local_values, axis=query_axis) + if len( local_values ): + local_min = np.amin( local_values, axis=query_axis ) + local_max = np.amax( local_values, axis=query_axis ) # Gather the results onto rank 0 - all_min = comm.gather(local_min, root=0) - all_max = comm.gather(local_max, root=0) + all_min = comm.gather( local_min, root=0 ) + all_max = comm.gather( local_max, root=0 ) global_min = 1e100 global_max = -1e100 - if (rank == 0): - global_min = np.amin(np.array(all_min), axis=0) - global_max = np.amax(np.array(all_max), axis=0) + if ( rank == 0 ): + global_min = np.amin( np.array( all_min ), axis=0 ) + global_max = np.amax( np.array( all_max ), axis=0 ) return global_min, global_max -def print_global_value_range(problem, key, header, scale=1.0, precision='%1.4f'): +def print_global_value_range( problem, key, header, scale=1.0, precision='%1.4f' ): """ Print the range of a target value across all processes @@ -191,25 +191,25 @@ def print_global_value_range(problem, key, header, scale=1.0, precision='%1.4f') Returns: tuple: The global min/max of the target """ - global_min, global_max = get_global_value_range(problem, key) + global_min, global_max = get_global_value_range( problem, key ) global_min *= scale global_max *= scale - if (rank == 0): - if isinstance(global_min, np.ndarray): - min_str = ', '.join([precision % (x) for x in global_min]) - max_str = ', '.join([precision % (x) for x in global_max]) - print('%s: min=[%s], max=[%s]' % (header, min_str, max_str)) + if ( rank == 0 ): + if isinstance( global_min, np.ndarray ): + min_str = ', '.join( [ precision % ( x ) for x in global_min ] ) + max_str = ', '.join( [ precision % ( x ) for x in global_max ] ) + print( '%s: min=[%s], max=[%s]' % ( header, min_str, max_str ) ) else: - min_str = precision % (global_min) - max_str = precision % (global_max) - print('%s: min=%s, max=%s' % (header, min_str, max_str)) + min_str = precision % ( global_min ) + max_str = precision % ( global_max ) + print( '%s: min=%s, max=%s' % ( header, min_str, max_str ) ) # Return a copy of the min/max in case we want to use them return global_min, global_max -def set_wrapper_to_value(problem, key, value): +def set_wrapper_to_value( problem, key, value ): """ Set the value of a wrapper @@ -218,11 +218,11 @@ def set_wrapper_to_value(problem, key, value): target_key (str): Key for the target wrapper value (float): Value to set the wrapper """ - local_values = get_wrapper(problem, key, write_flag=True) + local_values = get_wrapper( problem, key, write_flag=True ) local_values[...] = value -def set_wrapper_with_function(problem, target_key, input_keys, fn, target_index=-1): +def set_wrapper_with_function( problem, target_key, input_keys, fn, target_index=-1 ): """ Set the value of a wrapper using a function @@ -233,44 +233,44 @@ def set_wrapper_with_function(problem, target_key, input_keys, fn, target_index= fn (function): Vectorized function used to calculate target values target_index (int): Target index to write the output (default = all) """ - if isinstance(input_keys, str): - input_keys = [input_keys] - local_target = get_wrapper(problem, target_key, write_flag=True) - local_inputs = [get_wrapper(problem, k) for k in input_keys] + if isinstance( input_keys, str ): + input_keys = [ input_keys ] + local_target = get_wrapper( problem, target_key, write_flag=True ) + local_inputs = [ get_wrapper( problem, k ) for k in input_keys ] # Run the function, check the shape of outputs/target - fn_output = fn(*local_inputs) - N = np.shape(local_target) - M = np.shape(fn_output) + fn_output = fn( *local_inputs ) + N = np.shape( local_target ) + M = np.shape( fn_output ) - if (target_index < 0): - if (N == M): + if ( target_index < 0 ): + if ( N == M ): # Function output, target shapes are the same local_target[...] = fn_output - elif (len(M) == 1): + elif ( len( M ) == 1 ): # Apply the function output across each of the target dims - local_target[...] = np.tile(np.expand_dims(fn_output, axis=1), (1, N[1])) + local_target[...] = np.tile( np.expand_dims( fn_output, axis=1 ), ( 1, N[ 1 ] ) ) else: - raise Exception('Shape of function output %s is not compatible with target %s' % (str(M), str(N))) - elif (len(M) == 1): - if (len(N) == 2): + raise Exception( 'Shape of function output %s is not compatible with target %s' % ( str( M ), str( N ) ) ) + elif ( len( M ) == 1 ): + if ( len( N ) == 2 ): # 2D target, with 1D output applied to a given index - local_target[:, target_index] = fn_output + local_target[ :, target_index ] = fn_output else: # ND target, with 1D output tiled across intermediate indices - expand_axes = tuple([ii for ii in range(1, len(N) - 1)]) - tile_axes = tuple([1] + [ii for ii in N[1:-1]]) - local_target[..., target_index] = np.tile(np.expand_dims(fn_output, axis=expand_axes), tile_axes) + expand_axes = tuple( [ ii for ii in range( 1, len( N ) - 1 ) ] ) + tile_axes = tuple( [ 1 ] + [ ii for ii in N[ 1:-1 ] ] ) + local_target[..., target_index ] = np.tile( np.expand_dims( fn_output, axis=expand_axes ), tile_axes ) else: - raise Exception('Shape of function output %s is not compatible with target %s (target axis=%i)' % - (str(M), str(N), target_index)) + raise Exception( 'Shape of function output %s is not compatible with target %s (target axis=%i)' % + ( str( M ), str( N ), target_index ) ) -def search_datastructure_wrappers_recursive(group, filters, matching_paths, level=0, group_path=[]): +def search_datastructure_wrappers_recursive( group, filters, matching_paths, level=0, group_path=[] ): """ Recursively search the group and its children for wrappers that match the filters @@ -280,21 +280,21 @@ def search_datastructure_wrappers_recursive(group, filters, matching_paths, leve matching_paths (list): a list of matching values """ for wrapper in group.wrappers(): - wrapper_path = str(wrapper).split()[0] - wrapper_test = group_path + [wrapper_path[wrapper_path.rfind('/') + 1:]] - if all(f in wrapper_test for f in filters): - matching_paths.append('/'.join(wrapper_test)) + wrapper_path = str( wrapper ).split()[ 0 ] + wrapper_test = group_path + [ wrapper_path[ wrapper_path.rfind( '/' ) + 1: ] ] + if all( f in wrapper_test for f in filters ): + matching_paths.append( '/'.join( wrapper_test ) ) for sub_group in group.groups(): - sub_group_name = str(sub_group).split()[0].split('/')[-1] - search_datastructure_wrappers_recursive(sub_group, - filters, - matching_paths, - level=level + 1, - group_path=group_path + [sub_group_name]) + sub_group_name = str( sub_group ).split()[ 0 ].split( '/' )[ -1 ] + search_datastructure_wrappers_recursive( sub_group, + filters, + matching_paths, + level=level + 1, + group_path=group_path + [ sub_group_name ] ) -def get_matching_wrapper_path(problem, filters): +def get_matching_wrapper_path( problem, filters ): """ Recursively search the group and its children for wrappers that match the filters A successful match is identified if the wrapper path contains all of the @@ -310,22 +310,22 @@ def get_matching_wrapper_path(problem, filters): str: Key of the matching wrapper """ matching_paths = [] - search_datastructure_wrappers_recursive(problem, filters, matching_paths) + search_datastructure_wrappers_recursive( problem, filters, matching_paths ) - if (len(matching_paths) == 1): - if (rank == 0): - print('Found matching wrapper: %s' % (matching_paths[0])) - return matching_paths[0] + if ( len( matching_paths ) == 1 ): + if ( rank == 0 ): + print( 'Found matching wrapper: %s' % ( matching_paths[ 0 ] ) ) + return matching_paths[ 0 ] else: - if (rank == 0): - print('Error occured while looking for wrappers:') - print('Filters: [%s]' % (', '.join(filters))) - print('Matching wrappers: [%s]' % (', '.join(matching_paths))) - raise Exception('Search resulted in 0 or >1 wrappers mathching filters') + if ( rank == 0 ): + print( 'Error occured while looking for wrappers:' ) + print( 'Filters: [%s]' % ( ', '.join( filters ) ) ) + print( 'Matching wrappers: [%s]' % ( ', '.join( matching_paths ) ) ) + raise Exception( 'Search resulted in 0 or >1 wrappers mathching filters' ) -def run_queries(problem, records): +def run_queries( problem, records ): """ Query the current GEOSX datastructure Note: The expected record request format is as follows. @@ -340,16 +340,16 @@ def run_queries(problem, records): records (dict): A dict of dicts that specifies the queries to run """ for k in records.keys(): - if (k == 'time'): - current_time = get_wrapper(problem, "Events/time") - records[k]['history'].append(current_time * records[k]['scale']) + if ( k == 'time' ): + current_time = get_wrapper( problem, "Events/time" ) + records[ k ][ 'history' ].append( current_time * records[ k ][ 'scale' ] ) else: - tmp = print_global_value_range(problem, k, records[k]['label'], scale=records[k]['scale']) - records[k]['history'].append(tmp) + tmp = print_global_value_range( problem, k, records[ k ][ 'label' ], scale=records[ k ][ 'scale' ] ) + records[ k ][ 'history' ].append( tmp ) sys.stdout.flush() -def plot_history(records, output_root='.', save_figures=True, show_figures=True): +def plot_history( records, output_root='.', save_figures=True, show_figures=True ): """ Plot the time-histories for the records structure. Note: If figures are shown, the GEOSX process will be blocked until they are closed @@ -360,47 +360,47 @@ def plot_history(records, output_root='.', save_figures=True, show_figures=True) save_figures (bool): Flag to indicate whether figures should be saved (default = True) show_figures (bool): Flag to indicate whether figures should be drawn (default = False) """ - if (rank == 0): + if ( rank == 0 ): for k in records.keys(): - if (k != 'time'): + if ( k != 'time' ): # Set the active figure - fa = plt.figure(records[k]['fhandle'].number) + fa = plt.figure( records[ k ][ 'fhandle' ].number ) # Assemble values to plot - t = np.array(records['time']['history']) - x = np.array(records[k]['history']) - N = np.shape(x) # (time, min/max, dimensions) + t = np.array( records[ 'time' ][ 'history' ] ) + x = np.array( records[ k ][ 'history' ] ) + N = np.shape( x ) # (time, min/max, dimensions) # Add plots - if (len(N) == 2): + if ( len( N ) == 2 ): # This is a 1D field plt.gca().cla() - plt.plot(t, x[:, 0], label='min') - plt.plot(t, x[:, 1], label='max') - plt.xlabel(records['time']['label']) - plt.ylabel(records[k]['label']) + plt.plot( t, x[ :, 0 ], label='min' ) + plt.plot( t, x[ :, 1 ], label='max' ) + plt.xlabel( records[ 'time' ][ 'label' ] ) + plt.ylabel( records[ k ][ 'label' ] ) else: # This is a 2D field columns = 2 - rows = int(np.ceil(N[2] / float(columns))) + rows = int( np.ceil( N[ 2 ] / float( columns ) ) ) # Setup axes - if (('axes' not in records[k]) or (len(fa.axes) == 0)): - records[k]['axes'] = [plt.subplot(rows, columns, ii + 1) for ii in range(0, N[2])] + if ( ( 'axes' not in records[ k ] ) or ( len( fa.axes ) == 0 ) ): + records[ k ][ 'axes' ] = [ plt.subplot( rows, columns, ii + 1 ) for ii in range( 0, N[ 2 ] ) ] - for ii in range(0, N[2]): - ax = records[k]['axes'][ii] + for ii in range( 0, N[ 2 ] ): + ax = records[ k ][ 'axes' ][ ii ] ax.cla() - ax.plot(t, x[:, 0, ii], label='min') - ax.plot(t, x[:, 1, ii], label='max') - ax.set_xlabel(records['time']['label']) - ax.set_ylabel('%s (dim=%i)' % (records[k]['label'], ii)) - plt.legend(loc=2) - records[k]['fhandle'].tight_layout(pad=1.5) + ax.plot( t, x[ :, 0, ii ], label='min' ) + ax.plot( t, x[ :, 1, ii ], label='max' ) + ax.set_xlabel( records[ 'time' ][ 'label' ] ) + ax.set_ylabel( '%s (dim=%i)' % ( records[ k ][ 'label' ], ii ) ) + plt.legend( loc=2 ) + records[ k ][ 'fhandle' ].tight_layout( pad=1.5 ) if save_figures: - fname = k[k.rfind('/') + 1:] - plt.savefig('%s/%s.png' % (output_root, fname), format='png') + fname = k[ k.rfind( '/' ) + 1: ] + plt.savefig( '%s/%s.png' % ( output_root, fname ), format='png' ) if show_figures: plt.show() diff --git a/timehistory_package/timehistory/plot_time_history.py b/timehistory_package/timehistory/plot_time_history.py index 839288c..512c4b1 100644 --- a/timehistory_package/timehistory/plot_time_history.py +++ b/timehistory_package/timehistory/plot_time_history.py @@ -9,15 +9,15 @@ import re -def isiterable(obj): +def isiterable( obj ): try: - it = iter(obj) + it = iter( obj ) except TypeError: return False return True -def getHistorySeries(database, variable, setname, indices=None, components=None): +def getHistorySeries( database, variable, setname, indices=None, components=None ): """ Retrieve a series of time history structures suitable for plotting in addition to the specific set index and component for the time series @@ -32,123 +32,123 @@ def getHistorySeries(database, variable, setname, indices=None, components=None) list: list of (time, data, idx, comp) timeseries tuples for each time history data component """ - set_regex = re.compile(variable + '(.*?)', re.IGNORECASE) + set_regex = re.compile( variable + '(.*?)', re.IGNORECASE ) if setname is not None: - set_regex = re.compile(variable + '\s*' + str(setname), re.IGNORECASE) - time_regex = re.compile('Time', re.IGNORECASE) # need to make this per-set, thought that was in already? + set_regex = re.compile( variable + '\s*' + str( setname ), re.IGNORECASE ) + time_regex = re.compile( 'Time', re.IGNORECASE ) # need to make this per-set, thought that was in already? - set_match = list(filter(set_regex.match, database.keys())) - time_match = list(filter(time_regex.match, database.keys())) + set_match = list( filter( set_regex.match, database.keys() ) ) + time_match = list( filter( time_regex.match, database.keys() ) ) - if len(set_match) == 0: - print(f"Error: can't locate time history data for variable/set described by regex {set_regex.pattern}") + if len( set_match ) == 0: + print( f"Error: can't locate time history data for variable/set described by regex {set_regex.pattern}" ) return None - if len(time_match) == 0: - print(f"Error: can't locate time history data for set time variable described by regex {time_regex.pattern}") + if len( time_match ) == 0: + print( f"Error: can't locate time history data for set time variable described by regex {time_regex.pattern}" ) return None - if len(set_match) > 1: - print(f"Warning: variable/set specification matches multiple datasets: {', '.join(set_match)}") - if len(time_match) > 1: - print(f"Warning: set specification matches multiple time datasets: {', '.join(time_match)}") + if len( set_match ) > 1: + print( f"Warning: variable/set specification matches multiple datasets: {', '.join(set_match)}" ) + if len( time_match ) > 1: + print( f"Warning: set specification matches multiple time datasets: {', '.join(time_match)}" ) - set_match = set_match[0] - time_match = time_match[0] + set_match = set_match[ 0 ] + time_match = time_match[ 0 ] - data_series = database[set_match] - time_series = database[time_match] + data_series = database[ set_match ] + time_series = database[ time_match ] - if time_series.shape[0] != data_series.shape[0]: + if time_series.shape[ 0 ] != data_series.shape[ 0 ]: print( f"Error: The length of the time-series {time_match} and data-series {set_match} do not match: {time_series.shape} and {data_series.shape} !" ) if indices is not None: - if type(indices) is int: - indices = list(indices) - if isiterable(indices): - oob_idxs = list(filter(lambda idx: not 0 <= idx < data_series.shape[1], indices)) - if len(oob_idxs) > 0: - print(f"Error: The specified indices: ({', '.join(oob_idxs)}) " + "\n\t" + - f" are out of the dataset index range: [0,{data_series.shape[1]})") - indices = list(set(indices) - set(oob_idxs)) + if type( indices ) is int: + indices = list( indices ) + if isiterable( indices ): + oob_idxs = list( filter( lambda idx: not 0 <= idx < data_series.shape[ 1 ], indices ) ) + if len( oob_idxs ) > 0: + print( f"Error: The specified indices: ({', '.join(oob_idxs)}) " + "\n\t" + + f" are out of the dataset index range: [0,{data_series.shape[1]})" ) + indices = list( set( indices ) - set( oob_idxs ) ) else: - print(f"Error: unsupported indices type: {type(indices)}") + print( f"Error: unsupported indices type: {type(indices)}" ) else: - indices = range(data_series.shape[1]) + indices = range( data_series.shape[ 1 ] ) if components is not None: - if type(components) is int: - components = list(components) - if isiterable(components): - oob_comps = list(filter(lambda comp: not 0 <= comp < data_series.shape[2], components)) - if len(oob_comps) > 0: - print(f"Error: The specified components: ({', '.join(oob_comps)}) " + "\n\t" + - " is out of the dataset component range: [0,{data_series.shape[1]})") - components = list(set(components) - set(oob_comps)) + if type( components ) is int: + components = list( components ) + if isiterable( components ): + oob_comps = list( filter( lambda comp: not 0 <= comp < data_series.shape[ 2 ], components ) ) + if len( oob_comps ) > 0: + print( f"Error: The specified components: ({', '.join(oob_comps)}) " + "\n\t" + + " is out of the dataset component range: [0,{data_series.shape[1]})" ) + components = list( set( components ) - set( oob_comps ) ) else: - print(f"Error: unsupported components type: {type(components)}") + print( f"Error: unsupported components type: {type(components)}" ) else: - components = range(data_series.shape[2]) + components = range( data_series.shape[ 2 ] ) - return [(time_series[:, 0], data_series[:, idx, comp], idx, comp) for idx in indices for comp in components] + return [ ( time_series[ :, 0 ], data_series[ :, idx, comp ], idx, comp ) for idx in indices for comp in components ] def commandLinePlotGen(): parser = argparse.ArgumentParser( description="A script that parses geosx HDF5 time-history files and produces time-history plots using matplotlib" ) - parser.add_argument("filename", metavar="history_file", type=str, help="The time history file to parse") + parser.add_argument( "filename", metavar="history_file", type=str, help="The time history file to parse" ) - parser.add_argument("variable", - metavar="variable_name", - type=str, - help="Which time-history variable collected by GEOSX to generate a plot file for.") + parser.add_argument( "variable", + metavar="variable_name", + type=str, + help="Which time-history variable collected by GEOSX to generate a plot file for." ) parser.add_argument( "--sets", metavar="name", type=str, action='append', - default=[None], + default=[ None ], nargs="+", help= "Which index set of time-history data collected by GEOSX to generate a plot file for, may be specified multiple times with different indices/components for each set." ) - parser.add_argument("--indices", - metavar="index", - type=int, - default=[], - nargs="+", - help="An optional list of specific indices in the most-recently specified set.") + parser.add_argument( "--indices", + metavar="index", + type=int, + default=[], + nargs="+", + help="An optional list of specific indices in the most-recently specified set." ) - parser.add_argument("--components", - metavar="int", - type=int, - default=[], - nargs="+", - help="An optional list of specific variable components") + parser.add_argument( "--components", + metavar="int", + type=int, + default=[], + nargs="+", + help="An optional list of specific variable components" ) args = parser.parse_args() result = 0 - if not os.path.isfile(args.filename): - print(f"Error: file '{args.filename}' not found.") + if not os.path.isfile( args.filename ): + print( f"Error: file '{args.filename}' not found." ) result = -1 else: - with h5w(args.filename, mode='r') as database: + with h5w( args.filename, mode='r' ) as database: for setname in args.sets: - ds = getHistorySeries(database, args.variable, setname, args.indices, args.components) + ds = getHistorySeries( database, args.variable, setname, args.indices, args.components ) if ds is None: result = -1 break - figname = args.variable + ("_" + setname if setname is not None else "") + figname = args.variable + ( "_" + setname if setname is not None else "" ) fig, ax = plt.subplots() - ax.set_title(figname) + ax.set_title( figname ) for d in ds: - ax.plot(d[0], d[1]) - fig.savefig(figname + "_history.png") + ax.plot( d[ 0 ], d[ 1 ] ) + fig.savefig( figname + "_history.png" ) return result
SUMMARY
{self.displayName[status]}