Draw a co-ocurrence matrix with matplotlib.matshow (matplotlib.imshow)
https://en.wikipedia.org/wiki/Co-occurrence_matrix
Created with IPython notebook in an Anaconda 1.9.1 environment
%pylab inline --no-import-all
# import numpy as np
# import matplotlib.pyplot as plt
# import pylab
import collections
from pprint import pformat
Populating the interactive namespace from numpy and matplotlib
DATA_CONSTRAINTS = (
(('english',), 5),
(('math',), 7),
(('social studies',), 3),
(('science',), 3),
(('english', 'science'), 1),
(('english', 'social studies',), 1),
(('math', 'science'), 2),
(('science', 'social studies',), 1)
)
def generate_data(constraints):
"""
Args:
constraints: ((categories_tuple,), count_int)
Returns:
list of category tuples satisfying constraints
"""
def _standardize_constraints(constraints):
return [(tuple(sorted(c[0])), c[1]) for c in constraints]
def fits_constraints(data, constraints):
dataset = list(data)
counts = collections.Counter(tuple(sorted(elem)) for elem in dataset)
_constraints = _standardize_constraints(constraints)
return sorted(counts.iteritems()) == sorted(_constraints)
def generate_data(constraints):
_constraints = _standardize_constraints(constraints)
for categories, count in _constraints:
for n in xrange(count):
yield categories
data = list(generate_data(constraints))
if fits_constraints(data, constraints):
return data
raise Exception("uh") # XXX
data = generate_data(DATA_CONSTRAINTS)
data
[('english',), ('english',), ('english',), ('english',), ('english',), ('math',), ('math',), ('math',), ('math',), ('math',), ('math',), ('math',), ('social studies',), ('social studies',), ('social studies',), ('science',), ('science',), ('science',), ('english', 'science'), ('english', 'social studies'), ('math', 'science'), ('math', 'science'), ('science', 'social studies')]
def iter_adjacencies(data):
"""
Args:
data: iterable of categories
Returns:
iterable of (row, (category_x, category_y)) pairs with self edges
"""
for row_n, row in enumerate(data):
_len = len(row)
for category in row:
yield (row_n,row), (category, category)
if _len > 1:
for i in xrange(_len-1):
yield (row_n,row), (row[i], row[i+1])
adj_list = list(iter_adjacencies(data))
adj_list
[((0, ('english',)), ('english', 'english')), ((1, ('english',)), ('english', 'english')), ((2, ('english',)), ('english', 'english')), ((3, ('english',)), ('english', 'english')), ((4, ('english',)), ('english', 'english')), ((5, ('math',)), ('math', 'math')), ((6, ('math',)), ('math', 'math')), ((7, ('math',)), ('math', 'math')), ((8, ('math',)), ('math', 'math')), ((9, ('math',)), ('math', 'math')), ((10, ('math',)), ('math', 'math')), ((11, ('math',)), ('math', 'math')), ((12, ('social studies',)), ('social studies', 'social studies')), ((13, ('social studies',)), ('social studies', 'social studies')), ((14, ('social studies',)), ('social studies', 'social studies')), ((15, ('science',)), ('science', 'science')), ((16, ('science',)), ('science', 'science')), ((17, ('science',)), ('science', 'science')), ((18, ('english', 'science')), ('english', 'english')), ((18, ('english', 'science')), ('science', 'science')), ((18, ('english', 'science')), ('english', 'science')), ((19, ('english', 'social studies')), ('english', 'english')), ((19, ('english', 'social studies')), ('social studies', 'social studies')), ((19, ('english', 'social studies')), ('english', 'social studies')), ((20, ('math', 'science')), ('math', 'math')), ((20, ('math', 'science')), ('science', 'science')), ((20, ('math', 'science')), ('math', 'science')), ((21, ('math', 'science')), ('math', 'math')), ((21, ('math', 'science')), ('science', 'science')), ((21, ('math', 'science')), ('math', 'science')), ((22, ('science', 'social studies')), ('science', 'science')), ((22, ('science', 'social studies')), ('social studies', 'social studies')), ((22, ('science', 'social studies')), ('science', 'social studies'))]
def build_array_from_adj_list(data, adj_list):
print(pformat(collections.Counter(data).items()))
categories = collections.OrderedDict(
(v,k) for k,v in enumerate(sorted(set(item for row in data for item in row))))
print("Indices: %s" % categories)
adjacency_dimensions = len(categories), len(categories)
print(adjacency_dimensions)
adj = np.zeros(adjacency_dimensions)
#print(adj)
for row, adjacencies in adj_list:
x, y = categories.get(adjacencies[0]), categories.get(adjacencies[1])
adj[x][y] += 1
if x != y:
adj[y][x] += 1
print(adj)
subtotal_0 = np.sum(adj, axis=0)
subtotal_1 = np.sum(adj, axis=1)
if not np.all(np.equal(subtotal_0, subtotal_1)):
raise Exception("Should be the same")
totals = zip(categories.keys(), subtotal_0)
print("Totals: %s" % totals)
return adj, categories
adj, categories = build_array_from_adj_list(data, adj_list)
[(('english',), 5), (('math', 'science'), 2), (('math',), 7), (('science',), 3), (('english', 'science'), 1), (('science', 'social studies'), 1), (('social studies',), 3), (('english', 'social studies'), 1)] Indices: OrderedDict([('english', 0), ('math', 1), ('science', 2), ('social studies', 3)]) (4, 4) [[ 7. 0. 1. 1.] [ 0. 9. 2. 0.] [ 1. 2. 7. 1.] [ 1. 0. 1. 5.]] Totals: [('english', 9.0), ('math', 11.0), ('science', 11.0), ('social studies', 7.0)]
def draw_co_ocurrence_diagram(adj, categories, figsize=(4,5)):
pylab.rcParams['figure.figsize'] = figsize
plt.matshow(adj, cmap="Greys")
ticks = (np.arange(len(categories)), categories.keys())
plt.xticks(*ticks, rotation=90)
plt.yticks(*ticks)
plt.colorbar(orientation='horizontal')
draw_co_ocurrence_diagram(adj, categories)