Skip to content

Commit

Permalink
minor update
Browse files Browse the repository at this point in the history
  • Loading branch information
wi-re committed Jul 23, 2024
1 parent 79f2837 commit 4143102
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 6 deletions.
3 changes: 2 additions & 1 deletion src/BasisConvolution/util/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,6 @@
parser.add_argument('--shiftLoss', type = bool, default = argparse.SUPPRESS, action=argparse.BooleanOptionalAction, help='Shifting the loop')
parser.add_argument('--skipLastShift', type = bool, default = argparse.SUPPRESS, action=argparse.BooleanOptionalAction, help='Shifting the loop')
parser.add_argument('--scaleShiftLoss', type = bool, default = argparse.SUPPRESS, action=argparse.BooleanOptionalAction, help='Shifting the loop')
parser.add_argument('--integrationScheme', type = str, default = argparse.SUPPRESS, help='Integration scheme')

parser.add_argument('--exportPath', type = str, default = argparse.SUPPRESS, help='Export path')
parser.add_argument('--exportPath', type = str, default = argparse.SUPPRESS, help='Export path')
9 changes: 6 additions & 3 deletions src/BasisConvolution/util/hyperparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def defaultHyperParameters():
'inputEncoder': None,
'outputDecoder': None,
'edgeMLP': None,
'vertexMLP': None
'vertexMLP': None,
'integrationScheme': 'semiImplicitEuler',
}
return hyperParameterDict

Expand Down Expand Up @@ -135,6 +136,7 @@ def parseArguments(args, hyperParameterDict):
hyperParameterDict['scaleShiftLoss'] = args.scaleShiftLoss if hasattr(args, 'scaleShiftLoss') else hyperParameterDict['scaleShiftLoss']
hyperParameterDict['activation'] = args.activation if hasattr(args, 'activation') else hyperParameterDict['activation']
hyperParameterDict['exportPath'] = args.exportPath if hasattr(args, 'exportPath') else hyperParameterDict['exportPath']
hyperParameterDict['integrationScheme'] = args.integrationScheme if hasattr(args, 'integrationScheme') else hyperParameterDict['integrationScheme']

hyperParameterDict['device'] = args.device if hasattr(args, 'device') else hyperParameterDict['device']
# hyperParameterDict['dtype'] = torch.
Expand Down Expand Up @@ -162,7 +164,7 @@ def parseArguments(args, hyperParameterDict):
'gain': 1,
'norm': True,
'layout': [32],
'output': 1,
# 'output': 1,
'preNorm': False,
'postNorm': True,
'noLinear': True,
Expand Down Expand Up @@ -263,6 +265,7 @@ def parseConfig(config, hyperParameterDict):
parseEntry(cfg, 'shifting', 'networkType', hyperParameterDict, 'networkType')
parseEntry(cfg, 'shifting', 'shiftLoss', hyperParameterDict, 'shiftLoss')
parseEntry(cfg, 'shifting', 'scaleShiftLoss', hyperParameterDict, 'scaleShiftLoss')
parseEntry(cfg, 'shifting', 'integrationScheme', hyperParameterDict, 'integrationScheme')
parseEntry(cfg, 'dataset', 'dataIndex', hyperParameterDict, 'dataIndex')
parseEntry(cfg, 'shifting', 'skipLastShift', hyperParameterDict, 'skipLastShift')
parseEntry(cfg, 'loss', 'dxdtLossScaling', hyperParameterDict, 'dxdtLossScaling')
Expand Down Expand Up @@ -428,7 +431,7 @@ def toPandaDict(hyperParameterDict):
'skipLastShift': hyperParameterDict['skipLastShift'],
'dxdtLossScaling': hyperParameterDict['dxdtLossScaling'],
'scaleShiftLoss': hyperParameterDict['scaleShiftLoss'] if 'scaleShiftLoss' in hyperParameterDict else False,

'integrationScheme': hyperParameterDict['integrationScheme'],
'inputEncoder': True if hyperParameterDict['inputEncoder'] is not None else False,
'outputDecoder': True if hyperParameterDict['outputDecoder'] is not None else False,

Expand Down
9 changes: 7 additions & 2 deletions src/BasisConvolution/util/testcases.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,12 +522,13 @@ def loadFrame_newFormat(inFile, fileName, key, fileData, fileIndex, fileOffset,
iPriorKey = int(key) - hyperParameterDict['frameDistance']

priorState = None
if buildPriorState:
if buildPriorState or hyperParameterDict['adjustForFrameDistance']:
if iPriorKey < 0 or hyperParameterDict['frameDistance'] == 0:
priorState = copy.deepcopy(state)
else:
priorState = loadGroup_newFormat(inFile, inFile['simulationExport']['%05d' % iPriorKey], staticBoundaryData, fileName, iPriorKey, fileData, fileIndex, fileOffset, dataset, hyperParameterDict, unrollLength = unrollLength, device = device, dtype = dtype, additionalData = additionalData, buildPriorState = False, buildNextState = False)



nextStates = []
if buildNextState:
if unrollLength == 0 and hyperParameterDict['frameDistance'] == 0:
Expand All @@ -543,6 +544,10 @@ def loadFrame_newFormat(inFile, fileName, key, fileData, fileIndex, fileOffset,
nextState = loadGroup_newFormat(inFile, inFile['simulationExport']['%05d' % unrollKey], staticBoundaryData, fileName, iPriorKey, fileData, fileIndex, fileOffset, dataset, hyperParameterDict, unrollLength = unrollLength, device = device, dtype = dtype, additionalData = additionalData, buildPriorState = False, buildNextState = False)
nextStates.append(nextState)

# if hyperParameterDict['adjustForFrameDistance']:






Expand Down

0 comments on commit 4143102

Please sign in to comment.