class HtmlObject(object):
"""An object that can be displayed as HTML."""
def __init__(self, html):
self.html = html
def _repr_html_(self):
return self.html
class ContingencyTable(object):
"""A generic contingency table."""
def __init__(self, row_names, col_names, counts):
assert counts.shape == (len(row_names), len(col_names))
self.row_names = row_names
self.col_names = col_names
self.counts = counts
def get_chisq_stat(self):
"""Compute the chi-squared test statistic."""
expected = self.get_expected_counts()
return np.sum((self.counts - expected)**2 / expected)
def get_degrees_of_freedom(self):
"""Get the number of degrees of freedom for the chi-squared distribution."""
return np.product(self.counts.shape) - np.sum(self.counts.shape) + 1
def get_expected_counts(self):
"""Get a matric of the expected counts under the null hypothesis."""
expected = np.zeros(self.counts.shape)
for i in xrange(expected.shape[0]):
for j in xrange(expected.shape[1]):
n_i_dot = self.marginal(i, '.')
n_j_dot = self.marginal('.', j)
expected[i, j] = n_i_dot * n_j_dot
expected = expected / self.marginal('.', '.')
return expected
def get_p_value(self):
"""Compute a p-value."""
chisq_stat = self.get_chisq_stat()
dof = self.get_degrees_of_freedom()
return 1.0 - scipy.stats.chi2.cdf(chisq_stat, dof)
def marginal(self, i, j):
"""Compute a marginal (sum of row/column). Each argument can be either
an index or a '.'"""
return self._marginal(i, j, self.counts)
def to_null_hypothesis_html(self):
"""Get a table in HTML format showing the expected counts."""
expected = self.get_expected_counts()
return HtmlObject(
self._to_html(expected, caption='Expected Counts', summary=False))
def _marginal(self, i, j, cell_values):
i_dot = i == '.'
j_dot = j == '.'
assert i_dot or j_dot
if i_dot and j_dot:
return cell_values.sum()
elif i_dot:
return cell_values[:, j].sum()
elif j_dot:
return cell_values[i].sum()
def _to_html(self, cell_values, caption='Counts', summary=True):
lines = ['<table>']
lines.append('<caption>{0}</caption>'.format(caption))
def start_row():
lines.append('<tr>')
def end_row():
lines.append('</tr>')
def cell(contents):
lines.append('<td>{0}</td>'.format(str(contents)))
# Column names
start_row()
cell('')
for col_name in self.col_names:
cell(col_name)
cell('')
end_row()
# Rows
for i, row_name in enumerate(self.row_names):
start_row()
cell(row_name)
for j in xrange(len(self.col_names)):
cell(cell_values[i, j])
cell(self._marginal(i, '.', cell_values))
end_row()
start_row()
cell('')
for j in xrange(len(self.col_names)):
cell(self._marginal('.', j, cell_values))
cell(self._marginal('.', '.', cell_values))
end_row()
lines.append('</table>')
if summary:
chi2 = self.get_chisq_stat()
df = self.get_degrees_of_freedom()
s = 's' if df > 1 else ''
pval = self.get_p_value()
fmt = "$\\chi^2 = {0}$. {1} degree{2} of freedom. P-value {3}."
lines.append(fmt.format(chi2, df, s, pval))
return '\n'.join(lines)
def _repr_html_(self):
return self._to_html(self.counts)