import matplotlib.pyplot as plt
import numpy as np
import scipy.stats
Define the test statistics and test them on examples from the lectures.
def pearson(counts):
counts = np.asarray(counts)
expected = scipy.stats.contingency.expected_freq(counts)
return np.sum((counts - expected)**2 / expected)
assert np.allclose(pearson([[17066, 14464, 788, 126, 37], [48, 38, 5, 1, 1]]),
12.0821, atol=0.0001)
# TODO: Oops, this is 2x2 only... but the segment has a 3x3 table... ???
def wald_2x2(counts):
counts = np.asarray(counts)
m, n = counts[0, 0], counts[0, 1]
M, N = counts.sum(axis=0)
p1_hat = m / float(M)
p2_hat = n / float(N)
p_hat = (m + n) / float(M + N)
numerator = p1_hat - p2_hat
denominator = np.sqrt(p_hat * (1 - p_hat) * ((1.0 / M) + (1.0 / N)))
return numerator / denominator
assert np.allclose(wald_2x2([[8, 3], [16, 26]]), 2.0542, atol=0.0001)
Code for permuting tables and collecting test statistics from the permutations.
def expand(table):
items = []
for row in xrange(table.shape[0]):
for col in xrange(table.shape[1]):
items.extend([[row, col]] * table[row, col])
expanded = np.array(items)
assert expanded.shape == (table.sum(), 2)
return expanded
def shuffle(expanded_table):
expanded_table[:, 1] = np.random.permutation(expanded_table[:, 1])
return expanded_table
def rebuild(expanded_table):
table = np.zeros(np.max(expanded_table, axis=0) + 1, dtype=int)
for item in expanded_table:
table[tuple(item)] += 1
assert expanded_table.shape == (table.sum(), 2)
return table
def permute(table):
return rebuild(shuffle(expand(table)))
def sanity_check():
table = np.arange(9).reshape(3, 3)
# Expand and rebuild the same table.
assert np.all(rebuild(expand(table)) == table)
# Verify marginals are the same after shuffling.
marginals_before = np.array([table.sum(axis=0), table.sum(axis=1)])
permutation = permute(table)
marginals_after = np.array([permutation.sum(axis=0), permutation.sum(axis=1)])
assert np.all(marginals_before == marginals_after)
sanity_check()
def permutation_stats(table, num_permutations=100000, statistic_func=pearson):
stats = [statistic_func(permute(table)) for _ in xrange(num_permutations)]
return np.array(stats)
table = np.array([[5, 3, 2],
[2, 3, 6],
[0, 2, 3]])
table
array([[5, 3, 2], [2, 3, 6], [0, 2, 3]])
pearson(table)
5.7559917355371892
stats = permutation_stats(table, statistic_func=pearson)
plt.hist(stats, bins=np.arange(20), histtype='step', fill=True, normed=True)
plt.axvline(x=pearson(table), color='red', label='Pearson statistic for the table')
plt.title('Permutation Stats')
plt.legend()
<matplotlib.legend.Legend at 0xab0aaac>
pval = np.mean(stats >= pearson(table))
pval
0.22777
Not significant!