!date
Fri Mar 8 12:13:38 PST 2013
The plan with this notebook is to work through the RStan Getting Started Guide to see what Python wrappers would be necessary for equivalent functionality.
This is an example in Section 5.5 of Gelman et al (2003), which studied coaching effects from eight schools. For simplicity, we call this example "eight schools."
Translating it to Python is a matter of converting single quotes to triple quotes, and a handful of other, similarly simple substitutions.
def stan(model_code, data, iter, chains):
return # TODO: plumbing
schools_code = """
data {
int<lower=0> J; // number of schools
real y[J]; // estimated treatment effects
real<lower=0> sigma[J]; // s.e. of effect estimates
}
parameters {
real mu;
real<lower=0> tau;
real eta[J];
}
transformed parameters {
real theta[J];
for (j in 1:J)
theta[j] <- mu + tau * eta[j];
}
model {
eta ~ normal(0, 1);
y ~ normal(theta, sigma);
}
"""
schools_dat = dict(J = 8,
y = (28, 8, -3, 7, -1, 1, 18, 12),
sigma = (15, 10, 16, 11, 9, 11, 10, 18))
fit = stan(model_code = schools_code, data = schools_dat,
iter = 1000, chains = 4)
Of course, there is some plumbing necessary to make that do anything...
tmp_dir = '/tmp/'
def save_temp(txt, fname):
fname = tmp_dir + fname # TODO: avoid collisions, plan for cleanup
with file(fname, 'w') as f:
f.write(txt)
return fname
def rdump(params):
# If a full implementation of this does not yet exist,
# I suggest implementing json on the C++ side of Stan
txt = ''
for key, val in params.items():
if type(val) == int:
txt += '%s <- %d\n' % (key, val)
elif type(val) == tuple:
txt += '%s <- c%s\n' % (key, val)
return txt
import pandas as pd, StringIO
def load_samples(sample_path):
comments = ''
samples = ''
with file(sample_path) as f:
for line in f:
if line[0] == '#':
comments += line
else:
samples += line
df = pd.read_csv(StringIO.StringIO(samples))
df.comments = comments
return df
def stan(model_code, data, iter, chains):
code_path = save_temp(model_code, 'model.stan')
data_path = save_temp(rdump(data), 'model.Rdata')
model_path = code_path.replace('.stan', '')
%cd ~/notebook/stan-src-1.1.1/
!time make $model_path
%cd $tmp_dir
!time $model_path --data=$data_path
sample_path = code_path.replace('model.stan', 'samples.csv')
return load_samples(sample_path)
fit = stan(model_code = schools_code, data = schools_dat,
iter = 1000, chains = 4)
/snfs2/HOME/abie/new_dm/stan-src-1.1.1 --- Translating Stan graphical model to C++ code --- bin/stanc /tmp/model.stan --o=/tmp/model.cpp Model name=model_model Input file=/tmp/model.stan Output file=/tmp/model.cpp g++ -I src -I lib/eigen_3.1.2 -I lib/boost_1.52.0 -Wall -DBOOST_RESULT_OF_USE_TR1 -DBOOST_NO_DECLTYPE -c -O3 -o /tmp/model.o /tmp/model.cpp src/stan/agrad/agrad.hpp:2191: warning: ‘void stan::agrad::free_memory()’ defined but not used g++ -I src -I lib/eigen_3.1.2 -I lib/boost_1.52.0 -Wall -DBOOST_RESULT_OF_USE_TR1 -DBOOST_NO_DECLTYPE -lpthread -O3 -o /tmp/model /tmp/model.o -Lbin -lstan real 0m16.502s user 0m14.355s sys 0m0.711s /tmp STAN SAMPLING COMMAND data = /tmp/model.Rdata init = random initialization init tries = 1 samples = samples.csv append_samples = 0 save_warmup = 0 seed = 1389169387 (randomly generated) chain_id = 1 (default) iter = 2000 warmup = 1000 thin = 1 (default) equal_step_sizes = 0 leapfrog_steps = -1 max_treedepth = 10 epsilon = -1 epsilon_pm = 0 delta = 0.5 gamma = 0.05 Iteration: 1 / 2000 [ 0%] (Adapting) Iteration: 10 / 2000 [ 0%] (Adapting) Iteration: 20 / 2000 [ 1%] (Adapting) Iteration: 30 / 2000 [ 1%] (Adapting) Iteration: 40 / 2000 [ 2%] (Adapting) Iteration: 50 / 2000 [ 2%] (Adapting) Iteration: 60 / 2000 [ 3%] (Adapting) Iteration: 70 / 2000 [ 3%] (Adapting) Iteration: 80 / 2000 [ 4%] (Adapting) Iteration: 90 / 2000 [ 4%] (Adapting) Iteration: 100 / 2000 [ 5%] (Adapting) Iteration: 110 / 2000 [ 5%] (Adapting) Iteration: 120 / 2000 [ 6%] (Adapting) Iteration: 130 / 2000 [ 6%] (Adapting) Iteration: 140 / 2000 [ 7%] (Adapting) Iteration: 150 / 2000 [ 7%] (Adapting) Iteration: 160 / 2000 [ 8%] (Adapting) Iteration: 170 / 2000 [ 8%] (Adapting) Iteration: 180 / 2000 [ 9%] (Adapting) Iteration: 190 / 2000 [ 9%] (Adapting) Iteration: 200 / 2000 [ 10%] (Adapting) Iteration: 210 / 2000 [ 10%] (Adapting) Iteration: 220 / 2000 [ 11%] (Adapting) Iteration: 230 / 2000 [ 11%] (Adapting) Iteration: 240 / 2000 [ 12%] (Adapting) Iteration: 250 / 2000 [ 12%] (Adapting) Iteration: 260 / 2000 [ 13%] (Adapting) Iteration: 270 / 2000 [ 13%] (Adapting) Iteration: 280 / 2000 [ 14%] (Adapting) Iteration: 290 / 2000 [ 14%] (Adapting) Iteration: 300 / 2000 [ 15%] (Adapting) Iteration: 310 / 2000 [ 15%] (Adapting) Iteration: 320 / 2000 [ 16%] (Adapting) Iteration: 330 / 2000 [ 16%] (Adapting) Iteration: 340 / 2000 [ 17%] (Adapting) Iteration: 350 / 2000 [ 17%] (Adapting) Iteration: 360 / 2000 [ 18%] (Adapting) Iteration: 370 / 2000 [ 18%] (Adapting) Iteration: 380 / 2000 [ 19%] (Adapting) Iteration: 390 / 2000 [ 19%] (Adapting) Iteration: 400 / 2000 [ 20%] (Adapting) Iteration: 410 / 2000 [ 20%] (Adapting) Iteration: 420 / 2000 [ 21%] (Adapting) Iteration: 430 / 2000 [ 21%] (Adapting) Iteration: 440 / 2000 [ 22%] (Adapting) Iteration: 450 / 2000 [ 22%] (Adapting) Iteration: 460 / 2000 [ 23%] (Adapting) Iteration: 470 / 2000 [ 23%] (Adapting) Iteration: 480 / 2000 [ 24%] (Adapting) Iteration: 490 / 2000 [ 24%] (Adapting) Iteration: 500 / 2000 [ 25%] (Adapting) Iteration: 510 / 2000 [ 25%] (Adapting) Iteration: 520 / 2000 [ 26%] (Adapting) Iteration: 530 / 2000 [ 26%] (Adapting) Iteration: 540 / 2000 [ 27%] (Adapting) Iteration: 550 / 2000 [ 27%] (Adapting) Iteration: 560 / 2000 [ 28%] (Adapting) Iteration: 570 / 2000 [ 28%] (Adapting) Iteration: 580 / 2000 [ 29%] (Adapting) Iteration: 590 / 2000 [ 29%] (Adapting) Iteration: 600 / 2000 [ 30%] (Adapting) Iteration: 610 / 2000 [ 30%] (Adapting) Iteration: 620 / 2000 [ 31%] (Adapting) Iteration: 630 / 2000 [ 31%] (Adapting) Iteration: 640 / 2000 [ 32%] (Adapting) Iteration: 650 / 2000 [ 32%] (Adapting) Iteration: 660 / 2000 [ 33%] (Adapting) Iteration: 670 / 2000 [ 33%] (Adapting) Iteration: 680 / 2000 [ 34%] (Adapting) Iteration: 690 / 2000 [ 34%] (Adapting) Iteration: 700 / 2000 [ 35%] (Adapting) Iteration: 710 / 2000 [ 35%] (Adapting) Iteration: 720 / 2000 [ 36%] (Adapting) Iteration: 730 / 2000 [ 36%] (Adapting) Iteration: 740 / 2000 [ 37%] (Adapting) Iteration: 750 / 2000 [ 37%] (Adapting) Iteration: 760 / 2000 [ 38%] (Adapting) Iteration: 770 / 2000 [ 38%] (Adapting) Iteration: 780 / 2000 [ 39%] (Adapting) Iteration: 790 / 2000 [ 39%] (Adapting) Iteration: 800 / 2000 [ 40%] (Adapting) Iteration: 810 / 2000 [ 40%] (Adapting) Iteration: 820 / 2000 [ 41%] (Adapting) Iteration: 830 / 2000 [ 41%] (Adapting) Iteration: 840 / 2000 [ 42%] (Adapting) Iteration: 850 / 2000 [ 42%] (Adapting) Iteration: 860 / 2000 [ 43%] (Adapting) Iteration: 870 / 2000 [ 43%] (Adapting) Iteration: 880 / 2000 [ 44%] (Adapting) Iteration: 890 / 2000 [ 44%] (Adapting) Iteration: 900 / 2000 [ 45%] (Adapting) Iteration: 910 / 2000 [ 45%] (Adapting) Iteration: 920 / 2000 [ 46%] (Adapting) Iteration: 930 / 2000 [ 46%] (Adapting) Iteration: 940 / 2000 [ 47%] (Adapting) Iteration: 950 / 2000 [ 47%] (Adapting) Iteration: 960 / 2000 [ 48%] (Adapting) Iteration: 970 / 2000 [ 48%] (Adapting) Iteration: 980 / 2000 [ 49%] (Adapting) Iteration: 990 / 2000 [ 49%] (Adapting) Iteration: 1000 / 2000 [ 50%] (Adapting) Iteration: 1010 / 2000 [ 50%] (Sampling) Iteration: 1020 / 2000 [ 51%] (Sampling) Iteration: 1030 / 2000 [ 51%] (Sampling) Iteration: 1040 / 2000 [ 52%] (Sampling) Iteration: 1050 / 2000 [ 52%] (Sampling) Iteration: 1060 / 2000 [ 53%] (Sampling) Iteration: 1070 / 2000 [ 53%] (Sampling) Iteration: 1080 / 2000 [ 54%] (Sampling) Iteration: 1090 / 2000 [ 54%] (Sampling) Iteration: 1100 / 2000 [ 55%] (Sampling) Iteration: 1110 / 2000 [ 55%] (Sampling) Iteration: 1120 / 2000 [ 56%] (Sampling) Iteration: 1130 / 2000 [ 56%] (Sampling) Iteration: 1140 / 2000 [ 57%] (Sampling) Iteration: 1150 / 2000 [ 57%] (Sampling) Iteration: 1160 / 2000 [ 58%] (Sampling) Iteration: 1170 / 2000 [ 58%] (Sampling) Iteration: 1180 / 2000 [ 59%] (Sampling) Iteration: 1190 / 2000 [ 59%] (Sampling) Iteration: 1200 / 2000 [ 60%] (Sampling) Iteration: 1210 / 2000 [ 60%] (Sampling) Iteration: 1220 / 2000 [ 61%] (Sampling) Iteration: 1230 / 2000 [ 61%] (Sampling) Iteration: 1240 / 2000 [ 62%] (Sampling) Iteration: 1250 / 2000 [ 62%] (Sampling) Iteration: 1260 / 2000 [ 63%] (Sampling) Iteration: 1270 / 2000 [ 63%] (Sampling) Iteration: 1280 / 2000 [ 64%] (Sampling) Iteration: 1290 / 2000 [ 64%] (Sampling) Iteration: 1300 / 2000 [ 65%] (Sampling) Iteration: 1310 / 2000 [ 65%] (Sampling) Iteration: 1320 / 2000 [ 66%] (Sampling) Iteration: 1330 / 2000 [ 66%] (Sampling) Iteration: 1340 / 2000 [ 67%] (Sampling) Iteration: 1350 / 2000 [ 67%] (Sampling) Iteration: 1360 / 2000 [ 68%] (Sampling) Iteration: 1370 / 2000 [ 68%] (Sampling) Iteration: 1380 / 2000 [ 69%] (Sampling) Iteration: 1390 / 2000 [ 69%] (Sampling) Iteration: 1400 / 2000 [ 70%] (Sampling) Iteration: 1410 / 2000 [ 70%] (Sampling) Iteration: 1420 / 2000 [ 71%] (Sampling) Iteration: 1430 / 2000 [ 71%] (Sampling) Iteration: 1440 / 2000 [ 72%] (Sampling) Iteration: 1450 / 2000 [ 72%] (Sampling) Iteration: 1460 / 2000 [ 73%] (Sampling) Iteration: 1470 / 2000 [ 73%] (Sampling) Iteration: 1480 / 2000 [ 74%] (Sampling) Iteration: 1490 / 2000 [ 74%] (Sampling) Iteration: 1500 / 2000 [ 75%] (Sampling) Iteration: 1510 / 2000 [ 75%] (Sampling) Iteration: 1520 / 2000 [ 76%] (Sampling) Iteration: 1530 / 2000 [ 76%] (Sampling) Iteration: 1540 / 2000 [ 77%] (Sampling) Iteration: 1550 / 2000 [ 77%] (Sampling) Iteration: 1560 / 2000 [ 78%] (Sampling) Iteration: 1570 / 2000 [ 78%] (Sampling) Iteration: 1580 / 2000 [ 79%] (Sampling) Iteration: 1590 / 2000 [ 79%] (Sampling) Iteration: 1600 / 2000 [ 80%] (Sampling) Iteration: 1610 / 2000 [ 80%] (Sampling) Iteration: 1620 / 2000 [ 81%] (Sampling) Iteration: 1630 / 2000 [ 81%] (Sampling) Iteration: 1640 / 2000 [ 82%] (Sampling) Iteration: 1650 / 2000 [ 82%] (Sampling) Iteration: 1660 / 2000 [ 83%] (Sampling) Iteration: 1670 / 2000 [ 83%] (Sampling) Iteration: 1680 / 2000 [ 84%] (Sampling) Iteration: 1690 / 2000 [ 84%] (Sampling) Iteration: 1700 / 2000 [ 85%] (Sampling) Iteration: 1710 / 2000 [ 85%] (Sampling) Iteration: 1720 / 2000 [ 86%] (Sampling) Iteration: 1730 / 2000 [ 86%] (Sampling) Iteration: 1740 / 2000 [ 87%] (Sampling) Iteration: 1750 / 2000 [ 87%] (Sampling) Iteration: 1760 / 2000 [ 88%] (Sampling) Iteration: 1770 / 2000 [ 88%] (Sampling) Iteration: 1780 / 2000 [ 89%] (Sampling) Iteration: 1790 / 2000 [ 89%] (Sampling) Iteration: 1800 / 2000 [ 90%] (Sampling) Iteration: 1810 / 2000 [ 90%] (Sampling) Iteration: 1820 / 2000 [ 91%] (Sampling) Iteration: 1830 / 2000 [ 91%] (Sampling) Iteration: 1840 / 2000 [ 92%] (Sampling) Iteration: 1850 / 2000 [ 92%] (Sampling) Iteration: 1860 / 2000 [ 93%] (Sampling) Iteration: 1870 / 2000 [ 93%] (Sampling) Iteration: 1880 / 2000 [ 94%] (Sampling) Iteration: 1890 / 2000 [ 94%] (Sampling) Iteration: 1900 / 2000 [ 95%] (Sampling) Iteration: 1910 / 2000 [ 95%] (Sampling) Iteration: 1920 / 2000 [ 96%] (Sampling) Iteration: 1930 / 2000 [ 96%] (Sampling) Iteration: 1940 / 2000 [ 97%] (Sampling) Iteration: 1950 / 2000 [ 97%] (Sampling) Iteration: 1960 / 2000 [ 98%] (Sampling) Iteration: 1970 / 2000 [ 98%] (Sampling) Iteration: 1980 / 2000 [ 99%] (Sampling) Iteration: 1990 / 2000 [ 99%] (Sampling) Iteration: 2000 / 2000 [100%] (Sampling) real 0m0.048s user 0m0.040s sys 0m0.002s
print round_(fit.describe(95).T, 1)
count mean std min 2.5% 50% 97.5% max lp__ 1000 -5.2 2.7 -15.8 -11.5 -4.9 -0.5 0.1 treedepth__ 1000 2.4 0.8 0.0 1.0 3.0 3.0 5.0 stepsize__ 1000 1.1 0.0 1.1 1.1 1.1 1.1 1.1 mu 1000 8.0 6.5 -8.0 -4.7 7.6 24.8 28.8 tau 1000 6.6 5.9 0.0 0.1 5.2 24.0 26.7 eta.1 1000 0.4 1.0 -2.5 -1.5 0.4 2.3 3.2 eta.2 1000 -0.1 0.9 -2.8 -2.0 -0.0 1.7 3.7 eta.3 1000 -0.2 0.9 -2.8 -1.8 -0.3 1.7 2.8 eta.4 1000 -0.1 1.0 -3.3 -1.9 -0.1 1.7 2.9 eta.5 1000 -0.3 0.9 -3.5 -2.0 -0.3 1.5 3.6 eta.6 1000 -0.2 0.9 -3.5 -1.8 -0.2 1.6 2.7 eta.7 1000 0.3 0.9 -3.1 -1.5 0.4 1.9 3.2 eta.8 1000 0.0 0.9 -3.2 -1.7 -0.0 1.7 2.9 theta.1 1000 11.6 8.2 -8.5 -1.7 10.2 32.1 51.8 theta.2 1000 7.8 6.7 -16.9 -4.1 7.5 23.3 31.2 theta.3 1000 5.8 8.5 -20.7 -20.7 6.5 19.6 27.1 theta.4 1000 7.8 7.0 -15.8 -6.3 7.7 24.6 34.5 theta.5 1000 5.2 6.2 -16.1 -8.1 5.9 15.5 22.0 theta.6 1000 6.3 6.5 -13.3 -6.6 6.2 19.5 30.3 theta.7 1000 10.5 6.7 -10.5 -1.5 9.7 24.3 36.7 theta.8 1000 8.6 8.1 -13.3 -7.0 7.8 27.0 50.0
print fit.comments
# Samples Generated by Stan # # stan_version_major=1 # stan_version_minor=1 # stan_version_patch=1 # data=/tmp/model.Rdata # init=random initialization # append_samples=0 # save_warmup=0 # seed=1389169387 # chain_id=1 # iter=2000 # warmup=1000 # thin=1 # equal_step_sizes=0 # leapfrog_steps=-1 # max_treedepth=10 # epsilon=-1 # epsilon_pm=0 # delta=0.5 # gamma=0.05 # # (mcmc::nuts_diag) adaptation finished # step size=1.06744 # parameter step size multipliers: # 2.62742,0.67717,0.596109,0.522487,0.608809,0.576615,0.557494,0.561288,0.561017,0.60503
!date
Fri Mar 8 12:13:58 PST 2013
Patches welcome!