# Utils from scipy.stats import bernoulli from scipy.stats import beta import matplotlib.pyplot as plt import numpy as np from numpy.random import multinomial from numpy.random import dirichlet %matplotlib inline lda_pm = Image(filename='images/lda_pm.png') lda_pm # 25 word vocabulary W = 25 # image size L = np.sqrt(W) # 10 topics T = 10 # 100 documents D = 100 # 100 words per document N = 100 # phi is given as the horizontal and vertical topics on the 5X5 images phi = [np.zeros((L, L)) for i in range(T)] line = 0 for phi_t in phi: if line >= L: trueLine = int(line - L) phi_t[:,trueLine] = 1/L*np.ones(L) else: phi_t[line] = 1/L*np.ones(L) line += 1 # plot the topics f, axs = plt.subplots(1,T+1,figsize=(15,1)) ax = axs[0] ax.text(0,0.4, "Topics: ", fontsize = 16) ax.axis("off") ax.get_yaxis().set_visible(False) ax.get_xaxis().set_visible(False) for (ax, (i,phi_t)) in zip(axs[1:], enumerate(phi)): ax.get_yaxis().set_visible(False) ax.get_xaxis().set_visible(False) ax.imshow(phi_t, interpolation='none', cmap='Greys_r') # sample theta from alpha = 1 theta = [dirichlet(np.ones(T)) for i in range(D)] B = [] # sample documents from theta and phi for d in range(D): doc = np.zeros(W) theta_sample = multinomial(N, theta[d]) for t,count in enumerate(theta_sample): doc += multinomial(count, phi[int(t)].flatten()) doc = doc.reshape(L,L) B.append(doc) # plot the documents j = int(np.sqrt(D)) for row in range(j): f, axs = plt.subplots(1,j+1,figsize=(10,1)) ax = axs[0] if row==0: ax.text(0,0.4, "Docs: ", fontsize = 12) ax.axis("off") ax.get_yaxis().set_visible(False) ax.get_xaxis().set_visible(False) for (ax, doc) in zip(axs[1:], B[row*j:row*j+j]): ax.get_yaxis().set_visible(False) ax.get_xaxis().set_visible(False) ax.imshow(doc, interpolation='none', cmap='Greys_r') # convert from bag of words to word list words = [] for b in B: doc = [] for idx,el in enumerate(b.flatten()): doc += [idx]*el words += doc # naive implementation of the LDA Gibbs Sampler def ldaGibbsSampler(maxIt,words,T,D,N, W, alpha, beta): # initialize z theta = [dirichlet(alpha*np.ones(T)) for i in range(D)] z = np.zeros(D*N) for i in range(D): for j in range(N): cat_val = multinomial(1, theta[i]) z[i*N+j] = np.flatnonzero(cat_val>0)[0] CWT, CDT, sum_CWT, sum_CDT = compute_counts(words, z, N, W, T, D) # iterate the sampler phis, phi_iter = [], [] for it in range(maxIt): if it%10 == 0: phi_l, theta_l = uncollapse(CWT, CDT, sum_CWT, sum_CDT, alpha, beta, T, W, D) phis.append(phi_l) phi_iter.append(it) for i in range(D): for j in range(N): idx = i*N+j CWT, CDT = update_counts(False, CWT, CDT, sum_CWT, sum_CDT, words, z, idx, N) z_ij = update_conditional(idx, CWT, CDT, sum_CWT, sum_CDT, alpha, beta, W, T) z[idx] = z_ij CWT, CDT = update_counts(True, CWT, CDT, sum_CWT, sum_CDT, words, z, idx, N) phi_l, theta_l = uncollapse(CWT, CDT, sum_CWT, sum_CDT, alpha, beta, T, W, D) phis.append(phi_l) phi_iter.append(maxIt) return CWT, CDT, sum_CWT, sum_CDT, phis, phi_iter def update_counts(increment, CWT, CDT, sum_CWT, sum_CDT, words, z, idx, N): w = words[idx] t = int(z[idx]) d = idx/N if increment: CWT[w,t] += 1 CDT[d,t] += 1 sum_CWT[t] += 1 sum_CDT[d] += 1 else: CWT[w,t] -= 1 CDT[d,t] -= 1 sum_CWT[t] -= 1 sum_CDT[d] -= 1 return CWT, CDT def compute_counts(words, z, N, W, T, D): CWT = np.zeros((W,T)) CDT = np.zeros((D,T)) sum_CWT = np.zeros(T) sum_CDT = np.zeros(D) for (idx, (t, w)) in enumerate(zip(z, words)): t = int(t) CWT[w,t] += 1 sum_CWT[t] += 1 d = idx/N CDT[d,t] += 1 sum_CDT[d] += 1 return CWT, CDT, sum_CWT, sum_CDT def update_conditional(idx, CWT, CDT, sum_CWT, sum_CDT, alpha, beta, W, T, eps = 10**-10): probs = np.zeros(T) sum_probs = 0 w = words[idx] d = idx/N for k in range(T): val = (CWT[w,k]+beta)*(CDT[d,k]+alpha)/(sum_CWT[k]+beta*W+eps)/(sum_CDT[d]+alpha*T+eps) probs[k] = val sum_probs += val probs_norm = probs/sum_probs z_ij_cat = multinomial(1, probs_norm) return np.flatnonzero(z_ij_cat>0)[0] def uncollapse(CWT, CDT, sum_CWT, sum_CDT, alpha, beta, T, W, D): phi_learned = np.zeros((T, W)) theta_learned = np.zeros((D, T)) for w in range(W): for t in range(T): f1 = CWT[w,t]+beta f2 = sum_CWT[t]+beta*W phi_learned[t,w] = f1/f2 for d in range(D): for t in range(T): g1 = CDT[d,t]+alpha g2 = sum_CDT[d]+alpha*T theta_learned[d,t] = g1/g2 return phi_learned, theta_learned %%time CWT, CDT, sum_CWT, sum_CDT, phis, phi_iter = ldaGibbsSampler(100,words,T,D,N,W,1,1) %%bash mkdir -p phi_gif rm phi_gif/*.png # let's plot phi for (row, (it,phi_learned)) in enumerate(zip(phi_iter, phis)): f, axs = plt.subplots(1,T+1,figsize=(15,1)) ax = axs[0] ax.text(0,0.4, "Iter. {}".format(it), fontsize = 16) ax.axis("off") ax.get_yaxis().set_visible(False) ax.get_xaxis().set_visible(False) for ((i,phi_t),ax) in zip(enumerate(phi_learned),axs[1:]): i += 1 ax.get_yaxis().set_visible(False) ax.get_xaxis().set_visible(False) ax.imshow(phi_t.reshape(5,5), interpolation='none', cmap='Greys_r') f.savefig('phi_gif/{0:04d}.png'.format(row), format='png') %%bash convert -delay 60 -loop 0 phi_gif/*.png phi_gif/phi.gif