%matplotlib inline import numpy as np import matplotlib.pyplot as plt import emcee import triangle import random def make_example(n, beta, sigma, sigma0): x = np.random.random(size = n) * 50 y = beta * x + np.random.normal(scale = sigma, size = n) indices = random.sample(range(n), n//5) y[indices] = np.random.normal(scale = sigma0, size = n//5) return x, y x, y = make_example(50, 0.5, 1, 10) _ = plt.scatter(x,y) def log_likelihood(theta, x, y): beta, sigma, sigma0 = theta[:3] o = theta[3:] linear_part = - np.log(sigma * sigma) / 2 - (y - beta * x)**2 / (2 * sigma * sigma) outlier_part = - np.log(sigma0 * sigma0) / 2 - y**2 / (2 * sigma0 * sigma0) return np.sum( o * outlier_part + (1-o) * linear_part) def log_prior(theta): beta, sigma, sigma0 = theta[:3] return - np.log(abs(sigma)) - np.log(abs(sigma0)) def log_posterior(theta, x, y): return log_likelihood(theta, x, y) + log_prior(theta) def jump_index(index, theta, x, y, old_log_prob): theta_new = np.array(theta, dtype=np.float64) if index < 3: theta_new[index] += np.random.normal(0, 0.2) else: if np.random.random() >= 0.5: theta_new[index] = 1 - theta_new[index] log_prob = log_posterior(theta_new, x, y) alpha = log_prob - old_log_prob if alpha >= 0 or alpha >= np.log(np.random.random()): return theta_new, log_prob return theta, old_log_prob def jump(theta, x, y, old_log_prob): index = random.randrange(len(theta)) return jump_index(index, theta, x, y, old_log_prob) def walk(theta, x, y, length = 1000): old_log_prob = log_posterior(theta,x,y) trace = [] for _ in range(length): theta, old_log_prob = jump(theta, x, y, old_log_prob) trace.append(theta) return np.array(trace) theta = np.array([0,1,1] + [1]*len(x)) trace = walk(theta, x, y, 1000000) trace # Marginalise away the outliers and discard burnin and sample every 100 terms to try to reduce correlations param_trace = trace[10000::100, :3] figure = triangle.corner(param_trace, labels = ["$\\beta$", "$\sigma$", "$\sigma_0$"], truths=[0.5, 1, 10]) fig, axes = plt.subplots(figsize=(8,6)) axes.plot(x,y, "ok") xfit = np.linspace(0, 50, num=50) beta = param_trace[:, 0] yfit = (xfit[:,None]*beta).mean(1) axes.plot(xfit, yfit) s = (xfit[:,None]*beta).std(1) axes.fill_between(xfit, yfit-s, yfit+s, color="lightgray") outliers = [ (index, trace[:,index+3].mean()) for index in range(len(x)) ] outliers.sort(key = lambda pair : pair[1]) outliers outlier_indices = [i for (i,t) in outliers if t > 0.9] fig, axes = plt.subplots(figsize=(8,6)) axes.plot(x,y, "ok", color="grey") axes.plot(x[outlier_indices],y[outlier_indices], "ok", color="red") def gibbs(theta, x, y): # Update beta etc. old_log_prob = log_posterior(theta, x, y) theta, old_log_prob = jump_index(0, theta, x, y, old_log_prob) theta, old_log_prob = jump_index(1, theta, x, y, old_log_prob) theta, old_log_prob = jump_index(2, theta, x, y, old_log_prob) # Update outliers beta, sigma, sigma0 = theta[:3] o = theta[3:] linear_part = - np.log(sigma * sigma) / 2 - (y - beta * x)**2 / (2 * sigma * sigma) outlier_part = - np.log(sigma0 * sigma0) / 2 - y**2 / (2 * sigma0 * sigma0) # prob = P(o_i = 1) #prob = np.exp(outlier_part) / (np.exp(linear_part) + np.exp(outlier_part)) prob = 1.0 / (1.0 + np.exp(linear_part - outlier_part)) o = np.array( np.random.random(size = len(prob)) <= prob, dtype = np.float64 ) return np.hstack([np.array([beta, sigma, sigma0]), o]) def walk(theta, length = 10000): trace = [] for _ in range(length): theta = gibbs(theta, x, y) trace.append(theta) return np.array(trace) theta = np.array([0,1,1] + [1]*len(x)) trace = walk(theta, 20000) param_trace = trace[2000::, :3] figure = triangle.corner(param_trace, labels = ["$\\beta$", "$\sigma$", "$\sigma_0$"], truths=[0.5, 1, 10]) outliers = [ (index, trace[:,index+3].mean()) for index in range(len(x)) ] outliers.sort(key = lambda pair : pair[1]) outlier_indices = [i for (i,t) in outliers if t > 0.9] fig, axes = plt.subplots(figsize=(8,6)) axes.plot(x,y, "ok", color="grey") axes.plot(x[outlier_indices],y[outlier_indices], "ok", color="red") xfit = np.linspace(0, 50, 50) beta = trace[:,0] axes.plot(xfit, xfit*beta.mean())