In [1]:
import numpy as np
import milk

Simulate some labeled data in a two-dimensional feature space.

In [2]:
features = np.random.randn(100,2) # 2d array of features: 100 examples of 2 features each
labels = np.empty(100)
for i in range(100):
    if features[i,0] < 0:
        if features[i,1] < -1:
            labels[i] = np.random.rand() < .001
        else:
            labels[i] = np.random.rand() < .999
    else:
        if features[i,1] < 1:
            labels[i] = np.random.rand() < .001
        else:
            labels[i] = np.random.rand() < .999

What is the decision tree for this data?

Since the data is two-dimensional, we can take a look at it easily.

In [3]:
plot(features[labels==True,0], features[labels==True,1], 'kx', mec='k', ms=6, mew=3)
plot(features[labels==False,0], features[labels==False,1], 'wo', mec='k', ms=8, mew=1)
grid()

Fitting the model is easy:

In [4]:
learner = milk.supervised.tree_learner()
model = learner.train(features, labels)

Using it is easy, too:

In [5]:
model.apply([-1,1])
Out[5]:
True

Visualizing the decision boundary is a bit of a pain...

In [6]:
x_range = np.linspace(-3,3,100)
y_range = np.linspace(-3,3,100)

val = np.zeros((len(x_range), len(y_range)))

for i, x_i in enumerate(x_range):
    for j, y_j in enumerate(y_range):
        val[i,j] = model.apply([x_i,y_j])

imshow(val[::1,::-1].T, extent=[x_range[0],x_range[-1],y_range[0],y_range[-1]], cmap=cm.Greys)
plot(features[labels==True,0], features[labels==True,1], 'kx', mec='w', ms=7, mew=5)
plot(features[labels==True,0], features[labels==True,1], 'kx', mec='k', ms=5, mew=3)
plot(features[labels==False,0], features[labels==False,1], 'wo', mec='k', ms=8, mew=1)
grid()

And can we have a picture of the decision tree itself? It is hidden in the model instance somewhere...

In [7]:
model.tree
Out[7]:
<milk.supervised.tree.Node at 0x1fc54850>
In [8]:
model.tree.featid, model.tree.featval, model.tree.left, model.tree.right
Out[8]:
(0,
 -0.024916427900016365,
 <milk.supervised.tree.Node at 0x1fc5ab50>,
 <milk.supervised.tree.Node at 0x1fc54810>)
In [15]:
def describe_tree(node, prefix=''):
    print prefix + 'if x[%d] < %.2f:' % (node.featid, node.featval)
    if isinstance(node.left, milk.supervised.tree.Node):
        describe_tree(node.left, prefix+'    ')
    else:
        print prefix+'   ', node.left
    print prefix + 'else:'
    if isinstance(node.right, milk.supervised.tree.Node):
        describe_tree(node.right, prefix+'    ')
    else:
        print prefix+'   ', node.right
In [16]:
describe_tree(model.tree)
if x[0] < -0.02:
    if x[1] < -0.97:
        if x[0] < -1.51:
            Leaf(0.0,1.0)
        else:
            if x[0] < -1.12:
                Leaf(0.0,1.0)
            else:
                if x[0] < -0.43:
                    Leaf(0.0,1.0)
                else:
                    Leaf(0.0,3.0)
    else:
        if x[0] < -0.99:
            if x[0] < -1.50:
                if x[0] < -2.53:
                    Leaf(1.0,1.0)
                else:
                    if x[0] < -1.87:
                        Leaf(1.0,1.0)
                    else:
                        if x[0] < -1.85:
                            Leaf(1.0,1.0)
                        else:
                            Leaf(1.0,3.0)
            else:
                if x[0] < -1.41:
                    Leaf(1.0,2.0)
                else:
                    if x[0] < -1.30:
                        Leaf(1.0,1.0)
                    else:
                        if x[0] < -1.28:
                            Leaf(1.0,1.0)
                        else:
                            if x[0] < -1.16:
                                Leaf(1.0,1.0)
                            else:
                                Leaf(1.0,3.0)
        else:
            if x[0] < -0.94:
                Leaf(1.0,1.0)
            else:
                if x[0] < -0.63:
                    if x[0] < -0.82:
                        Leaf(1.0,1.0)
                    else:
                        Leaf(1.0,3.0)
                else:
                    if x[0] < -0.37:
                        if x[0] < -0.63:
                            Leaf(1.0,1.0)
                        else:
                            if x[0] < -0.49:
                                Leaf(1.0,1.0)
                            else:
                                if x[0] < -0.49:
                                    Leaf(1.0,1.0)
                                else:
                                    Leaf(1.0,3.0)
                    else:
                        if x[0] < -0.25:
                            if x[0] < -0.37:
                                Leaf(1.0,1.0)
                            else:
                                if x[0] < -0.35:
                                    Leaf(1.0,1.0)
                                else:
                                    if x[0] < -0.35:
                                        Leaf(1.0,1.0)
                                    else:
                                        Leaf(1.0,3.0)
                        else:
                            if x[0] < -0.13:
                                if x[0] < -0.17:
                                    Leaf(1.0,1.0)
                                else:
                                    Leaf(1.0,3.0)
                            else:
                                if x[0] < -0.13:
                                    Leaf(1.0,1.0)
                                else:
                                    if x[0] < -0.12:
                                        Leaf(1.0,1.0)
                                    else:
                                        if x[0] < -0.05:
                                            Leaf(1.0,1.0)
                                        else:
                                            Leaf(1.0,3.0)
else:
    if x[1] < 1.05:
        if x[0] < 0.00:
            Leaf(0.5,2.0)
        else:
            if x[0] < 0.79:
                if x[0] < 0.14:
                    if x[0] < 0.02:
                        Leaf(0.0,1.0)
                    else:
                        if x[0] < 0.04:
                            Leaf(0.0,1.0)
                        else:
                            if x[0] < 0.11:
                                Leaf(0.0,1.0)
                            else:
                                Leaf(0.0,3.0)
                else:
                    if x[0] < 0.35:
                        if x[0] < 0.15:
                            Leaf(0.0,1.0)
                        else:
                            if x[0] < 0.18:
                                Leaf(0.0,1.0)
                            else:
                                if x[0] < 0.21:
                                    Leaf(0.0,1.0)
                                else:
                                    Leaf(0.0,3.0)
                    else:
                        if x[0] < 0.53:
                            if x[0] < 0.38:
                                Leaf(0.0,1.0)
                            else:
                                Leaf(0.0,3.0)
                        else:
                            if x[0] < 0.58:
                                Leaf(0.0,1.0)
                            else:
                                if x[0] < 0.61:
                                    Leaf(0.0,1.0)
                                else:
                                    if x[0] < 0.71:
                                        Leaf(0.0,1.0)
                                    else:
                                        Leaf(0.0,3.0)
            else:
                if x[0] < 0.95:
                    if x[0] < 0.80:
                        Leaf(0.0,1.0)
                    else:
                        if x[0] < 0.85:
                            Leaf(0.0,1.0)
                        else:
                            if x[0] < 0.85:
                                Leaf(0.0,1.0)
                            else:
                                Leaf(0.0,3.0)
                else:
                    if x[0] < 1.31:
                        if x[0] < 1.01:
                            Leaf(0.0,1.0)
                        else:
                            if x[0] < 1.06:
                                Leaf(0.0,1.0)
                            else:
                                if x[0] < 1.09:
                                    Leaf(0.0,1.0)
                                else:
                                    Leaf(0.0,3.0)
                    else:
                        if x[0] < 1.46:
                            if x[0] < 1.35:
                                Leaf(0.0,1.0)
                            else:
                                Leaf(0.0,3.0)
                        else:
                            if x[0] < 1.48:
                                Leaf(0.0,1.0)
                            else:
                                if x[0] < 1.51:
                                    Leaf(0.0,1.0)
                                else:
                                    if x[0] < 1.73:
                                        Leaf(0.0,1.0)
                                    else:
                                        Leaf(0.0,3.0)
    else:
        if x[0] < 0.37:
            Leaf(1.0,1.0)
        else:
            if x[0] < 0.48:
                Leaf(1.0,1.0)
            else:
                if x[0] < 0.62:
                    Leaf(1.0,1.0)
                else:
                    if x[0] < 1.34:
                        Leaf(1.0,1.0)
                    else:
                        Leaf(1.0,3.0)

Not as simple as it seemed in the lecture, huh?

Back to top