In part 1, we saw how to create a simple GUI to control a Python function interactively. Here, we will see a more advanced method that gives more control on the GUI.
Specifically, we will design a small GUI to control a Support Vector Classifier with scikit-learn.
Let's start by doing some imports.
from IPython.html import widgets
from IPython.display import display, clear_output
from functools import partial
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
We also need to import scikit-learn.
import sklearn
import sklearn.datasets as ds
import sklearn.cross_validation as cv
import sklearn.grid_search as gs
import sklearn.svm as svm
We create a random dataset, with 2D points, and a class generated by a XOR operation on the signs of the points coordinates. This dataset is not linearly separable.
X = np.random.randn(200, 2)
y = np.logical_xor(X[:, 0]>0, X[:, 1]>0)
Now, we create a function that displays with matplotlib the points with their classes, the decision function, and the decision boundaries.
# We generate a grid in the square [-3,3 ]^2.
xx, yy = np.meshgrid(np.linspace(-3, 3, 500),
np.linspace(-3, 3, 500))
# This function takes a SVM estimator as input.
def plot_decision_function(est):
# We evaluate the decision function on the grid.
Z = est.decision_function(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
cmap = plt.cm.Blues
# We display the decision function on the grid.
plt.imshow(Z,
extent=(xx.min(), xx.max(), yy.min(), yy.max()),
aspect='auto', origin='lower', cmap=cmap);
# We display the boundaries.
plt.contour(xx, yy, Z, levels=[0], linewidths=2,
colors='k');
# We display the points with their true labels.
plt.scatter(X[:, 0], X[:, 1], s=30, c=.5+.5*y, lw=1,
cmap=cmap, vmin=0, vmax=1)
plt.axis([-3, 3, -3, 3])
Let's create a GUI controlling the classifier. We need four controls and two buttons.
kernel_control = widgets.RadioButtonsWidget(description='kernel',
values=['linear', 'poly',
'rbf', 'sigmoid'])
degree_control = widgets.IntSliderWidget(description='degree',
value=3, min=3, max=10, step=1,
visible=False)
c_control = widgets.FloatSliderWidget(description='log(c)',
value=1, min=-3, max=3, step=.5)
gamma_control = widgets.FloatSliderWidget(description='log(gamma)',
value=1, min=-3, max=3, step=.5)
start_button = widgets.ButtonWidget(description='Go!')
gs_button = widgets.ButtonWidget(description='grid search')
Next step is to organize the layout of these controls. For this, we can use ContainerWidget
s. Those contain a list of child widgets. Besides, we can set hbox
and vbox
CSS classes for these containers in order to align children horizontally or vertically.
# The main container, vbox, will contains three row.
vbox = widgets.ContainerWidget()
hbox1 = widgets.ContainerWidget()
hbox2 = widgets.ContainerWidget()
hbox3 = widgets.ContainerWidget()
# We put the controls in each row.
hbox1.children = [kernel_control, degree_control]
hbox2.children = [c_control, gamma_control]
hbox3.children = [start_button, gs_button]
vbox.children = [hbox1, hbox2, hbox3]
# We display the GUI here.
display(vbox)
Setting CSS classes for containers is not particularly convenient in IPython 2.0: this should be improved in a later version.
# We put the CSS in a function so that we can reuse it later.
def set_css():
# HACK: for horizontal alignment, we need to first remove the vbox class, and
# then add the hbox class. This needs to happen *after* the widget has been displayed.
hbox1.remove_class('vbox')
hbox1.add_class('hbox')
hbox2.remove_class('vbox')
hbox2.add_class('hbox')
hbox3.remove_class('vbox')
hbox3.add_class('hbox')
vbox.add_class('vbox')
# Setting the width of the sliders.
degree_control.set_css('width', 100)
c_control.set_css('width', 100)
gamma_control.set_css('width', 100)
set_css()
We can customize the CSS at will. For example, try doing hbox1.set_css('background-color', '#ccc')
.
We also add some logic in the GUI. When the kernel is poly
, we display the degree control, otherwise we hide it. This can be done by registering a callback function with on_trait_change
. Here, the function on_kernel_change
will be called as soon as the value
attribute of the kernel control changes.
def on_kernel_change(name, value):
if value == 'poly':
degree_control.visible = True
else:
degree_control.visible = False
kernel_control.on_trait_change(on_kernel_change, 'value')
Now, we write the function that implements the actual computation, and that will be called by clicking on the GUI buttons.
def start(gs=False, e=None):
"""Fits a SVC on the data.
Arguments:
* gs=False: whether to use a grid search or not
* e=None: the event object used bi IPython (unused here)
"""
# We create the SVC estimator, using the hyperparameters as specified
# in the GUI. The .value attributes of the controls are updated in
# real time in the notebook when interacting with the controls.
estimator = svm.SVC(kernel=kernel_control.value,
C=10**c_control.value,
gamma=10**gamma_control.value,
degree=degree_control.value)
# We clear the previous plots in the output area.
clear_output()
# Using a grid search...
if gs:
estimator = sklearn.grid_search.GridSearchCV(estimator, dict(
C=np.logspace(-3., 3., 10),
gamma=np.logspace(-3., 3., 10),
))
# We launch the fitting.
estimator.fit(X, y)
# If we used a grid search, we retrieve the best estimator found.
if gs:
estimator = estimator.best_estimator_
# We update the GUI controls in the notebook with that best estimator.
c_control.value = np.log10(estimator.get_params()['C'])
gamma_control.value = np.log10(estimator.get_params()['gamma'])
# Finally, we display the results.
plot_decision_function(estimator)
We now bind this start
function to our two buttons, using either False
or True
for the first parameter. We use the convenient partial
function in the native functools
Python module for this.
start_button.on_click(partial(start, False))
gs_button.on_click(partial(start, True))
Everything is ready. We can display the GUI again and play with it. Note that changing the controls in this area also changes the controls in the copy of the widget above. This is because all instances of the GUI are dynamically bound in the notebook: they all share the same attributes (one model, multiple views).
display(vbox)
set_css()