import pandas.rpy.common as com
iris = com.load_data('iris')
iris.columns = map(lambda x: x.lower().replace('.', '_'), iris.columns)
iris.columns
Index([sepal_length, sepal_width, petal_length, petal_width, species], dtype=object)
iris['species'].value_counts()
setosa 50 versicolor 50 virginica 50
spec = iris.groupby('species')
cols = ['k', 'r', 'g']
i = 0
for s, df in spec:
plot(df['petal_width'], df['sepal_width'], 'o', color=cols[i], label=s)
i += 1
legend()
xlabel('Petal Width')
ylabel('Sepal Width');
import patsy as pt
from sklearn import tree
y, X = pt.dmatrices('species ~ sepal_width + petal_width - 1', iris)
clf = tree.DecisionTreeClassifier(max_depth=3).fit(X, y)
import StringIO, pydot
from IPython.core.display import HTML
dot_data = StringIO.StringIO()
tree.export_graphviz(clf, out_file=dot_data)
graph = pydot.graph_from_dot_data(dot_data.getvalue())
graph.write_png('tree.png')
HTML('<img src="files/tree.png" width=600 height=500/>')
clf.tree_.threshold
array([ 8.00000012e-001, -2.32035018e+077, 1.75000000e+000, 1.34999996e+000, 2.23443806e-314, 2.28130340e-314, 1.84999996e+000, 6.93069750e-310, 2.28041841e-314])
%load_ext rmagic
%%R -o rnewdata
set.seed(32313)
rnewdata <- data.frame(sepal_width = runif(20,2,4.5),petal_width = runif(20,0,2.5))
import pandas as pd
newdata = pd.DataFrame(rnewdata.T, columns=['sepal_width', 'petal_width'])
pred1 = clf.predict_proba(newdata)
species = iris['species'].unique()
idx = ['0', '1']
cols = [x + '_' + y for x in species for y in idx]
pd.DataFrame(np.hstack(pred1), columns=cols)
# I don't understand what these probabilities mean
setosa_0 | setosa_1 | versicolor_0 | versicolor_1 | virginica_0 | virginica_1 | |
---|---|---|---|---|---|---|
0 | 1 | 0 | 1.000000 | 0.000000 | 0.000000 | 1.000000 |
1 | 0 | 1 | 1.000000 | 0.000000 | 1.000000 | 0.000000 |
2 | 0 | 1 | 1.000000 | 0.000000 | 1.000000 | 0.000000 |
3 | 1 | 0 | 0.000000 | 1.000000 | 1.000000 | 0.000000 |
4 | 1 | 0 | 0.000000 | 1.000000 | 1.000000 | 0.000000 |
5 | 0 | 1 | 1.000000 | 0.000000 | 1.000000 | 0.000000 |
6 | 0 | 1 | 1.000000 | 0.000000 | 1.000000 | 0.000000 |
7 | 1 | 0 | 1.000000 | 0.000000 | 0.000000 | 1.000000 |
8 | 1 | 0 | 1.000000 | 0.000000 | 0.000000 | 1.000000 |
9 | 0 | 1 | 1.000000 | 0.000000 | 1.000000 | 0.000000 |
10 | 0 | 1 | 1.000000 | 0.000000 | 1.000000 | 0.000000 |
11 | 1 | 0 | 0.000000 | 1.000000 | 1.000000 | 0.000000 |
12 | 1 | 0 | 1.000000 | 0.000000 | 0.000000 | 1.000000 |
13 | 1 | 0 | 0.192308 | 0.807692 | 0.807692 | 0.192308 |
14 | 1 | 0 | 1.000000 | 0.000000 | 0.000000 | 1.000000 |
15 | 1 | 0 | 1.000000 | 0.000000 | 0.000000 | 1.000000 |
16 | 1 | 0 | 0.000000 | 1.000000 | 1.000000 | 0.000000 |
17 | 1 | 0 | 0.192308 | 0.807692 | 0.807692 | 0.192308 |
18 | 1 | 0 | 0.192308 | 0.807692 | 0.807692 | 0.192308 |
19 | 1 | 0 | 0.000000 | 1.000000 | 1.000000 | 0.000000 |
Cars93 = com.load_data('Cars93', package='MASS')
Cars93.columns = map(lambda x: x.lower().replace('.', '_'), Cars93.columns)
Cars93.ix[:6, :15]
manufacturer | model | type | min_price | price | max_price | mpg_city | mpg_highway | airbags | drivetrain | cylinders | enginesize | horsepower | rpm | rev_per_mile | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
1 | Acura | Integra | Small | 12.9 | 15.9 | 18.8 | 25 | 31 | None | Front | 4 | 1.8 | 140 | 6300 | 2890 |
2 | Acura | Legend | Midsize | 29.2 | 33.9 | 38.7 | 18 | 25 | Driver & Passenger | Front | 6 | 3.2 | 200 | 5500 | 2335 |
3 | Audi | 90 | Compact | 25.9 | 29.1 | 32.3 | 20 | 26 | Driver only | Front | 6 | 2.8 | 172 | 5500 | 2280 |
4 | Audi | 100 | Midsize | 30.8 | 37.7 | 44.6 | 19 | 26 | Driver & Passenger | Front | 6 | 2.8 | 172 | 5500 | 2535 |
5 | BMW | 535i | Midsize | 23.7 | 30.0 | 36.2 | 22 | 30 | Driver only | Rear | 4 | 3.5 | 208 | 5700 | 2545 |
6 | Buick | Century | Midsize | 14.2 | 15.7 | 17.3 | 22 | 31 | Driver only | Front | 4 | 2.2 | 110 | 5200 | 2565 |
y, X = pt.dmatrices('drivetrain ~ mpg_city + mpg_highway + airbags + \
enginesize + width + length + weight + price + \
cylinders + horsepower + wheelbase - 1', Cars93)
clf = tree.DecisionTreeClassifier().fit(X, y)
dot_data = StringIO.StringIO()
tree.export_graphviz(clf, out_file=dot_data)
graph = pydot.graph_from_dot_data(dot_data.getvalue())
graph.write_png('tree2.png')
HTML('<img src="files/tree2.png" width=1000 height=1500/>')
# pruning not currently supported