#!/usr/bin/env python # coding: utf-8 # # Birth and merge variational inference for Dirichlet process mixtures of Gaussians # ## Goal # # This demo walks you through a "Hello World" example of using **bnpy** from within Python. # # We'll train a Dirichlet process (DP) Gaussian mixture model using memoized variational inference. # # We can use the following import statements to load bnpy. # In[1]: import bnpy # In[2]: import os get_ipython().run_line_magic('pylab', 'inline') from bnpy.viz.PlotUtil import ExportInfo bnpy.viz.PlotUtil.ConfigPylabDefaults(pylab) # ## Toy dataset : `AsteriskK8` # We'll use a simple dataset of 2D points, drawn from 8 well-separated Gaussian clusters. # # Our goal will be to recover the true set of 8 clusters. # In[3]: import AsteriskK8 Data = AsteriskK8.get_data() # We can visualize this dataset as follows: # In[4]: pylab.plot(Data.X[:,0], Data.X[:,1], 'k.'); pylab.axis('image'); pylab.xlim([-1.75, 1.75]); pylab.xticks([-1, 0, 1]); pylab.ylim([-1.75, 1.75]); pylab.yticks([-1, 0, 1]); # # Running inference with **bnpy** # # We'll fit a DP Gaussian mixture model, using birth and merge moves. We use the `moves` keyword argument to specify which moves to employ. # # We'll look at several possible initial numbers of clusters, and show that the birth/merge moves consistently reach the ideal set of 8 clusters. # # ## 1 initial cluster, repeated for 2 trials # In[5]: hmodel, RInfo = bnpy.run('AsteriskK8', 'DPMixtureModel', 'Gauss', 'moVB', K=1, moves='birth,merge', jobname='demobirthmerge-Kinit=1', nTask=5, nLap=20) # ## 4 initial clusters, repeated for 5 independent initializations # In[6]: hmodel, RInfo = bnpy.run('AsteriskK8', 'DPMixtureModel', 'Gauss', 'moVB', K=4, moves='birth,merge', jobname='demobirthmerge-Kinit=4', nTask=5, nLap=20) # ## 16 initial clusters, repeated for 5 independent initializations # In[7]: hmodel, RInfo = bnpy.run('AsteriskK8', 'DPMixtureModel', 'Gauss', 'moVB', K=16, moves='birth,merge', jobname='demobirthmerge-Kinit=16', nTask=5, nLap=20) # # Using bnpy visualization tools to assess performance # ## Trace plot of objective function over time. # # Here, for each run we plot a trace of the objective function (sometimes called the ELBO) as the algorithm sees more training data. # # The different colors correspond to different number of initial components. # # **Conclusion:** All the different initial conditions quickly converge to similar scores of high quality. # In[8]: bnpy.viz.PlotELBO.plotJobsThatMatch('AsteriskK8/demobirthmerge-*'); pylab.legend(loc='lower right'); # ## Trace plot of number of clusters over time. # # ** Conclusion:** Across very different initial conditions, we consistently reach exactly 8 learned clusters. # In[9]: bnpy.viz.PlotTrace.plotJobsThatMatch('AsteriskK8/demobirthmerge-*', yvar='K'); pylab.ylabel('num. states K'); # # Plot the learned cluster centers # # ## First, the 5 different *initializations* # # Each subplot corresponds to one single initialization. # In[10]: figH, axH = pylab.subplots(nrows=1, ncols=5, figsize=(20,4)) for plotID, rank in enumerate([1,2,3,4,5]): pylab.subplot(1, 5, plotID+1) taskidstr = '.rank%d' % (rank) bnpy.viz.PlotComps.plotCompsForJob('AsteriskK8/demobirthmerge-Kinit=4/', taskids=[taskidstr], figH=figH, lap=0); pylab.axis('image'); pylab.xlim([-1.75, 1.75]); pylab.xticks([-1, 0, 1]); pylab.ylim([-1.75, 1.75]); pylab.yticks([-1, 0, 1]); pylab.title('Run %d/5' % (rank)) pylab.tight_layout() # ## After 1 lap # # Showing each run after one complete pass through the dataset (all 10 batches). # In[11]: figH, axH = pylab.subplots(nrows=1, ncols=5, figsize=(20,4)) for plotID, rank in enumerate([1,2,3,4,5]): pylab.subplot(1, 5, plotID+1) taskidstr = '.rank%d' % (rank) bnpy.viz.PlotComps.plotCompsForJob('AsteriskK8/demobirthmerge-Kinit=4/', taskids=[taskidstr], figH=figH, lap=1); pylab.axis('image'); pylab.xlim([-1.75, 1.75]); pylab.xticks([-1, 0, 1]); pylab.ylim([-1.75, 1.75]); pylab.yticks([-1, 0, 1]); pylab.title('Run %d/5' % (rank)) pylab.tight_layout() # ## After 2 laps # In[12]: figH, axH = pylab.subplots(nrows=1, ncols=5, figsize=(20,4)) for plotID, rank in enumerate([1,2,3,4,5]): pylab.subplot(1, 5, plotID+1) taskidstr = '.rank%d' % (rank) bnpy.viz.PlotComps.plotCompsForJob('AsteriskK8/demobirthmerge-Kinit=4/', taskids=[taskidstr], figH=figH, lap=2); pylab.axis('image'); pylab.xlim([-1.75, 1.75]); pylab.xticks([-1, 0, 1]); pylab.ylim([-1.75, 1.75]); pylab.yticks([-1, 0, 1]); pylab.title('Run %d/5' % (rank)) pylab.tight_layout() # ## After 4 laps # In[13]: figH, axH = pylab.subplots(nrows=1, ncols=5, figsize=(20,4)) for plotID, rank in enumerate([1,2,3,4,5]): pylab.subplot(1, 5, plotID+1) taskidstr = '.rank%d' % (rank) bnpy.viz.PlotComps.plotCompsForJob('AsteriskK8/demobirthmerge-Kinit=4/', taskids=[taskidstr], figH=figH, lap=4); pylab.axis('image'); pylab.xlim([-1.75, 1.75]); pylab.xticks([-1, 0, 1]); pylab.ylim([-1.75, 1.75]); pylab.yticks([-1, 0, 1]); pylab.title('Run %d/5' % (rank)) pylab.tight_layout() # ## After 10 laps # In[14]: figH, axH = pylab.subplots(nrows=1, ncols=5, figsize=(20,4)) for plotID, rank in enumerate([1,2,3,4,5]): pylab.subplot(1, 5, plotID+1) taskidstr = '.rank%d' % (rank) bnpy.viz.PlotComps.plotCompsForJob('AsteriskK8/demobirthmerge-Kinit=4/', taskids=[taskidstr], figH=figH, lap=10); pylab.axis('image'); pylab.xlim([-1.75, 1.75]); pylab.xticks([-1, 0, 1]); pylab.ylim([-1.75, 1.75]); pylab.yticks([-1, 0, 1]); pylab.title('Run %d/5' % (rank)) pylab.tight_layout() # ## After 20 laps # # **Conclusion**: All runs identify the ideal set of 8 true clusters. # In[15]: figH, axH = pylab.subplots(nrows=1, ncols=5, figsize=(20,4)) for plotID, rank in enumerate([1,2,3,4,5]): pylab.subplot(1, 5, plotID+1) taskidstr = '.rank%d' % (rank) bnpy.viz.PlotComps.plotCompsForJob('AsteriskK8/demobirthmerge-Kinit=4/', taskids=[taskidstr], figH=figH); pylab.axis('image'); pylab.xlim([-1.75, 1.75]); pylab.xticks([-1, 0, 1]); pylab.ylim([-1.75, 1.75]); pylab.yticks([-1, 0, 1]); pylab.title('Run %d/5' % (rank)) pylab.tight_layout() # ## Grid plot: Evolution of learned clusters over time # # Each column is a separate run # In[16]: laps = [1,4,20] ranks = [1,2,3] nrows = len(laps) ncols = len(ranks) figH, axH = pylab.subplots(nrows=nrows, ncols=ncols, figsize=(2*nrows,2*ncols)) for plotID in range(nrows*ncols): lap = laps[plotID // nrows] rank = ranks[plotID % ncols] pylab.subplot(nrows, ncols, plotID+1) taskidstr = '.rank%d' % (rank) bnpy.viz.PlotComps.plotCompsForJob('AsteriskK8/demobirthmerge-Kinit=4/', taskids=[taskidstr], figH=figH, lap=lap); pylab.axis('image'); pylab.xlim([-1.75, 1.75]); pylab.xticks([]); pylab.ylim([-1.75, 1.75]); pylab.yticks([]); if rank == 1: pylab.ylabel('Lap %d' % (lap)) pylab.tight_layout() # Ignore this block. Only needed for auto-generation of documentation. if ExportInfo['doExport']: W_in, H_in = pylab.gcf().get_size_inches() figpath100 = '../docs/source/_static/GaussianToyData_DPMixtureModel_MemoizedWithBirthsAndMerges_%dx%d.png' % (100, 100) pylab.savefig(figpath100, bbox_inches=0, pad_inches=0, dpi=ExportInfo['dpi']/W_in);