Proof of concept for canonicalizing gufunc axis argument

In [1]:
import re

def parse_axis_one_variable(axis, var_sig):
    if isinstance(axis, int):
        return [axis]
    else:
        # TODO: validate axis based on var_sig
        return list(axis)
    
def get_nonempty_variable_signatures(gufunc_sig):
    # ignore variables with signatures like '()'
    return re.findall(r'\(([\w?+,]+)\)', gufunc_sig)

def parse_axis(gufunc_sig, axis):
    var_sigs = get_nonempty_variable_signatures(gufunc_sig)
    if isinstance(axis, int):
        return [[axis]]
    elif len(axis) == len(var_sigs):
        return [parse_axis_one_variable(ax, sig) for
                ax, sig in zip(axis, var_sigs)]
    elif len(var_sigs) == 1:
        return [parse_axis_one_variable(axis, var_sigs[0])]
    else:
        raise ValueError((axis, gufunc_sig))

test_cases = [
    ('(n)', 0, [[0]]),
    ('(n)', (0,), [[0]]),
    ('(n)', ((0,),), [[0]]),
    ('(n)', (0, 1), [[0, 1]]),
    ('(n)()', (0, 1), [[0, 1]]),
    ('()(n)', (0, 1), [[0, 1]]),
    ('(n)(m)', (0, 1), [[0], [1]]),
    ('(n)(m)', ((0,), (1,)), [[0], [1]]),
    ('(n,m)', (0, 1), [[0, 1]]),
    ('(n,m)', ((0, 1),), [[0, 1]]),
    ('(m,n),(n,p)', (0, 1), [[0], [1]]),
    ('(m,n),(n,p)', ((0, 1), 2), [[0, 1], [2]]),
    ('(m,n),(n,p)', (0, (1, 2)), [[0], [1, 2]]),
    ('(m,n),(n,p)', ((0, 1), (2, 3)), [[0, 1], [2, 3]]),
]

for sig, axis, canonical in test_cases:
    assert parse_axis(sig, axis) == canonical, (sig, axis, parse_axis(sig, axis))
In [2]:
import pandas as pd
print pd.DataFrame(test_cases, columns=['gufunc_sig', 'input axis', 'canoncial axis'])
     gufunc_sig        input axis    canoncial axis
0           (n)                 0             [[0]]
1           (n)              (0,)             [[0]]
2           (n)           ((0,),)             [[0]]
3           (n)            (0, 1)          [[0, 1]]
4         (n)()            (0, 1)          [[0, 1]]
5         ()(n)            (0, 1)          [[0, 1]]
6        (n)(m)            (0, 1)        [[0], [1]]
7        (n)(m)      ((0,), (1,))        [[0], [1]]
8         (n,m)            (0, 1)          [[0, 1]]
9         (n,m)         ((0, 1),)          [[0, 1]]
10  (m,n),(n,p)            (0, 1)        [[0], [1]]
11  (m,n),(n,p)       ((0, 1), 2)     [[0, 1], [2]]
12  (m,n),(n,p)       (0, (1, 2))     [[0], [1, 2]]
13  (m,n),(n,p)  ((0, 1), (2, 3))  [[0, 1], [2, 3]]

Notes:

  • For readability, the canonical form is printed in the form of lists, though arguably a tuple would make more sense.
  • This proof of concept doesn't even parse the terms in the gufunc signature specific to particular variables, so there's no need to worry about optional axes.