In [1]:
!date
Fri Mar  8 12:13:38 PST 2013

The plan with this notebook is to work through the RStan Getting Started Guide to see what Python wrappers would be necessary for equivalent functionality.

Example 1: Eight Schools

This is an example in Section 5.5 of Gelman et al (2003), which studied coaching effects from eight schools. For simplicity, we call this example "eight schools."

Translating it to Python is a matter of converting single quotes to triple quotes, and a handful of other, similarly simple substitutions.

In [2]:
def stan(model_code, data, iter, chains):
    return  # TODO: plumbing
In [3]:
schools_code = """
  data {
    int<lower=0> J; // number of schools 
    real y[J]; // estimated treatment effects
    real<lower=0> sigma[J]; // s.e. of effect estimates 
  }
  parameters {
    real mu; 
    real<lower=0> tau;
    real eta[J];
  }
  transformed parameters {
    real theta[J];
    for (j in 1:J)
      theta[j] <- mu + tau * eta[j];
  }
  model {
    eta ~ normal(0, 1);
    y ~ normal(theta, sigma);
  }
"""

schools_dat = dict(J = 8, 
                   y = (28,  8, -3,  7, -1,  1, 18, 12),
                   sigma = (15, 10, 16, 11,  9, 11, 10, 18))

fit = stan(model_code = schools_code, data = schools_dat, 
            iter = 1000, chains = 4)

Of course, there is some plumbing necessary to make that do anything...

In [4]:
tmp_dir = '/tmp/'
def save_temp(txt, fname):
    fname = tmp_dir + fname  # TODO: avoid collisions, plan for cleanup
    with file(fname, 'w') as f:
        f.write(txt)
        
    return fname

def rdump(params):
    # If a full implementation of this does not yet exist,
    # I suggest implementing json on the C++ side of Stan
    txt = ''
    
    for key, val in params.items():
        if type(val) == int:
            txt += '%s <- %d\n' % (key, val)
        elif type(val) == tuple:
            txt += '%s <- c%s\n' % (key, val)
    return txt
In [5]:
import pandas as pd, StringIO
In [6]:
def load_samples(sample_path):
    comments = ''
    samples = ''
    
    with file(sample_path) as f:
        for line in f:
            if line[0] == '#':
                comments += line
            else:
                samples += line
                
    df = pd.read_csv(StringIO.StringIO(samples))
    df.comments = comments
    
    return df
In [7]:
def stan(model_code, data, iter, chains):
    code_path = save_temp(model_code, 'model.stan')
    data_path = save_temp(rdump(data), 'model.Rdata')
    model_path = code_path.replace('.stan', '')
    
    %cd ~/notebook/stan-src-1.1.1/
    !time make $model_path
    
    %cd $tmp_dir
    !time $model_path --data=$data_path
    
    sample_path = code_path.replace('model.stan', 'samples.csv')
    return load_samples(sample_path)

fit = stan(model_code = schools_code, data = schools_dat, 
            iter = 1000, chains = 4)
/snfs2/HOME/abie/new_dm/stan-src-1.1.1

--- Translating Stan graphical model to C++ code ---
bin/stanc /tmp/model.stan --o=/tmp/model.cpp
Model name=model_model
Input file=/tmp/model.stan
Output file=/tmp/model.cpp
g++ -I src -I lib/eigen_3.1.2 -I lib/boost_1.52.0 -Wall -DBOOST_RESULT_OF_USE_TR1 -DBOOST_NO_DECLTYPE   -c -O3 -o /tmp/model.o /tmp/model.cpp
src/stan/agrad/agrad.hpp:2191: warning: ‘void stan::agrad::free_memory()’ defined but not used
g++ -I src -I lib/eigen_3.1.2 -I lib/boost_1.52.0 -Wall -DBOOST_RESULT_OF_USE_TR1 -DBOOST_NO_DECLTYPE  -lpthread  -O3 -o /tmp/model /tmp/model.o -Lbin -lstan

real	0m16.502s
user	0m14.355s
sys	0m0.711s
/tmp
STAN SAMPLING COMMAND
data = /tmp/model.Rdata
init = random initialization
init tries = 1
samples = samples.csv
append_samples = 0
save_warmup = 0
seed = 1389169387 (randomly generated)
chain_id = 1 (default)
iter = 2000
warmup = 1000
thin = 1 (default)
equal_step_sizes = 0
leapfrog_steps = -1
max_treedepth = 10
epsilon = -1
epsilon_pm = 0
delta = 0.5
gamma = 0.05

Iteration:    1 / 2000 [  0%]  (Adapting)
Iteration:   10 / 2000 [  0%]  (Adapting)
Iteration:   20 / 2000 [  1%]  (Adapting)
Iteration:   30 / 2000 [  1%]  (Adapting)
Iteration:   40 / 2000 [  2%]  (Adapting)
Iteration:   50 / 2000 [  2%]  (Adapting)
Iteration:   60 / 2000 [  3%]  (Adapting)
Iteration:   70 / 2000 [  3%]  (Adapting)
Iteration:   80 / 2000 [  4%]  (Adapting)
Iteration:   90 / 2000 [  4%]  (Adapting)
Iteration:  100 / 2000 [  5%]  (Adapting)
Iteration:  110 / 2000 [  5%]  (Adapting)
Iteration:  120 / 2000 [  6%]  (Adapting)
Iteration:  130 / 2000 [  6%]  (Adapting)
Iteration:  140 / 2000 [  7%]  (Adapting)
Iteration:  150 / 2000 [  7%]  (Adapting)
Iteration:  160 / 2000 [  8%]  (Adapting)
Iteration:  170 / 2000 [  8%]  (Adapting)
Iteration:  180 / 2000 [  9%]  (Adapting)
Iteration:  190 / 2000 [  9%]  (Adapting)
Iteration:  200 / 2000 [ 10%]  (Adapting)
Iteration:  210 / 2000 [ 10%]  (Adapting)
Iteration:  220 / 2000 [ 11%]  (Adapting)
Iteration:  230 / 2000 [ 11%]  (Adapting)
Iteration:  240 / 2000 [ 12%]  (Adapting)
Iteration:  250 / 2000 [ 12%]  (Adapting)
Iteration:  260 / 2000 [ 13%]  (Adapting)
Iteration:  270 / 2000 [ 13%]  (Adapting)
Iteration:  280 / 2000 [ 14%]  (Adapting)
Iteration:  290 / 2000 [ 14%]  (Adapting)
Iteration:  300 / 2000 [ 15%]  (Adapting)
Iteration:  310 / 2000 [ 15%]  (Adapting)
Iteration:  320 / 2000 [ 16%]  (Adapting)
Iteration:  330 / 2000 [ 16%]  (Adapting)
Iteration:  340 / 2000 [ 17%]  (Adapting)
Iteration:  350 / 2000 [ 17%]  (Adapting)
Iteration:  360 / 2000 [ 18%]  (Adapting)
Iteration:  370 / 2000 [ 18%]  (Adapting)
Iteration:  380 / 2000 [ 19%]  (Adapting)
Iteration:  390 / 2000 [ 19%]  (Adapting)
Iteration:  400 / 2000 [ 20%]  (Adapting)
Iteration:  410 / 2000 [ 20%]  (Adapting)
Iteration:  420 / 2000 [ 21%]  (Adapting)
Iteration:  430 / 2000 [ 21%]  (Adapting)
Iteration:  440 / 2000 [ 22%]  (Adapting)
Iteration:  450 / 2000 [ 22%]  (Adapting)
Iteration:  460 / 2000 [ 23%]  (Adapting)
Iteration:  470 / 2000 [ 23%]  (Adapting)
Iteration:  480 / 2000 [ 24%]  (Adapting)
Iteration:  490 / 2000 [ 24%]  (Adapting)
Iteration:  500 / 2000 [ 25%]  (Adapting)
Iteration:  510 / 2000 [ 25%]  (Adapting)
Iteration:  520 / 2000 [ 26%]  (Adapting)
Iteration:  530 / 2000 [ 26%]  (Adapting)
Iteration:  540 / 2000 [ 27%]  (Adapting)
Iteration:  550 / 2000 [ 27%]  (Adapting)
Iteration:  560 / 2000 [ 28%]  (Adapting)
Iteration:  570 / 2000 [ 28%]  (Adapting)
Iteration:  580 / 2000 [ 29%]  (Adapting)
Iteration:  590 / 2000 [ 29%]  (Adapting)
Iteration:  600 / 2000 [ 30%]  (Adapting)
Iteration:  610 / 2000 [ 30%]  (Adapting)
Iteration:  620 / 2000 [ 31%]  (Adapting)
Iteration:  630 / 2000 [ 31%]  (Adapting)
Iteration:  640 / 2000 [ 32%]  (Adapting)
Iteration:  650 / 2000 [ 32%]  (Adapting)
Iteration:  660 / 2000 [ 33%]  (Adapting)
Iteration:  670 / 2000 [ 33%]  (Adapting)
Iteration:  680 / 2000 [ 34%]  (Adapting)
Iteration:  690 / 2000 [ 34%]  (Adapting)
Iteration:  700 / 2000 [ 35%]  (Adapting)
Iteration:  710 / 2000 [ 35%]  (Adapting)
Iteration:  720 / 2000 [ 36%]  (Adapting)
Iteration:  730 / 2000 [ 36%]  (Adapting)
Iteration:  740 / 2000 [ 37%]  (Adapting)
Iteration:  750 / 2000 [ 37%]  (Adapting)
Iteration:  760 / 2000 [ 38%]  (Adapting)
Iteration:  770 / 2000 [ 38%]  (Adapting)
Iteration:  780 / 2000 [ 39%]  (Adapting)
Iteration:  790 / 2000 [ 39%]  (Adapting)
Iteration:  800 / 2000 [ 40%]  (Adapting)
Iteration:  810 / 2000 [ 40%]  (Adapting)
Iteration:  820 / 2000 [ 41%]  (Adapting)
Iteration:  830 / 2000 [ 41%]  (Adapting)
Iteration:  840 / 2000 [ 42%]  (Adapting)
Iteration:  850 / 2000 [ 42%]  (Adapting)
Iteration:  860 / 2000 [ 43%]  (Adapting)
Iteration:  870 / 2000 [ 43%]  (Adapting)
Iteration:  880 / 2000 [ 44%]  (Adapting)
Iteration:  890 / 2000 [ 44%]  (Adapting)
Iteration:  900 / 2000 [ 45%]  (Adapting)
Iteration:  910 / 2000 [ 45%]  (Adapting)
Iteration:  920 / 2000 [ 46%]  (Adapting)
Iteration:  930 / 2000 [ 46%]  (Adapting)
Iteration:  940 / 2000 [ 47%]  (Adapting)
Iteration:  950 / 2000 [ 47%]  (Adapting)
Iteration:  960 / 2000 [ 48%]  (Adapting)
Iteration:  970 / 2000 [ 48%]  (Adapting)
Iteration:  980 / 2000 [ 49%]  (Adapting)
Iteration:  990 / 2000 [ 49%]  (Adapting)
Iteration: 1000 / 2000 [ 50%]  (Adapting)
Iteration: 1010 / 2000 [ 50%]  (Sampling)
Iteration: 1020 / 2000 [ 51%]  (Sampling)
Iteration: 1030 / 2000 [ 51%]  (Sampling)
Iteration: 1040 / 2000 [ 52%]  (Sampling)
Iteration: 1050 / 2000 [ 52%]  (Sampling)
Iteration: 1060 / 2000 [ 53%]  (Sampling)
Iteration: 1070 / 2000 [ 53%]  (Sampling)
Iteration: 1080 / 2000 [ 54%]  (Sampling)
Iteration: 1090 / 2000 [ 54%]  (Sampling)
Iteration: 1100 / 2000 [ 55%]  (Sampling)
Iteration: 1110 / 2000 [ 55%]  (Sampling)
Iteration: 1120 / 2000 [ 56%]  (Sampling)
Iteration: 1130 / 2000 [ 56%]  (Sampling)
Iteration: 1140 / 2000 [ 57%]  (Sampling)
Iteration: 1150 / 2000 [ 57%]  (Sampling)
Iteration: 1160 / 2000 [ 58%]  (Sampling)
Iteration: 1170 / 2000 [ 58%]  (Sampling)
Iteration: 1180 / 2000 [ 59%]  (Sampling)
Iteration: 1190 / 2000 [ 59%]  (Sampling)
Iteration: 1200 / 2000 [ 60%]  (Sampling)
Iteration: 1210 / 2000 [ 60%]  (Sampling)
Iteration: 1220 / 2000 [ 61%]  (Sampling)
Iteration: 1230 / 2000 [ 61%]  (Sampling)
Iteration: 1240 / 2000 [ 62%]  (Sampling)
Iteration: 1250 / 2000 [ 62%]  (Sampling)
Iteration: 1260 / 2000 [ 63%]  (Sampling)
Iteration: 1270 / 2000 [ 63%]  (Sampling)
Iteration: 1280 / 2000 [ 64%]  (Sampling)
Iteration: 1290 / 2000 [ 64%]  (Sampling)
Iteration: 1300 / 2000 [ 65%]  (Sampling)
Iteration: 1310 / 2000 [ 65%]  (Sampling)
Iteration: 1320 / 2000 [ 66%]  (Sampling)
Iteration: 1330 / 2000 [ 66%]  (Sampling)
Iteration: 1340 / 2000 [ 67%]  (Sampling)
Iteration: 1350 / 2000 [ 67%]  (Sampling)
Iteration: 1360 / 2000 [ 68%]  (Sampling)
Iteration: 1370 / 2000 [ 68%]  (Sampling)
Iteration: 1380 / 2000 [ 69%]  (Sampling)
Iteration: 1390 / 2000 [ 69%]  (Sampling)
Iteration: 1400 / 2000 [ 70%]  (Sampling)
Iteration: 1410 / 2000 [ 70%]  (Sampling)
Iteration: 1420 / 2000 [ 71%]  (Sampling)
Iteration: 1430 / 2000 [ 71%]  (Sampling)
Iteration: 1440 / 2000 [ 72%]  (Sampling)
Iteration: 1450 / 2000 [ 72%]  (Sampling)
Iteration: 1460 / 2000 [ 73%]  (Sampling)
Iteration: 1470 / 2000 [ 73%]  (Sampling)
Iteration: 1480 / 2000 [ 74%]  (Sampling)
Iteration: 1490 / 2000 [ 74%]  (Sampling)
Iteration: 1500 / 2000 [ 75%]  (Sampling)
Iteration: 1510 / 2000 [ 75%]  (Sampling)
Iteration: 1520 / 2000 [ 76%]  (Sampling)
Iteration: 1530 / 2000 [ 76%]  (Sampling)
Iteration: 1540 / 2000 [ 77%]  (Sampling)
Iteration: 1550 / 2000 [ 77%]  (Sampling)
Iteration: 1560 / 2000 [ 78%]  (Sampling)
Iteration: 1570 / 2000 [ 78%]  (Sampling)
Iteration: 1580 / 2000 [ 79%]  (Sampling)
Iteration: 1590 / 2000 [ 79%]  (Sampling)
Iteration: 1600 / 2000 [ 80%]  (Sampling)
Iteration: 1610 / 2000 [ 80%]  (Sampling)
Iteration: 1620 / 2000 [ 81%]  (Sampling)
Iteration: 1630 / 2000 [ 81%]  (Sampling)
Iteration: 1640 / 2000 [ 82%]  (Sampling)
Iteration: 1650 / 2000 [ 82%]  (Sampling)
Iteration: 1660 / 2000 [ 83%]  (Sampling)
Iteration: 1670 / 2000 [ 83%]  (Sampling)
Iteration: 1680 / 2000 [ 84%]  (Sampling)
Iteration: 1690 / 2000 [ 84%]  (Sampling)
Iteration: 1700 / 2000 [ 85%]  (Sampling)
Iteration: 1710 / 2000 [ 85%]  (Sampling)
Iteration: 1720 / 2000 [ 86%]  (Sampling)
Iteration: 1730 / 2000 [ 86%]  (Sampling)
Iteration: 1740 / 2000 [ 87%]  (Sampling)
Iteration: 1750 / 2000 [ 87%]  (Sampling)
Iteration: 1760 / 2000 [ 88%]  (Sampling)
Iteration: 1770 / 2000 [ 88%]  (Sampling)
Iteration: 1780 / 2000 [ 89%]  (Sampling)
Iteration: 1790 / 2000 [ 89%]  (Sampling)
Iteration: 1800 / 2000 [ 90%]  (Sampling)
Iteration: 1810 / 2000 [ 90%]  (Sampling)
Iteration: 1820 / 2000 [ 91%]  (Sampling)
Iteration: 1830 / 2000 [ 91%]  (Sampling)
Iteration: 1840 / 2000 [ 92%]  (Sampling)
Iteration: 1850 / 2000 [ 92%]  (Sampling)
Iteration: 1860 / 2000 [ 93%]  (Sampling)
Iteration: 1870 / 2000 [ 93%]  (Sampling)
Iteration: 1880 / 2000 [ 94%]  (Sampling)
Iteration: 1890 / 2000 [ 94%]  (Sampling)
Iteration: 1900 / 2000 [ 95%]  (Sampling)
Iteration: 1910 / 2000 [ 95%]  (Sampling)
Iteration: 1920 / 2000 [ 96%]  (Sampling)
Iteration: 1930 / 2000 [ 96%]  (Sampling)
Iteration: 1940 / 2000 [ 97%]  (Sampling)
Iteration: 1950 / 2000 [ 97%]  (Sampling)
Iteration: 1960 / 2000 [ 98%]  (Sampling)
Iteration: 1970 / 2000 [ 98%]  (Sampling)
Iteration: 1980 / 2000 [ 99%]  (Sampling)
Iteration: 1990 / 2000 [ 99%]  (Sampling)
Iteration: 2000 / 2000 [100%]  (Sampling)



real	0m0.048s
user	0m0.040s
sys	0m0.002s
In [8]:
print round_(fit.describe(95).T, 1)
             count  mean  std   min  2.5%   50%  97.5%   max
lp__          1000  -5.2  2.7 -15.8 -11.5  -4.9   -0.5   0.1
treedepth__   1000   2.4  0.8   0.0   1.0   3.0    3.0   5.0
stepsize__    1000   1.1  0.0   1.1   1.1   1.1    1.1   1.1
mu            1000   8.0  6.5  -8.0  -4.7   7.6   24.8  28.8
tau           1000   6.6  5.9   0.0   0.1   5.2   24.0  26.7
eta.1         1000   0.4  1.0  -2.5  -1.5   0.4    2.3   3.2
eta.2         1000  -0.1  0.9  -2.8  -2.0  -0.0    1.7   3.7
eta.3         1000  -0.2  0.9  -2.8  -1.8  -0.3    1.7   2.8
eta.4         1000  -0.1  1.0  -3.3  -1.9  -0.1    1.7   2.9
eta.5         1000  -0.3  0.9  -3.5  -2.0  -0.3    1.5   3.6
eta.6         1000  -0.2  0.9  -3.5  -1.8  -0.2    1.6   2.7
eta.7         1000   0.3  0.9  -3.1  -1.5   0.4    1.9   3.2
eta.8         1000   0.0  0.9  -3.2  -1.7  -0.0    1.7   2.9
theta.1       1000  11.6  8.2  -8.5  -1.7  10.2   32.1  51.8
theta.2       1000   7.8  6.7 -16.9  -4.1   7.5   23.3  31.2
theta.3       1000   5.8  8.5 -20.7 -20.7   6.5   19.6  27.1
theta.4       1000   7.8  7.0 -15.8  -6.3   7.7   24.6  34.5
theta.5       1000   5.2  6.2 -16.1  -8.1   5.9   15.5  22.0
theta.6       1000   6.3  6.5 -13.3  -6.6   6.2   19.5  30.3
theta.7       1000  10.5  6.7 -10.5  -1.5   9.7   24.3  36.7
theta.8       1000   8.6  8.1 -13.3  -7.0   7.8   27.0  50.0
In [9]:
print fit.comments
# Samples Generated by Stan
#
# stan_version_major=1
# stan_version_minor=1
# stan_version_patch=1
# data=/tmp/model.Rdata
# init=random initialization
# append_samples=0
# save_warmup=0
# seed=1389169387
# chain_id=1
# iter=2000
# warmup=1000
# thin=1
# equal_step_sizes=0
# leapfrog_steps=-1
# max_treedepth=10
# epsilon=-1
# epsilon_pm=0
# delta=0.5
# gamma=0.05
#
# (mcmc::nuts_diag) adaptation finished
# step size=1.06744
# parameter step size multipliers:
# 2.62742,0.67717,0.596109,0.522487,0.608809,0.576615,0.557494,0.561288,0.561017,0.60503

In [10]:
!date
Fri Mar  8 12:13:58 PST 2013

Still to do:

  • Use the iter and chains options
  • Call the C++ code directly, instead of using temp files and commandline interface
  • Reproduce the rest of the getting started

Patches welcome!