import numpy as np, pymc as pm, pandas as pd
from scipy import stats
true_alpha = -0.1
true_beta = 0.8
n_obs = 100
x = np.repeat(np.linspace(start=-3, stop=3, num=9), n_obs)
Normal = stats.norm
p = Normal.cdf(true_alpha + true_beta*x)
Bernoulli = stats.bernoulli
y = Bernoulli.rvs(p)
with pm.Model() as model:
# priors
alpha = pm.Normal('alpha', mu=0, tau=0.001)
beta = pm.Normal('beta', mu=0, tau=0.001)
# linear predictor
theta_p = (alpha + beta * x)
# logit transform: this seems to work ok
# def invlogit(x):
# import theano.tensor as t
# return t.exp(x) / (1 + t.exp(x))
# theta = invlogit(theta_p)
# Probit transform: this doesn't work
def phi(x):
import theano.tensor as t
return 0.5 * (1 + t.erf(x / t.sqr(2)))
theta = phi(theta_p)
# likelihood
y = pm.Bernoulli('y', p=theta, observed=y)
with model:
# Inference
start = pm.find_MAP() # Find starting value by optimization
print("MAP found:")
print("alpha:", start['alpha'])
print("beta:", start['beta'])
print("Compare with true values:")
print("true_alpha", true_alpha)
print("true_beta", true_beta)
MAP found: ('alpha:', array(-0.2625527488896039)) ('beta:', array(2.316800630909899)) Compare with true values: ('true_alpha', -0.1) ('true_beta', 0.8)
with model:
step = pm.NUTS()
trace = pm.sample(2000,
step,
start=start,
progressbar=True) # draw posterior samples
[-----------------100%-----------------] 2000 of 2000 complete in 2.5 sec
pm.summary(trace)
alpha: Mean SD MC Error 95% HPD interval ------------------------------------------------------------------- -0.270 0.163 0.004 [-0.605, 0.044] Posterior quantiles: 2.5 25 50 75 97.5 |--------------|==============|==============|--------------| -0.591 -0.374 -0.271 -0.163 0.061 beta: Mean SD MC Error 95% HPD interval ------------------------------------------------------------------- 2.333 0.131 0.004 [2.076, 2.570] Posterior quantiles: 2.5 25 50 75 97.5 |--------------|==============|==============|--------------| 2.095 2.243 2.327 2.418 2.593
%load_ext rmagic
from __future__ import division
The rmagic extension is already loaded. To reload it, use: %reload_ext rmagic
## push data to R
%Rpush x y
%%R
library(R2jags)
model <- "
model {
# likelihood
for (i in 1:length(y)) {
probit(theta[i]) <- a + b * x[i]
y[i] ~ dbern(theta[i])
}
# priors
a ~ dnorm(0, 1e-4)
b ~ dnorm(0, 1e-4)
}"
data <- list(x=x, y=y)
parameters <- c("a", "b")
samples = jags(data = data,
# inits = inits,
parameters.to.save = parameters,
model.file = textConnection(model),
n.chains = 1,
n.iter = 1000,
n.burnin = 200,
n.thin = 1,
DIC = T)
Loading required package: rjags Loading required package: coda Loading required package: lattice Linked to JAGS 3.4.0 Loaded modules: basemod,bugs Attaching package: ‘R2jags’ The following object is masked from ‘package:coda’: traceplot module glm loaded Compiling model graph Resolving undeclared variables Allocating nodes Graph Size: 1831 Initializing model | | | 0% | |* | 2% | |** | 5% | |**** | 8% | |***** | 10% | |****** | 12% | |******** | 15% | |********* | 18% | |********** | 20% | |*********** | 22% | |************ | 25% | |************** | 28% | |*************** | 30% | |**************** | 32% | |****************** | 35% | |******************* | 38% | |******************** | 40% | |********************* | 42% | |********************** | 45% | |************************ | 48% | |************************* | 50% | |************************** | 52% | |**************************** | 55% | |***************************** | 58% | |****************************** | 60% | |******************************* | 62% | |******************************** | 65% | |********************************** | 68% | |*********************************** | 70% | |************************************ | 72% | |************************************** | 75% | |*************************************** | 78% | |**************************************** | 80% | |***************************************** | 82% | |****************************************** | 85% | |******************************************** | 88% | |********************************************* | 90% | |********************************************** | 92% | |************************************************ | 95% | |************************************************* | 98% | |**************************************************| 100%
%%R
samples