from sklearn.datasets import fetch_olivetti_faces
from sklearn.decomposition import ProjectedGradientNMF, MultiplicativeNMF
from time import time
faces = fetch_olivetti_faces(shuffle=True)
X = faces.data
rec_errors = []
components = []
for init in ('random', 'nndsvd', 'nndsvda', 'nndsvdar'):
for n_components in (9,):
for tol, tol_name in zip((1e-2, 1e-5), ('high', 'low')):
for n_iter in xrange(1, 102, 10):
t0 = time()
pgnmf = ProjectedGradientNMF(n_components=n_components, max_iter=n_iter, init=init, tol=tol)
pgnmf_err = pgnmf.fit(X).reconstruction_err_
pgnmf_time = time() - t0
rec_errors.append((init, n_components, 'PG', tol_name, pgnmf.n_iter_, pgnmf_err, pgnmf_time))
t0 = time()
multnmf = MultiplicativeNMF(n_components=n_components, max_iter=n_iter, init=init, tol=tol)
multnmf_err = multnmf.fit(X).reconstruction_err_
multnmf_time = time() - t0
rec_errors.append((init, n_components, 'MULT', tol_name, multnmf.n_iter_, multnmf_err, multnmf_time))
#t0 = time()
#lbnmf = LBfgsNMF(n_components=n_components, max_iter=n_iter, init=init, tol=1e-3)
#lbnmf_err = lbnmf.fit(X).reconstruction_err_
#lbnmf_time = time() - t0
#rec_errors.append((init, n_components, 'LBFGS', n_iter, lbnmf_err, lbnmf_time))
components.append((init, n_components, 'PG', tol_name, pgnmf.components_))
components.append((init, n_components, 'MULT', tol_name, multnmf.components_))
/Users/vene/code/scikit-learn/sklearn/decomposition/nmf.py:728: UserWarning: Iteration limit reached during fit warnings.warn("Iteration limit reached during fit")
import pandas
rec_errors_df = pandas.DataFrame(rec_errors, columns="init n_components method tol n_iter err time".split())
%pylab inline --no-import-all
colors = {9: 'r', 16: 'b'}
styles = {'PG': ':', 'MULT': '-'}
plt.figure(figsize=(22,25))
for i, init_name in enumerate(np.unique(rec_errors_df['init'])):
for j, n_components in enumerate((9,)):
for k, tol_name in enumerate(('high', 'low')):
plt.subplot(4, 4, 4 * i + 2 * k + 1)
for method in ('PG', 'MULT'):
selected_items=rec_errors_df\
[rec_errors_df['init'] == init_name]\
[rec_errors_df['n_components'] == n_components]\
[rec_errors_df['method'] == method]\
[rec_errors_df['tol'] == tol_name]
plt.plot(selected_items['n_iter'], selected_items['err'],
color=colors[n_components],
ls=styles[method],
label="{}".format(method))
plt.xlabel("n_iter")
plt.legend()
plt.title("{} tol={}".format(init_name, tol_name))
plt.subplot(4, 4, 4 * i + 2 * k + 2)
for method in ('PG', 'MULT'):
selected_items=rec_errors_df\
[rec_errors_df['init'] == init_name]\
[rec_errors_df['n_components'] == n_components]\
[rec_errors_df['method'] == method]\
[rec_errors_df['tol'] == tol_name]
plt.plot(selected_items['time'], selected_items['err'],
color=colors[n_components],
ls=styles[method],
label="{}".format(method))
plt.xlabel("Time (s)")
plt.legend()
plt.title("{} tol={}".format(init_name, tol_name))
Populating the interactive namespace from numpy and matplotlib
selected_items = rec_errors_df[rec_errors_df['n_components'] == 9]
colors={'random': 'r', 'nndsvd': 'g', 'nndsvda': 'b', 'nndsvdar': 'y'}
markers={'PG': 'x', 'MULT': 'o'}
plt.figure(figsize=(15, 6))
#plt.ylim((0.95, 1.05))
for k, tol_name in enumerate(['high', 'low']):
plt.subplot(1, 2, k + 1)
for init in ('random', 'nndsvd', 'nndsvda', 'nndsvdar'):
for method in ('MULT', 'PG'):
these_items = selected_items[selected_items['init'] == init]\
[selected_items['method'] == method]\
[selected_items['tol'] == tol_name]
these_items = these_items[these_items['n_iter'] == these_items['n_iter'].max()]
t = these_items['time'].tolist()[0]
err = these_items['err'].tolist()[0]
plt.scatter(t, err / np.sqrt(np.sum(X ** 2)), marker=markers[method], color=colors[init],
label="{}-{}".format(method, init))
plt.legend()
plt.xlabel('Time (s)')
plt.ylabel('Reconstruction error')
plt.title("tol={}".format(tol_name))
components_df = pandas.DataFrame(components, columns=['init', 'n_components', 'method', 'tol', 'components'])
def plot_gallery(title, images, n_col, n_row):
plt.figure(figsize=(2. * n_col, 2.26 * n_row))
plt.suptitle(title, size=16)
for i, comp in enumerate(images):
plt.subplot(n_row, n_col, i + 1)
vmax = max(comp.max(), -comp.min())
plt.imshow(comp.reshape((64, 64)), cmap=plt.cm.gray,
interpolation='nearest',
vmin=-vmax, vmax=vmax)
plt.xticks(())
plt.yticks(())
plt.subplots_adjust(0.01, 0.05, 0.99, 0.93, 0.04, 0.)
for init in colors.keys():
for method in ('PG', 'MULT'):
for tol in ('low', 'high'):
for n_components in (9,):
these_comps = components_df\
[components_df['init'] == init]\
[components_df['method'] == method]\
[components_df['n_components'] == n_components]\
[components_df['tol'] == tol]\
['components'].tolist()[0]
plot_gallery("{}-{} tol={}".format(method, init, tol), these_comps,
n_col=int(np.sqrt(n_components)), n_row=int(np.sqrt(n_components)))