結局,確率論的プログラミングにおいて,プログラマがやることは,
あとは計算機がMCMCサンプリングをして,パラメータの推定値を返す.
結局,事前分布や尤度の確率分布を上手にモデリングするには,確率分布とその確率分布の関係をよく知っていなければならない.
"Univariate Distribution Relationships" 2008amstat.pdf
# coding: utf-8
from __future__ import division
import os
import sys
import glob
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
%matplotlib inline
%precision 4
#plt.style.use('ggplot')
import seaborn as sns
sns.set_style('white')
sns.set_context('paper')
np.random.seed(1234)
import pymc3 as pm
import scipy.stats as stats
Using gpu device 0: GeForce GTX TITAN (CNMeM is disabled)
# データを用意する
n = 100
h = 61
# パラメータをきめる
alpha = 2
beta = 2
niter = 1000
with pm.Model() as model: # context management
# 事前分布をモデリング
p = pm.Beta('p', alpha=alpha, beta=beta)
# 尤度をモデリング
y = pm.Binomial('y', n=n, p=p, observed=h)
#-- パラメータ推定 --
# MAP推定を使って初期値をきめる.MAP推定とMCMCは相性がいい
start = pm.find_MAP()
# MCMCの遷移ステップを選ぶ
step = pm.NUTS(scaling=start)
trace = pm.sample(niter, step, start, random_seed=123, progressbar=True)
pm.traceplot(trace)
[-----------------100%-----------------] 1000 of 1000 complete in 0.2 sec
plt.hist(trace['p'], 15, histtype='step', normed=True, label='post');
x = np.linspace(0, 1, 100)
plt.plot(x, stats.beta.pdf(x, alpha, beta), label='prior');
plt.legend(loc='best');
import pystan
# pystanではstanのコードは日本語を許しません.
# それはstanに渡す直前で強制的にasciiに変換しているからです.
coin_code = """
// The order of following declarations matters.
//
// data declaration here
data {
int<lower=0> n; // number of tosses
int<lower=0> y; // number of heads
}
// in case you want to tweak data...
transformed data {}
// parameters declaration here
parameters {
real<lower=0, upper=1> p;
}
// in case you want to tweak parameters...
transformed parameters {}
// your model here
model {
p ~ beta(2, 2);
y ~ binomial(n, p);
}
// tweaking samples from posterior
generated quantities {}
"""
# データを用意
coin_dat = {
'n': 100,
'y': 61,
}
# パラメータ推定.デフォルトでNUTSが使われる
fit = pystan.stan(model_code=coin_code, data=coin_dat, iter=1000, chains=1)
#fit = pystan.stan(file='coin_code.stan', data=coin_dat, iter=1000, chains=1)
print(fit)
Inference for Stan model: anon_model_ccec8cd16799dee6754657976ed54632. 1 chains, each with iter=1000; warmup=500; thin=1; post-warmup draws per chain=500, total post-warmup draws=500. mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat p 0.61 4.1e-3 0.04 0.53 0.57 0.61 0.64 0.69 116 1.0 lp__ -70.17 0.05 0.53 -71.72 -70.33 -69.96 -69.79 -69.74 105 1.0 Samples were drawn using NUTS(diag_e) at Mon Oct 12 17:12:11 2015. For each parameter, n_eff is a crude measure of effective sample size, and Rhat is the potential scale reduction factor on split chains (at convergence, Rhat=1).
fit.plot('p');
plt.tight_layout()
coin_dict = fit.extract()
coin_dict.keys()
KeysView(OrderedDict([('p', array([ 0.6733, 0.6355, 0.6015, 0.4983, 0.5673, 0.581 , 0.6186, 0.6396, 0.6825, 0.6358, 0.6275, 0.6613, 0.6212, 0.6565, 0.6061, 0.6174, 0.5713, 0.6161, 0.6866, 0.7314, 0.6033, 0.6543, 0.6247, 0.576 , 0.5743, 0.6277, 0.6414, 0.6314, 0.5536, 0.6202, 0.608 , 0.6617, 0.633 , 0.5872, 0.5726, 0.5913, 0.5961, 0.6331, 0.6072, 0.6093, 0.6239, 0.6514, 0.6752, 0.5669, 0.5842, 0.5564, 0.5307, 0.6246, 0.6357, 0.6155, 0.6236, 0.603 , 0.5403, 0.6051, 0.6551, 0.6276, 0.5874, 0.6055, 0.6326, 0.5996, 0.6868, 0.5455, 0.5904, 0.6164, 0.6165, 0.6526, 0.6199, 0.6443, 0.5072, 0.6581, 0.5424, 0.5549, 0.5856, 0.6371, 0.5668, 0.6019, 0.6197, 0.4842, 0.577 , 0.6113, 0.7361, 0.6323, 0.528 , 0.6502, 0.5851, 0.6437, 0.6635, 0.5715, 0.5957, 0.6777, 0.6018, 0.628 , 0.6254, 0.5756, 0.6012, 0.5754, 0.6756, 0.5872, 0.5943, 0.6081, 0.6198, 0.6052, 0.5814, 0.6074, 0.7017, 0.5837, 0.6362, 0.712 , 0.5936, 0.6321, 0.7061, 0.6251, 0.5931, 0.6175, 0.6262, 0.6022, 0.5864, 0.6215, 0.6597, 0.6516, 0.5504, 0.6766, 0.5745, 0.6786, 0.6926, 0.638 , 0.6484, 0.602 , 0.5442, 0.577 , 0.5889, 0.5606, 0.6367, 0.6132, 0.6192, 0.6762, 0.6062, 0.6578, 0.6104, 0.6071, 0.5202, 0.5723, 0.5601, 0.6476, 0.6134, 0.6396, 0.5965, 0.6093, 0.5613, 0.6473, 0.6157, 0.647 , 0.5657, 0.6446, 0.606 , 0.7116, 0.5883, 0.6395, 0.6008, 0.5829, 0.5984, 0.6608, 0.5812, 0.599 , 0.5275, 0.6656, 0.5899, 0.6178, 0.5824, 0.6299, 0.4969, 0.6083, 0.6306, 0.6855, 0.6705, 0.5874, 0.5912, 0.6409, 0.6081, 0.66 , 0.6756, 0.5997, 0.6549, 0.5908, 0.5435, 0.6239, 0.6668, 0.6753, 0.5489, 0.6054, 0.6144, 0.6024, 0.5342, 0.5892, 0.5928, 0.6136, 0.5981, 0.6006, 0.6809, 0.6061, 0.5472, 0.6041, 0.6454, 0.6501, 0.5508, 0.5883, 0.5938, 0.639 , 0.6237, 0.6218, 0.6309, 0.6245, 0.5978, 0.6024, 0.6016, 0.5777, 0.5585, 0.5273, 0.6362, 0.578 , 0.5412, 0.599 , 0.5404, 0.5705, 0.5647, 0.5794, 0.6855, 0.6468, 0.6853, 0.5514, 0.6554, 0.552 , 0.6528, 0.5542, 0.5959, 0.5968, 0.6037, 0.5846, 0.6402, 0.5744, 0.6163, 0.6502, 0.5558, 0.5557, 0.5021, 0.6763, 0.6413, 0.692 , 0.5818, 0.5856, 0.6171, 0.651 , 0.7221, 0.7064, 0.5937, 0.5941, 0.6324, 0.6349, 0.5426, 0.6808, 0.5759, 0.6081, 0.6034, 0.6418, 0.6004, 0.643 , 0.593 , 0.5472, 0.6032, 0.5949, 0.6231, 0.6635, 0.7009, 0.6204, 0.6741, 0.6131, 0.6614, 0.6181, 0.5501, 0.6108, 0.6061, 0.581 , 0.613 , 0.5625, 0.6262, 0.6073, 0.5909, 0.5828, 0.5962, 0.6514, 0.6061, 0.583 , 0.5462, 0.6014, 0.6021, 0.6672, 0.5225, 0.5594, 0.6174, 0.5425, 0.6743, 0.6148, 0.6412, 0.6278, 0.58 , 0.6239, 0.608 , 0.5817, 0.7235, 0.6076, 0.6183, 0.6452, 0.5761, 0.5971, 0.5411, 0.6543, 0.5344, 0.6295, 0.6125, 0.6357, 0.6193, 0.6387, 0.5562, 0.7402, 0.6455, 0.6896, 0.598 , 0.6711, 0.5922, 0.6197, 0.6503, 0.5655, 0.6216, 0.591 , 0.6764, 0.603 , 0.579 , 0.5738, 0.6023, 0.712 , 0.5279, 0.5866, 0.5754, 0.5954, 0.647 , 0.5167, 0.681 , 0.5586, 0.6211, 0.6165, 0.5943, 0.61 , 0.6452, 0.6192, 0.584 , 0.6318, 0.6555, 0.676 , 0.5282, 0.5343, 0.5885, 0.6253, 0.6492, 0.653 , 0.6483, 0.5694, 0.5885, 0.6787, 0.6277, 0.6815, 0.6292, 0.5981, 0.6326, 0.5869, 0.578 , 0.5895, 0.7 , 0.6182, 0.6142, 0.613 , 0.6666, 0.6081, 0.6612, 0.5715, 0.5114, 0.5719, 0.608 , 0.5549, 0.5942, 0.6575, 0.6021, 0.6335, 0.5928, 0.6021, 0.6411, 0.6108, 0.5928, 0.6671, 0.5335, 0.5913, 0.6565, 0.6549, 0.5651, 0.6018, 0.63 , 0.6803, 0.5744, 0.5903, 0.5977, 0.5455, 0.6056, 0.6287, 0.614 , 0.5701, 0.5254, 0.5969, 0.5634, 0.6528, 0.5792, 0.592 , 0.6141, 0.5814, 0.5424, 0.6361, 0.5647, 0.6239, 0.5764, 0.5557, 0.7218, 0.6072, 0.5908, 0.6888, 0.65 , 0.6988, 0.5934, 0.6599, 0.6598, 0.6035, 0.5668, 0.6188, 0.6287, 0.5087, 0.5249, 0.5812, 0.6109, 0.6749, 0.5981, 0.6292, 0.5349, 0.5068, 0.6693, 0.6028, 0.6305, 0.6125, 0.5385, 0.5721, 0.5963, 0.536 , 0.5583, 0.5878, 0.5744, 0.6339, 0.6093, 0.6944, 0.6525, 0.5418, 0.6014, 0.6693, 0.6391, 0.6374, 0.6629, 0.5899, 0.7328, 0.6087, 0.6081, 0.6776, 0.6594, 0.6606, 0.5498, 0.6557, 0.6002, 0.5704, 0.6054, 0.6385, 0.6411, 0.5432, 0.5717, 0.6442, 0.5683, 0.6416, 0.6206, 0.6337, 0.601 , 0.6401, 0.5829, 0.6475, 0.6316, 0.577 , 0.6718, 0.6759])), ('lp__', array([-70.789 , -69.9386, -69.7466, -72.1649, -70.0587, -69.8743, -69.7789, -69.9971, -71.1022, -69.9437, -69.8466, -70.4429, -69.7951, -70.3229, -69.7427, -69.7722, -69.9962, -69.7662, -71.2613, -73.605 , -69.744 , -70.2731, -69.8221, -69.9328, -69.955 , -69.8491, -70.0262, -69.8884, -70.3198, -69.7882, -69.7438, -70.4517, -69.9069, -69.8173, -69.9779, -69.788 , -69.763 , -69.9078, -69.7432, -69.7455, -69.8153, -70.2111, -70.8482, -70.0657, -69.8431, -70.2617, -70.93 , -69.8207, -69.9414, -69.7635, -69.8131, -69.7444, -70.6501, -69.7428, -70.291 , -69.8478, -69.8153, -69.7427, -69.9023, -69.7509, -71.2686, -70.5123, -69.7936, -69.7673, -69.7681, -70.2361, -69.7865, -70.0746, -71.7799, -70.363 , -70.5927, -70.2927, -69.83 , -69.9612, -70.0673, -69.746 , -69.7854, -72.8334, -69.9196, -69.7495, -73.9251, -69.8992, -71.0178, -70.1862, -69.8351, -70.0638, -70.4994, -69.9947, -69.7648, -70.9342, -69.7461, -69.8514, -69.8278, -69.9384, -69.7473, -69.9398, -70.8612, -69.817 , -69.7713, -69.7439, -69.7858, -69.7428, -69.8708, -69.7433, -71.9126, -69.8476, -69.9486, -72.4338, -69.7745, -69.8968, -72.1276, -69.825 , -69.7776, -69.7729, -69.8346, -69.7455, -69.8237, -69.7973, -70.4002, -70.2159, -70.393 , -70.8962, -69.9523, -70.9661, -71.5054, -69.9744, -70.1497, -69.7458, -70.5442, -69.9198, -69.8041, -70.1781, -69.9551, -69.7548, -69.7823, -70.8815, -69.7428, -70.3552, -69.7474, -69.7431, -71.2844, -69.9823, -70.187 , -70.1341, -69.7555, -69.9982, -69.7613, -69.7454, -70.1646, -70.1297, -69.7645, -70.1229, -70.0858, -70.0796, -69.7427, -72.4172, -69.8083, -69.9969, -69.7481, -69.8548, -69.7546, -70.4297, -69.8727, -69.7525, -71.0324, -70.557 , -69.7973, -69.7744, -69.8597, -69.8713, -72.2266, -69.7441, -69.8793, -71.217 , -70.7002, -69.8158, -69.7886, -70.0186, -69.7439, -70.4088, -70.8629, -69.7508, -70.286 , -69.7909, -70.5635, -69.815 , -70.5908, -70.8522, -70.4283, -69.7427, -69.7589, -69.7451, -70.8238, -69.8017, -69.7791, -69.7562, -69.7554, -69.7486, -71.0467, -69.7427, -70.4697, -69.7433, -70.0943, -70.1838, -70.3848, -69.8083, -69.7735, -69.9882, -69.8138, -69.7991, -69.8819, -69.8203, -69.7566, -69.7452, -69.7465, -69.9117, -70.2175, -71.0417, -69.9481, -69.9082, -70.6236, -69.7527, -70.6455, -70.0087, -70.1031, -69.8916, -71.2169, -70.12 , -71.2099, -70.3694, -70.2989, -70.3559, -70.2402, -70.3068, -69.7637, -69.76 , -69.7436, -69.8396, -70.0065, -69.9531, -69.7669, -70.1867, -70.2735, -70.2748, -71.9951, -70.8867, -70.0251, -71.4806, -69.8666, -69.8306, -69.771 , -70.2027, -73.0129, -72.1445, -69.7742, -69.7722, -69.9002, -69.9314, -70.5869, -71.0408, -69.9335, -69.7439, -69.7439, -70.0319, -69.7489, -70.0529, -69.7779, -70.4709, -69.7442, -69.7684, -69.8092, -70.4994, -71.8769, -69.7899, -70.8125, -69.7544, -70.4446, -69.776 , -70.3997, -69.7483, -69.7427, -69.8746, -69.7542, -70.1417, -69.835 , -69.7432, -69.7904, -69.8561, -69.7625, -70.2111, -69.7427, -69.8547, -70.4953, -69.7468, -69.7456, -70.602 , -71.201 , -70.2001, -69.7722, -70.5908, -70.8191, -69.7604, -70.023 , -69.8494, -69.8858, -69.815 , -69.7438, -69.8668, -73.0971, -69.7434, -69.7769, -70.0904, -69.9316, -69.7588, -70.6268, -70.274 , -70.8168, -69.8673, -69.7526, -69.9413, -69.7829, -69.9841, -70.265 , -74.2161, -70.0952, -71.3782, -69.7559, -70.7198, -69.7826, -69.7854, -70.188 , -70.0889, -69.7979, -69.7897, -70.8883, -69.7443, -69.8964, -69.9619, -69.7453, -72.437 , -71.0211, -69.8216, -69.9405, -69.766 , -70.123 , -71.4121, -71.0496, -70.2155, -69.7946, -69.7679, -69.771 , -69.7466, -70.0909, -69.7822, -69.8447, -69.8931, -70.3016, -70.8758, -71.0105, -70.82 , -69.8072, -69.8266, -70.1666, -70.2459, -70.1479, -70.026 , -69.8068, -70.9688, -69.8491, -71.0691, -69.8645, -69.7555, -69.9021, -69.8195, -69.9085, -69.7996, -71.8357, -69.7764, -69.7583, -69.754 , -70.5868, -69.7439, -70.4382, -69.9947, -71.6133, -69.9887, -69.7438, -70.2936, -69.7715, -70.3478, -69.7456, -69.9129, -69.7789, -69.7456, -70.0218, -69.7482, -69.7789, -70.5987, -70.845 , -69.7879, -70.3232, -70.2862, -70.0958, -69.7461, -69.8724, -71.0242, -69.954 , -69.7944, -69.757 , -70.5123, -69.7427, -69.8584, -69.7576, -70.014 , -71.1048, -69.7598, -70.1258, -70.2406, -69.8941, -69.7835, -69.758 , -69.8708, -70.5917, -69.9473, -70.1031, -69.8153, -69.9279, -70.2766, -72.9958, -69.7432, -69.7909, -71.348 , -70.1812, -71.7785, -69.7757, -70.4074, -70.4036, -69.7438, -70.0673, -69.7801, -69.8594, -71.7188, -71.1211, -69.8727, -69.7484, -70.841 , -69.7554, -69.8644, -70.8029, -71.7973, -70.6646, -69.7446, -69.8776, -69.7527, -70.6994, -69.9858, -69.7621, -70.7702, -70.2216, -69.8126, -69.9538, -69.9187, -69.7454, -71.5826, -70.234 , -70.6075, -69.7469, -70.6656, -69.9907, -69.9651, -70.4844, -69.7971, -73.6988, -69.7445, -69.7439, -70.9306, -70.3927, -70.4234, -70.4077, -70.3063, -69.7494, -70.0104, -69.7427, -69.9808, -70.0218, -70.5706, -69.9905, -70.0725, -70.0425, -70.0287, -69.7913, -69.9155, -69.7477, -70.0051, -69.8548, -70.1338, -69.8909, -69.92 , -70.7394, -70.8725]))]))