Using tricky key functions for sort/min/max in Python

At the moment I am rather enamoured with key functions in Python. The functions sorted, min and max take an optional key argument that, if specified, should be a function that is applied to each item in the iterable. The result of this key function is what is used when comparing items in the list to order them.

Here is an example simplified from my work. I have a list of 'sentences'. Each sentence is essentially a sentence of a forecast, describing a weather element or type, such as precipitation (precip), thunderstorms (TS), sky, fog (FG), large hail (H), gusty winds (w). Each sentence has a time range associated with it (TS may only be present in the late evening, for example). So what order should things be reported in?

  • In general, things should be reported in the order that they occur, so based on the start time. If the start times are equal, then compare the end times.
  • If the start and end times are equal, report them in the same order as a reference order (e.g. [sky, FG, precip, TS, w, H])
    • But also, make sure the TS are reported after the precip, if there is precip. (This is a bit of business logic: because the TS and precip are related phenomena and generally occur together, it could be jarring or absurd to those who know about such things to see TS mentioned before precip, even if the TS is being reported as starting earlier.)
  • Also also, some items like H only occur in the presence of other elements like TS. In this case their sentence will actually be something like Thunderstorms possibly severe in the afternoon with large hail which emphasises their relationship to the TS. In this case we need to mention the H directly after the TS regardless of the time ranges.

Complicated rules like this make it tempting to give up on using the built-in methods at all and just use a hand-crafted method which picks apart the elements and painstakingly pieces them back together according to requirements. However I would argue that is more error-prone, difficult to debug, hard to modify, and that with a couple of tricks we can keep using the built-in methods and still have a clue what is going on.

I will also add, don't overlook the benefit of using these with min and max, where you may want to find the "best" result amongst a set according to cascading criteria. I use this frequently at work in combination with exhaustive search (aka brute-force search or "generate and test"). If the combinatorics are not going to explode on you too badly, I think it's a good way of knowing you have arrived at the "best" result.

OK, so our data looks something like this:

In [99]:
from collections import namedtuple

class Sentence(namedtuple('Sentence', 'type start end words')):
    def __repr__(self):
        return 'Sentence({0.type}, {0.start}-{0.end})'.format(self)

priorityOrder = ['sky', 'FG', 'precip', 'TS', 'w', 'H']

precip = Sentence('precip', 12, 24, 'Isolated showers from midday, becoming widespread in the evening')
winds = Sentence('w', 12, 24, 'Gusty winds in the early afternoon')
ts = Sentence('TS', 9, 24, 'Scattered thunderstorms from the late morning')
hail = Sentence('H', 15, 21, 'Thunderstorms possibly severe in the afternoon with large hail')
sky = Sentence('sky', 0, 0, 'Partly cloudy')

sentences = set([winds, ts, hail, precip, sky])

expected = [sky, precip, ts, hail, winds]

OK so let's have a crack at a really simple key function. We want to sort on start time, then end time.

In [100]:
def byTime(sentence):
    return (sentence.start, sentence.end)

# Or more briefly:
byTime = lambda s: (s.start, s.end)

from pprint import pprint
pprint(sorted(sentences, key=byTime))
[Sentence(sky, 0-0),
 Sentence(TS, 9-24),
 Sentence(precip, 12-24),
 Sentence(w, 12-24),
 Sentence(H, 15-21)]

Already we are using the first feature of Python sorting - tuple sorting. Python sorts tuples by comparing all first elements, then if they are equal, second elements, and so on. Hence our winds (w) sentence is coming before the H, because 12 < 15. It seems pretty obvious, but this is the basis for making quite more complicated key functions.

However we haven't incorporated the requirement to sort by priority, if the time ranges are equal. It happens that 'w' is sorting after 'precip' in the above example but we haven't enforced it, which we can confirm directly:

In [101]:
byTime(precip) < byTime(winds)
Out[101]:
False

So let's add another element to our key function's return value to account for priority.

In [102]:
def byTimeAndPriority(sentence):
    priority = priorityOrder.index(sentence.type)
    return (sentence.start, sentence.end, priority)

pprint(sorted(sentences, key=byTimeAndPriority))

byTimeAndPriority(precip) < byTimeAndPriority(winds)
[Sentence(sky, 0-0),
 Sentence(TS, 9-24),
 Sentence(precip, 12-24),
 Sentence(w, 12-24),
 Sentence(H, 15-21)]

Out[102]:
True

Also to confirm that what is happening is what we expect, we can print out the value of calling the key function on each item in the iterable. This is super useful in debugging your sort and is a major benefit over the old method of influencing sort by writing a cmp method (that takes two items from an iterable, and returns -1/0/1 according to which should be ordered first). I can't praise this highly enough. My workmate recently rewrote a cmp function that was "mostly" right (read: I couldn't figure out how to fix it) into a key function that I have the highest confidence in - because of this exact reason.

In [103]:
for item in sorted(sentences, key=byTimeAndPriority):
    print byTimeAndPriority(item), "<=", item
(0, 0, 0) <= Sentence(sky, 0-0)
(9, 24, 3) <= Sentence(TS, 9-24)
(12, 24, 2) <= Sentence(precip, 12-24)
(12, 24, 4) <= Sentence(w, 12-24)
(15, 21, 5) <= Sentence(H, 15-21)

Now I have confidence in our sorting!

To return to the trickier requirements. The TS is being listed ahead of the precip. What can we do about that?

Well, maybe a useful thing to do is try and come up with the results of a key function that would give us what we want, and work backwards.

(0, 0, 0, ...), Sentence(sky, 0-0)
(12, 24, 2, ...) Sentence(precip, 12-24)
(?, ?, ?, ...) Sentence(TS, 9-24)
(?, ?, ?, ...) Sentence(H, 15-21)
(12, 24, 4...) Sentence(w, 12-24)

So the thing that jumps out here is that the TS and H need to have the same values as the precip, so that they are sorted next to each other. But in our key function we only have one element of the iterable at a time. If the TS Sentence is passed to the key function we wouldn't generally have access to information about other elements in the iterable. So what we can we do? Well making a closure can help us out...

In [104]:
def generateKeyFn(items):
    precipStart = None
    precipEnd = None
    precipPriority = priorityOrder.index('precip')
    precipSentences = [item for item in items if item.type == 'precip']
    if precipSentences:
        precipStart = min(item.start for item in precipSentences)
        precipEnd = max(item.end for item in precipSentences)
        
    mustFollowPrecip = ('TS', 'H')
    
    def keyFn(item):
        if precipStart and item.type in mustFollowPrecip:
            return (precipStart, precipEnd, precipPriority)
        priority = priorityOrder.index(item.type)
        return (item.start, item.end, priority)

    return keyFn


bySpecialTime = generateKeyFn(sentences)
for item in sorted(sentences, key=bySpecialTime):
    print bySpecialTime(item), "<=", item
(0, 0, 0) <= Sentence(sky, 0-0)
(12, 24, 2) <= Sentence(precip, 12-24)
(12, 24, 2) <= Sentence(H, 15-21)
(12, 24, 2) <= Sentence(TS, 9-24)
(12, 24, 4) <= Sentence(w, 12-24)

This is a good start! Our precip/TS/H are all together. But we need to add extra elements to our key function to sort them correctly.

In [105]:
def generateKeyFn(items):
    precipStart, precipEnd = None, None
    precipPriority = priorityOrder.index('precip')
    precipSentences = [item for item in items if item.type == 'precip']
    if precipSentences:
        precipStart = min(item.start for item in precipSentences)
        precipEnd = max(item.end for item in precipSentences)
    
    mustFollowPrecip = ('TS', 'H')
    
    def keyFn(item):
        start, end = item.start, item.end
        priority = priorityOrder.index(item.type)
        mustFollow = item.type in mustFollowPrecip
        if precipStart and mustFollow:
            start, end, priority = precipStart, precipEnd, precipPriority
        return (start, end, priority, mustFollow)

    return keyFn


bySpecialTime = generateKeyFn(sentences)
for item in sorted(sentences, key=bySpecialTime):
    print bySpecialTime(item), "<=", item
(0, 0, 0, False) <= Sentence(sky, 0-0)
(12, 24, 2, False) <= Sentence(precip, 12-24)
(12, 24, 2, True) <= Sentence(H, 15-21)
(12, 24, 2, True) <= Sentence(TS, 9-24)
(12, 24, 4, False) <= Sentence(w, 12-24)

Here we are relying on Boolean sorting - False < True. The value of this fourth field is irrelevant to the non-precip/TS/H sentences as these are already being sorted on the first three fields.

Now TS is following precip, but we also need the H to follow the TS.

In [106]:
def generateKeyFn(items):
    precipStart, precipEnd = None, None
    precipPriority = priorityOrder.index('precip')
    precipSentences = [item for item in items if item.type == 'precip']
    if precipSentences:
        precipStart = min(item.start for item in precipSentences)
        precipEnd = max(item.end for item in precipSentences)
    
    typesMustFollowPrecip = ('TS', 'H')
    typesMustFollowTS = ('H',)
    
    def keyFn(item):
        start, end = item.start, item.end
        priority = priorityOrder.index(item.type)
        mustFollowTS = item.type in typesMustFollowTS
        mustFollowPrecip = item.type in typesMustFollowPrecip
        if precipStart and mustFollowPrecip:
            start, end, priority = precipStart, precipEnd, precipPriority
        return (start, end, priority, mustFollowPrecip, mustFollowTS)

    return keyFn


bySpecialTime = generateKeyFn(sentences)
for item in sorted(sentences, key=bySpecialTime):
    print bySpecialTime(item), "<=", item

sorted(sentences, key=bySpecialTime) == expected
(0, 0, 0, False, False) <= Sentence(sky, 0-0)
(12, 24, 2, False, False) <= Sentence(precip, 12-24)
(12, 24, 2, True, False) <= Sentence(TS, 9-24)
(12, 24, 2, True, True) <= Sentence(H, 15-21)
(12, 24, 4, False, False) <= Sentence(w, 12-24)

Out[106]:
True

Success! The fifth field is forcing the H after the TS.

Now this key function is not totally correct. There is one immediate error. If we have TS and H but no precip, we haven't done anything to ensure that the H will still follow the precip. Witness:

In [107]:
sentencesNoPrecip = set([winds, ts, hail, sky])
sortKey = generateKeyFn(sentencesNoPrecip)
for item in sorted(sentencesNoPrecip, key=sortKey):
    print sortKey(item), "<=", item
(0, 0, 0, False, False) <= Sentence(sky, 0-0)
(9, 24, 3, True, False) <= Sentence(TS, 9-24)
(12, 24, 4, False, False) <= Sentence(w, 12-24)
(15, 21, 5, True, True) <= Sentence(H, 15-21)

Note that we had to generate a fresh key function because our input data was different.

Correcting generateKeyFn is left as an exercise to the reader. :) Here are some test cases, though. (Written for py.test, likely to work with nose although untested)

In [108]:
import py


def test_sortingTime():
    input = [
        Sentence('precip', 0, 12, 'rain in the morning'),
        Sentence('sky', 12, 24, 'sunny afternoon'),
        ]
    expected = ['precip', 'sky']
    sortKey = generateKeyFn(input)
    assert [s.type for s in sorted(input, key=sortKey)] == expected


def test_sortingPriority():
    input = [
        Sentence('sky', 0, 24, 'cloudy'),
        Sentence('precip', 0, 24, 'rain'),
        ]
    expected = ['precip', 'sky']
    sortKey = generateKeyFn(input)
    assert [s.type for s in sorted(input, key=sortKey)] == expected


def test_sortingTSAfterPrecip():
    input = [
        Sentence('TS', 0, 12, 'thunderstorms'),
        Sentence('precip', 12, 24, 'rain'),
        ]
    expected = ['precip', 'TS']
    sortKey = generateKeyFn(input)
    assert [s.type for s in sorted(input, key=sortKey)] == expected

    
def test_sortingHailAfterTSWithPrecip():
    input = [
        Sentence('precip', 12, 24, 'rain'),
        Sentence('TS', 0, 12, 'thunderstorms'),
        Sentence('H', 0, 6, 'hail'),
        ]
    expected = ['precip', 'TS', 'H']
    sortKey = generateKeyFn(input)
    assert [s.type for s in sorted(input, key=sortKey)] == expected


def test_sortingHailAfterTSNoPrecip():
    input = [
        Sentence('TS', 0, 12, 'thunderstorms'),
        Sentence('H', 0, 6, 'hail'),
        ]
    expected = ['TS', 'H']
    sortKey = generateKeyFn(input)
    assert [s.type for s in sorted(input, key=sortKey)] == expected

And as expected, this last test fails when I run the tests:

In [109]:
$py.test test_sort.py -v
=============================================== test session starts ===============================================
platform linux2 -- Python 2.7.5 -- pytest-2.3.5 -- /usr/bin/python
plugins: xdist, cov
collected 5 items 

test_sort.py:34: test_sortingTime PASSED
test_sort.py:44: test_sortingPriority PASSED
test_sort.py:54: test_sortingTSAfterPrecip PASSED
test_sort.py:64: test_sortingHailAfterTSWithPrecip PASSED
test_sort.py:75: test_sortingHailAfterTSNoPrecip FAILED

==================================================== FAILURES =====================================================
_________________________________________ test_sortingHailAfterTSNoPrecip _________________________________________
test_sort.py:82: in test_sortingHailAfterTSNoPrecip
>       assert [s.type for s in sorted(input, key=sortKey)] == expected
E       assert ['H', 'TS'] == ['TS', 'H']
E         At index 0 diff: 'H' != 'TS'
======================================= 1 failed, 4 passed in 0.02 seconds =======================================

Written with thanks to this post for a reminder to blog every now and then.

Reference reading: