Skip to content

Commit

Permalink
Fix issue #127
Browse files Browse the repository at this point in the history
  • Loading branch information
robclewley committed Jan 3, 2021
1 parent 45fc2f1 commit 308e8ea
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 54 deletions.
62 changes: 9 additions & 53 deletions PyDSTool/core/codegenerators/c.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
# -*- coding: utf-8 -*-


from PyDSTool.common import invertMap, intersect, concatStrDict, sortedDictItems, isUniqueSeq
from PyDSTool.parseUtils import convertPowers, parseMatrixStrToDictStr, addArgToCalls, wrapArgInCall, splitargs, findEndBrace
from PyDSTool.common import invertMap, intersect, concatStrDict, \
sortedDictItems, isUniqueSeq
from PyDSTool.parseUtils import convertPowers, parseMatrixStrToDictStr, \
addArgToCalls, wrapArgInCall, splitargs, findEndBrace, \
add_arglen_to_fn_names, paren_contents
from PyDSTool.Symbolic import QuantSpec
from PyDSTool.utils import compareList, info

Expand Down Expand Up @@ -653,57 +656,10 @@ def _processSpecialC(self, specStr):
'True': 1, 'False': 0, 'if': '__rhs_if',
'max': '__maxof', 'min': '__minof'})
qtoks = qspec.parser.tokenized
# default value
new_specStr = str(qspec)
# NOTE: This simple iterative parsing of the arguments means that
# user cannot nest calls to min() or max() with eachother
if '__minof' in qtoks:
new_specStr = ""
num = qtoks.count('__minof')
n_ix = -1
ix_continue = 0
for _ in range(num):
n_ix = qtoks[n_ix + 1:].index('__minof') + n_ix + 1
new_specStr += "".join(qtoks[ix_continue:n_ix])
rbrace_ix = findEndBrace(qtoks[n_ix + 1:]) + n_ix + 1
ix_continue = rbrace_ix + 1
#assert qtoks[n_ix+2] == '[', "Error in min() syntax"
#assert qtoks[rbrace_ix-1] == ']', "Error in min() syntax"
#new_specStr += "".join(qtoks[n_ix+3:rbrace_ix-1]) + ")"
num_args = qtoks[n_ix + 2:ix_continue].count(',') + 1
if num_args > 4:
raise NotImplementedError(
"Max of more than 4 arguments not currently supported in C")
new_specStr += '__minof%s(' % str(num_args)
new_specStr += "".join(
[q for q in qtoks[n_ix + 2:ix_continue] if q not in ('[', ']')])
new_specStr += "".join(qtoks[ix_continue:])
qspec = QuantSpec('spec', new_specStr)
qtoks = qspec.parser.tokenized
if '__maxof' in qtoks:
new_specStr = ""
num = qtoks.count('__maxof')
n_ix = -1
ix_continue = 0
for _ in range(num):
n_ix = qtoks[n_ix + 1:].index('__maxof') + n_ix + 1
new_specStr += "".join(qtoks[ix_continue:n_ix])
rbrace_ix = findEndBrace(qtoks[n_ix + 1:]) + n_ix + 1
ix_continue = rbrace_ix + 1
#assert qtoks[n_ix+2] == '[', "Error in max() syntax"
#assert qtoks[rbrace_ix-1] == ']', "Error in max() syntax"
#new_specStr += "".join(qtoks[n_ix+3:rbrace_ix-1]) + ")"
num_args = qtoks[n_ix + 2:ix_continue].count(',') + 1
if num_args > 4:
raise NotImplementedError(
"Min of more than 4 arguments not currently supported in C")
new_specStr += '__maxof%s(' % str(num_args)
new_specStr += "".join(
[q for q in qtoks[n_ix + 2:ix_continue] if q not in ('[', ']')])
new_specStr += "".join(qtoks[ix_continue:])
qspec = QuantSpec('spec', new_specStr)
qtoks = qspec.parser.tokenized
return new_specStr
pc_info_list = list(paren_contents(qtoks))
for fname in ('__maxof', '__minof'):
qtoks = add_arglen_to_fn_names(qtoks, fname, pc_info_list)
return "".join(qtoks)

def _format_user_code(self, code):
before = '/* Verbose code insert -- begin */'
Expand Down
90 changes: 89 additions & 1 deletion PyDSTool/parseUtils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1948,6 +1948,8 @@ def readArgs(argstr, lbchar='(', rbchar=')'):
Returns a triple: [success_boolean, list of arguments, number of args]"""
bracetest = argstr[0] == lbchar and argstr[-1] == rbchar
if not bracetest:
raise ValueError("argument string must begin and end with specified braces")
rest = argstr[1:-1].replace(" ","")
pieces = []
while True:
Expand Down Expand Up @@ -1985,7 +1987,8 @@ def findEndBrace(s, lbchar='(', rbchar=')'):
'(' and ')'. Change them with the optional second and third arguments.
"""
pos = 0
assert s[0] == lbchar, 'string argument must begin with left brace'
if s[0] != lbchar:
raise ValueError('string argument must begin with left brace')
stemp = s
leftbrace_count = 0
notDone = True
Expand Down Expand Up @@ -2118,6 +2121,30 @@ def wrapArgInCall(source, callfn, wrapL, wrapR=None, argnums=[0],
output += source[currpos:]
return output

def paren_contents(seq, lbr='(', rbr=')'):
"""Generate info about parenthesized contents in sequence as
(unique tree position key,
tree depth,
seq index of opening paren,
seq index of closing paren).
"""
past_outkeys = []
key_stack = []
stack = []
for i, c in enumerate(seq):
if c == lbr:
if len(past_outkeys) == 0:
key_stack = key_stack + [0]
else:
if key_stack in past_outkeys:
key_stack = key_stack[:-1] + [key_stack[-1] + 1]
else:
key_stack = key_stack + [0]
stack.append((i, key_stack))
elif c == rbr and stack:
start, key_stack = stack.pop()
past_outkeys.append(key_stack.copy())
yield (tuple(key_stack), len(stack), start, i)

##def replaceCallsWithDummies(source, callfns, used_dummies=None, notFirst=False):
## """Replace all function calls in source with dummy names,
Expand Down Expand Up @@ -2289,8 +2316,69 @@ def replaceCallsWithDummies(source, callfns, used_dummies=None, notFirst=False):
return output, dummies


def add_arglen_to_fn_names(qtoks, fname, pc_info_list=None, max_args=4):
"""Add the number of args to instances of the given function's calls to the
name of the function. This is used for python -> C code generation where
the C equivalents of python functions such as min() or max() must take a
fixed number of arguments known at compile time.
See core/codegenerators/c.py for usage.
qtoks is a list of tokens (e.g. created by QuantSpec.parser.tokenized).
pc_info_list is the optional pre-computed output from paren_contents
converted to a list (to avoid re-computing for several calls to this
function).
max_args (default 4) is the maximum number of support arguments for this
parse scheme.
"""
if pc_info_list is None:
# convert generator to list for repeated iterations
pc_info_list = list(paren_contents(qtoks))
rep_ix = list(np.where(np.array(qtoks) == fname)[0])
if len(rep_ix) != 0:
for key, depth, i, j in pc_info_list:
if i-1 in rep_ix:
temp_qtoks = qtoks[:]
# matching function name
# get args for this call between parens
# blocking out any interior calls at higher depth
for key2, depth2, i2, j2 in pc_info_list:
if depth2 == depth + 1 and key2[:len(key)] == key:
temp_qtoks = temp_qtoks[:i2] + ['_']*(j2-i2) + temp_qtoks[j2:]
num_args = temp_qtoks[i:j].count(',') + 1
if num_args > max_args:
raise ValueError("Too many arguments found for this to parse")
qtoks = qtoks[:i-1] + [fname+str(num_args)] + qtoks[i:]
rep_ix.remove(i-1)
if rep_ix == []:
break # for
return qtoks


def add_args_to_calls_tok(qtoks, callfns, new_arg_list):
"""qtoks is a list of tokens (e.g. created by QuantSpec.parser.tokenized).
callfns is a list of function names, for which the additional args will be
added to call instances.
new_arg_list is a list of string tokens to add to calls (no commas).
"""
for fname in callfns:
while True:
rep_ix = np.where(np.array(qtoks) == fname)[0]
if len(rep_ix) == 0:
break # while
else:
for key, depth, i, j in paren_contents(qtoks):
if i-1 in rep_ix:
# matching function name
# get args for this call between parens
qtoks = qtoks[:j] + \
[',' + arg for arg in new_arg_list] + qtoks[j:]
break # for


def addArgToCalls(source, callfns, arg, notFirst=''):
"""Add an argument to calls in source, to the functions listed in callfns.
This is a legacy function. A better version for future use is
add_args_to_calls_tok.
"""
# This function used to work on lists of callfns directly, but I can't
# see why it stopped working. So I just added this recursing part at
Expand Down

0 comments on commit 308e8ea

Please sign in to comment.