The goal of this activity was to speed up the permutation test. I wrote two different implementations. The "naive" implementation uses the method presented in the segment that expands the $N$ counts in the table to an $N \times 2$ matrix, shuffles the second column, then reassembles it into a new table. The "hypergeometric" implementation fills in cells in a new table by sampling from the hypergeometric distribution. It makes use of numpy's questionable implementation of the questionable hypergeometric sampling algorithm.
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats
np.random.seed(42)
# In class R/Python/Matlab A/B/C example.
language_grade_table = np.array([[4, 1, 14],
[10, 6, 14],
[51, 35, 60]])
language_grade_table
array([[ 4, 1, 14], [10, 6, 14], [51, 35, 60]])
# Drinking and pregnancy.
alcohol_table = np.array([[17066, 14464, 788, 126, 37],
[ 48, 38, 5, 1, 1]])
alcohol_table
array([[17066, 14464, 788, 126, 37], [ 48, 38, 5, 1, 1]])
class BasePermutationTest(object):
"""Base class for permutation tests.
Subclasses need to override the `permute` method.
"""
def __init__(self, table):
self.table = table.copy()
self.table_stat = self.pearson(self.table)
self.permutation_stats = None
def compute_permutation_stats(self, num_stats=10000):
"""Compute `num_stats` pearson stats from random permutations."""
self.permutation_stats = np.array([self.pearson(self.permute())
for _ in xrange(num_stats)])
def compute_pvalue(self):
"""Compute the p-value after computing the stats."""
return np.mean(self.permutation_stats >= self.table_stat)
def pearson(self, table):
"""Pearson statistic."""
expected = scipy.stats.contingency.expected_freq(table)
return np.sum((table - expected)**2 / expected)
def permute(self):
"""Returns a random permutation of the table."""
raise NotImplementedError()
class NaivePermutationTest(BasePermutationTest):
"""Naive/slow implementation of permutation tests.
This could easily be optimized a bit with matrix operations, etc...
"""
def permute(self):
return self._rebuild(self._shuffle(self._expand(self.table)))
def _expand(self, table):
# Expand all the counts into an Nx2 matrix.
items = []
for row in xrange(table.shape[0]):
for col in xrange(table.shape[1]):
items.extend([[row, col]] * table[row, col])
expanded = np.array(items)
#assert expanded.shape == (table.sum(), 2)
return expanded
def _shuffle(self, expanded_table):
# Shuffle the second column of the expanded table.
expanded_table[:, 1] = np.random.permutation(expanded_table[:, 1])
return expanded_table
def _rebuild(self, expanded_table):
# Rebuild the expanded table into a normal table.
table = np.zeros(np.max(expanded_table, axis=0) + 1, dtype=int)
for item in expanded_table:
table[tuple(item)] += 1
#assert (table.sum(), 2) == expanded_table.shape
return table
class HypergeometricPermutationTest(BasePermutationTest):
"""Hypergeometric implementation of permutation tests.
See the comments in permute() for details.
"""
def __init__(self, table, debug=False):
super(HypergeometricPermutationTest, self).__init__(table)
self.debug = debug
def permute(self):
table = self.table
nrow, ncol = table.shape
self._debug('Original table:')
self._debug(table)
# `permuted` starts out as a matrix of zeros with the same
# dimensions as the original table. `row_counts` and `col_counts`
# start off containing the row and column marginals. For example,
# if the original table is:
# 4 1 14
# 10 6 14
# 51 35 60
#
# then we would start out as:
#
# 0 0 0 | 19
# 0 0 0 | 30
# 0 0 0 | 146
# -----------
# 65 42 88
#
# As counts are placed in `permuted`, they are subtracted from
# the row and column counts in order to keep track of how many
# items still need to be placed in the table.
permuted = np.zeros((nrow, ncol), dtype=int)
col_counts = np.sum(table, axis=0)
row_counts = np.sum(table, axis=1)
def print_state():
# Print the current table and the row/column counts.
if self.debug:
m = np.zeros((nrow + 1, ncol + 1), dtype=int)
m[:-1, :-1] = permuted
m[-1, :ncol] = col_counts
m[:nrow, -1] = row_counts
self._debug(m)
self._debug('Original state:')
print_state()
# With fixed marginals, the table has f = (nrow - 1) x (ncol - 1)
# degrees of freedom. Loop through the f cells at the top left of
# the table and fill them in by sampling from the hypergeometric
# distribution.
for row in xrange(nrow - 1):
for col in xrange(ncol - 1):
row_count = row_counts
self._debug('Filling in ({0},{1})'.format(row, col))
# Compute the cell value.
n_col = col_counts[col]
n_other_cols = col_counts[(col + 1):].sum()
n_row = row_counts[row]
if n_row > 0:
self._debug('hyper({}, {}, {})'.format(n_col, n_other_cols, n_row))
n_cell = np.random.hypergeometric(n_col, n_other_cols, n_row)
else:
# This row is already full.
n_cell = 0
# Fill in the cell and update the counts.
permuted[row, col] = n_cell
row_counts[row] -= n_cell
col_counts[col] -= n_cell
print_state()
self._debug('Filling in the first {0} values in the last row.'.format(ncol - 1))
permuted[-1, :-1] += col_counts[:-1]
row_counts[-1] -= col_counts[:-1].sum()
col_counts[:-1] = 0
print_state()
self._debug('Filling in the first {0} values in the last column.'.format(nrow - 1))
permuted[:-1, -1] += row_counts[:-1]
col_counts[-1] -= row_counts[:-1].sum()
row_counts[:-1] = 0
print_state()
self._debug('Filling in the final cell at the bottom right corner')
assert row_counts[-1] == col_counts[-1]
permuted[-1, -1] = col_counts[-1]
row_counts[-1] = 0
col_counts[-1] = 0
print_state()
return permuted
def _debug(self, message):
if self.debug:
print message
def sanity_check_hyper(table):
"""Do some sanity checks of the hypergeometric permutation maker."""
col_marginals = np.sum(table, axis=0)
row_marginals = np.sum(table, axis=1)
hyper = HypergeometricPermutationTest(table)
for _ in xrange(100):
permuted = hyper.permute()
assert np.all(permuted >= 0)
assert np.all(np.sum(permuted, axis=0) == col_marginals)
assert np.all(np.sum(permuted, axis=1) == row_marginals)
sanity_check_hyper(np.arange(9).reshape(3, 3))
sanity_check_hyper(np.arange(8).reshape(2, 4))
sanity_check_hyper(language_grade_table)
sanity_check_hyper(alcohol_table)
# Let's see one example hypergeometric permutation in detail.
HypergeometricPermutationTest(language_grade_table, debug=True).permute()
Original table: [[ 4 1 14] [10 6 14] [51 35 60]] Original state: [[ 0 0 0 19] [ 0 0 0 30] [ 0 0 0 146] [ 65 42 88 0]] Filling in (0,0) hyper(65, 130, 19) [[ 4 0 0 15] [ 0 0 0 30] [ 0 0 0 146] [ 61 42 88 0]] Filling in (0,1) hyper(42, 88, 15) [[ 4 8 0 7] [ 0 0 0 30] [ 0 0 0 146] [ 61 34 88 0]] Filling in (1,0) hyper(61, 122, 30) [[ 4 8 0 7] [ 6 0 0 24] [ 0 0 0 146] [ 55 34 88 0]] Filling in (1,1) hyper(34, 88, 24) [[ 4 8 0 7] [ 6 9 0 15] [ 0 0 0 146] [ 55 25 88 0]] Filling in the first 2 values in the last row. [[ 4 8 0 7] [ 6 9 0 15] [55 25 0 66] [ 0 0 88 0]] Filling in the first 2 values in the last column. [[ 4 8 7 0] [ 6 9 15 0] [55 25 0 66] [ 0 0 66 0]] Filling in the final cell at the bottom right corner [[ 4 8 7 0] [ 6 9 15 0] [55 25 66 0] [ 0 0 0 0]]
array([[ 4, 8, 7], [ 6, 9, 15], [55, 25, 66]])
# And try a full permutation test with the hypergeometric implementation.
hyper = HypergeometricPermutationTest(language_grade_table)
hyper.compute_permutation_stats(num_stats=10000)
print 'p-val: {0}'.format(hyper.compute_pvalue())
plt.hist(hyper.permutation_stats, bins=50)
plt.title('Distribution of pearson statistics')
p-val: 0.1039
<matplotlib.text.Text at 0xb21d52c>
def compare(perm_class_1, perm_class_2, table, num_stats=10000):
print 'Doing {0} permutations'.format(num_stats)
print
for cls in (perm_class_1, perm_class_2):
perm_test = cls(table)
print '{0}: '.format(perm_test.__class__.__name__)
%time perm_test.compute_permutation_stats(num_stats=num_stats)
print 'pval: {0}'.format(perm_test.compute_pvalue())
print
First, let's try 100,000 permutations on the language/grade table.
language_grade_table
array([[ 4, 1, 14], [10, 6, 14], [51, 35, 60]])
compare(NaivePermutationTest, HypergeometricPermutationTest, language_grade_table)
Doing 10000 permutations NaivePermutationTest: CPU times: user 9.78 s, sys: 0 ns, total: 9.78 s Wall time: 9.86 s pval: 0.1016 HypergeometricPermutationTest: CPU times: user 3.12 s, sys: 0 ns, total: 3.12 s Wall time: 3.13 s pval: 0.1023
The hypergeometric method runs about 3 times faster on this table. Decent, but not that impressive especially considering that the naive implementation could easily be optimized.
The real benefits of the hypergeometric method start to show when the counts in the table get larger. The naive implementation takes significantly longer (and more memory) because it expands every item in the table. The hypergeometric implementation, on the other hand, doesn't slow down with more counts in the table. To see this, let's try doubling the counts in the table.
language_grade_table * 2
array([[ 8, 2, 28], [ 20, 12, 28], [102, 70, 120]])
compare(NaivePermutationTest, HypergeometricPermutationTest, language_grade_table * 2)
Doing 10000 permutations NaivePermutationTest: CPU times: user 17.6 s, sys: 4 ms, total: 17.6 s Wall time: 17.7 s pval: 0.0028 HypergeometricPermutationTest: CPU times: user 3.1 s, sys: 0 ns, total: 3.1 s Wall time: 3.1 s pval: 0.0055
Now let's try things out on the alcohol table. This should be interesting since the counts are so much larger.
alcohol_table
array([[17066, 14464, 788, 126, 37], [ 48, 38, 5, 1, 1]])
compare(NaivePermutationTest, HypergeometricPermutationTest, alcohol_table, num_stats=1000)
Doing 1000 permutations NaivePermutationTest: CPU times: user 2min 12s, sys: 184 ms, total: 2min 12s Wall time: 2min 12s pval: 0.027 HypergeometricPermutationTest: CPU times: user 324 ms, sys: 0 ns, total: 324 ms Wall time: 325 ms pval: 0.026
The naive method takes a couple minutes for 1000 permutations while the hypergeometric method completes in well under a second.
The hypergeometric method can handle 100,000 permutations without too much trouble:
hyper = HypergeometricPermutationTest(alcohol_table)
%time hyper.compute_permutation_stats(num_stats=100000)
print 'p-val: {0}'.format(hyper.compute_pvalue())
plt.hist(hyper.permutation_stats, bins=50)
plt.xlim(0, 50)
plt.title('Distribution of pearson statistics')
CPU times: user 32.8 s, sys: 12 ms, total: 32.8 s Wall time: 32.9 s p-val: 0.03467
<matplotlib.text.Text at 0xb3616ac>