import numpy as np
import milk
Simulate some labeled data in a two-dimensional feature space.
features = np.random.randn(100,2) # 2d array of features: 100 examples of 2 features each
labels = np.empty(100)
for i in range(100):
if features[i,0] < 0:
if features[i,1] < -1:
labels[i] = np.random.rand() < .001
else:
labels[i] = np.random.rand() < .999
else:
if features[i,1] < 1:
labels[i] = np.random.rand() < .001
else:
labels[i] = np.random.rand() < .999
What is the decision tree for this data?
Since the data is two-dimensional, we can take a look at it easily.
plot(features[labels==True,0], features[labels==True,1], 'kx', mec='k', ms=6, mew=3)
plot(features[labels==False,0], features[labels==False,1], 'wo', mec='k', ms=8, mew=1)
grid()
Fitting the model is easy:
learner = milk.supervised.tree_learner()
model = learner.train(features, labels)
Using it is easy, too:
model.apply([-1,1])
True
Visualizing the decision boundary is a bit of a pain...
x_range = np.linspace(-3,3,100)
y_range = np.linspace(-3,3,100)
val = np.zeros((len(x_range), len(y_range)))
for i, x_i in enumerate(x_range):
for j, y_j in enumerate(y_range):
val[i,j] = model.apply([x_i,y_j])
imshow(val[::1,::-1].T, extent=[x_range[0],x_range[-1],y_range[0],y_range[-1]], cmap=cm.Greys)
plot(features[labels==True,0], features[labels==True,1], 'kx', mec='w', ms=7, mew=5)
plot(features[labels==True,0], features[labels==True,1], 'kx', mec='k', ms=5, mew=3)
plot(features[labels==False,0], features[labels==False,1], 'wo', mec='k', ms=8, mew=1)
grid()
And can we have a picture of the decision tree itself? It is hidden in the model instance somewhere...
model.tree
<milk.supervised.tree.Node at 0x1fc54850>
model.tree.featid, model.tree.featval, model.tree.left, model.tree.right
(0, -0.024916427900016365, <milk.supervised.tree.Node at 0x1fc5ab50>, <milk.supervised.tree.Node at 0x1fc54810>)
def describe_tree(node, prefix=''):
print prefix + 'if x[%d] < %.2f:' % (node.featid, node.featval)
if isinstance(node.left, milk.supervised.tree.Node):
describe_tree(node.left, prefix+' ')
else:
print prefix+' ', node.left
print prefix + 'else:'
if isinstance(node.right, milk.supervised.tree.Node):
describe_tree(node.right, prefix+' ')
else:
print prefix+' ', node.right
describe_tree(model.tree)
if x[0] < -0.02: if x[1] < -0.97: if x[0] < -1.51: Leaf(0.0,1.0) else: if x[0] < -1.12: Leaf(0.0,1.0) else: if x[0] < -0.43: Leaf(0.0,1.0) else: Leaf(0.0,3.0) else: if x[0] < -0.99: if x[0] < -1.50: if x[0] < -2.53: Leaf(1.0,1.0) else: if x[0] < -1.87: Leaf(1.0,1.0) else: if x[0] < -1.85: Leaf(1.0,1.0) else: Leaf(1.0,3.0) else: if x[0] < -1.41: Leaf(1.0,2.0) else: if x[0] < -1.30: Leaf(1.0,1.0) else: if x[0] < -1.28: Leaf(1.0,1.0) else: if x[0] < -1.16: Leaf(1.0,1.0) else: Leaf(1.0,3.0) else: if x[0] < -0.94: Leaf(1.0,1.0) else: if x[0] < -0.63: if x[0] < -0.82: Leaf(1.0,1.0) else: Leaf(1.0,3.0) else: if x[0] < -0.37: if x[0] < -0.63: Leaf(1.0,1.0) else: if x[0] < -0.49: Leaf(1.0,1.0) else: if x[0] < -0.49: Leaf(1.0,1.0) else: Leaf(1.0,3.0) else: if x[0] < -0.25: if x[0] < -0.37: Leaf(1.0,1.0) else: if x[0] < -0.35: Leaf(1.0,1.0) else: if x[0] < -0.35: Leaf(1.0,1.0) else: Leaf(1.0,3.0) else: if x[0] < -0.13: if x[0] < -0.17: Leaf(1.0,1.0) else: Leaf(1.0,3.0) else: if x[0] < -0.13: Leaf(1.0,1.0) else: if x[0] < -0.12: Leaf(1.0,1.0) else: if x[0] < -0.05: Leaf(1.0,1.0) else: Leaf(1.0,3.0) else: if x[1] < 1.05: if x[0] < 0.00: Leaf(0.5,2.0) else: if x[0] < 0.79: if x[0] < 0.14: if x[0] < 0.02: Leaf(0.0,1.0) else: if x[0] < 0.04: Leaf(0.0,1.0) else: if x[0] < 0.11: Leaf(0.0,1.0) else: Leaf(0.0,3.0) else: if x[0] < 0.35: if x[0] < 0.15: Leaf(0.0,1.0) else: if x[0] < 0.18: Leaf(0.0,1.0) else: if x[0] < 0.21: Leaf(0.0,1.0) else: Leaf(0.0,3.0) else: if x[0] < 0.53: if x[0] < 0.38: Leaf(0.0,1.0) else: Leaf(0.0,3.0) else: if x[0] < 0.58: Leaf(0.0,1.0) else: if x[0] < 0.61: Leaf(0.0,1.0) else: if x[0] < 0.71: Leaf(0.0,1.0) else: Leaf(0.0,3.0) else: if x[0] < 0.95: if x[0] < 0.80: Leaf(0.0,1.0) else: if x[0] < 0.85: Leaf(0.0,1.0) else: if x[0] < 0.85: Leaf(0.0,1.0) else: Leaf(0.0,3.0) else: if x[0] < 1.31: if x[0] < 1.01: Leaf(0.0,1.0) else: if x[0] < 1.06: Leaf(0.0,1.0) else: if x[0] < 1.09: Leaf(0.0,1.0) else: Leaf(0.0,3.0) else: if x[0] < 1.46: if x[0] < 1.35: Leaf(0.0,1.0) else: Leaf(0.0,3.0) else: if x[0] < 1.48: Leaf(0.0,1.0) else: if x[0] < 1.51: Leaf(0.0,1.0) else: if x[0] < 1.73: Leaf(0.0,1.0) else: Leaf(0.0,3.0) else: if x[0] < 0.37: Leaf(1.0,1.0) else: if x[0] < 0.48: Leaf(1.0,1.0) else: if x[0] < 0.62: Leaf(1.0,1.0) else: if x[0] < 1.34: Leaf(1.0,1.0) else: Leaf(1.0,3.0)
Not as simple as it seemed in the lecture, huh?