import re
def parse_axis_one_variable(axis, var_sig):
if isinstance(axis, int):
return [axis]
else:
# TODO: validate axis based on var_sig
return list(axis)
def get_nonempty_variable_signatures(gufunc_sig):
# ignore variables with signatures like '()'
return re.findall(r'\(([\w?+,]+)\)', gufunc_sig)
def parse_axis(gufunc_sig, axis):
var_sigs = get_nonempty_variable_signatures(gufunc_sig)
if isinstance(axis, int):
return [[axis]]
elif len(axis) == len(var_sigs):
return [parse_axis_one_variable(ax, sig) for
ax, sig in zip(axis, var_sigs)]
elif len(var_sigs) == 1:
return [parse_axis_one_variable(axis, var_sigs[0])]
else:
raise ValueError((axis, gufunc_sig))
test_cases = [
('(n)', 0, [[0]]),
('(n)', (0,), [[0]]),
('(n)', ((0,),), [[0]]),
('(n)', (0, 1), [[0, 1]]),
('(n)()', (0, 1), [[0, 1]]),
('()(n)', (0, 1), [[0, 1]]),
('(n)(m)', (0, 1), [[0], [1]]),
('(n)(m)', ((0,), (1,)), [[0], [1]]),
('(n,m)', (0, 1), [[0, 1]]),
('(n,m)', ((0, 1),), [[0, 1]]),
('(m,n),(n,p)', (0, 1), [[0], [1]]),
('(m,n),(n,p)', ((0, 1), 2), [[0, 1], [2]]),
('(m,n),(n,p)', (0, (1, 2)), [[0], [1, 2]]),
('(m,n),(n,p)', ((0, 1), (2, 3)), [[0, 1], [2, 3]]),
]
for sig, axis, canonical in test_cases:
assert parse_axis(sig, axis) == canonical, (sig, axis, parse_axis(sig, axis))