I was inspired by this blog post about James Pearson to learn a bit about his method of solving mixture models. I also figured it would be a good way to learn sympy a little better.
The basic gist of this notebook is that moments for a mixture of gaussians are nice polynomials in terms of means and variance. Thus if you calculate the moments in your data, you can use those polynomials to get the formulas for the parameters of the mixture model. But to have the numerical solver converge, it requires picking points somewhat near the real solution. Also, I don't know much about stats. Feel free to educate me.
I'm going to start writing more algorithmshop.com stuff just in IPython notebooks. This might make its way into a post someday, but I would rather just put things out there and talk to folks. I'm justinvf on Twitter. If you live in the bay (or the internet) and want to talk math, please reach out! I think I made my site nicer than I want to commit to regularly. I would rather be learning stuff than making my site pretty.
from sympy.stats import Normal, moment
from sympy import Symbol, simplify, init_printing, symbols, pretty_print
from sympy.solvers import nsolve
init_printing()
import numpy as np
from numpy.linalg import norm
np.random.seed(43)
from matplotlib import pyplot as plt
%pylab inline
Populating the interactive namespace from numpy and matplotlib
# The moments to use in the calculations
MOMENTS_USED = [1,2,3,4,5]
# The parameters of the 2 gaussians, and the mixture param (mu_1, sigma_2, mu_2, sigma_2, r)
REAL_PARAMS = (0, 2, 10, 3, .7)
NUM_POINTS = 5000
mu = Symbol('mu')
sigma = Symbol('sigma', positive=True)
X = Normal('X', mu, sigma)
moments = [simplify(moment(X,i)) for i in MOMENTS_USED]
for (i, m) in zip(MOMENTS_USED, moments):
print('\nMoment {}'.format(i))
pretty_print(m)
Moment 1 μ Moment 2 2 2 μ + σ Moment 3 ⎛ 2 2⎞ μ⋅⎝μ + 3⋅σ ⎠ Moment 4 4 2 2 4 μ + 6⋅μ ⋅σ + 3⋅σ Moment 5 ⎛ 4 2 2 4⎞ μ⋅⎝μ + 10⋅μ ⋅σ + 15⋅σ ⎠
Now let's generate data from a gaussian mixture
def generate_data(mu_1, sigma_1, mu_2, sigma_2, r, n):
"""mu and sigmas are for the two gaussians.
r is the probability of picking from from the first gaussian.
n is the number of points to generate
"""
s = np.random.rand(n) < r
return np.select([s, ~s],
[np.random.normal(mu_1, sigma_1, n),
np.random.normal(mu_2, sigma_2, n)])
fake_data = generate_data(*REAL_PARAMS, n=NUM_POINTS)
plt.hist(fake_data, bins=50);
If we assume our data is a mixture of two gaussians, N_1 and N_2, with r the mixture parameter, then we know what the moments should look like for the data.
mu_1, sigma_1, mu_2, sigma_2, r = symbols('mu_1 sigma_1 mu_2 sigma_2 r')
def mixture_moment(m):
return ( r * (m.subs({mu:mu_1, sigma: sigma_1}))
+ (1 - r) * (m.subs({mu:mu_2, sigma: sigma_2})))
mixture_moments = [mixture_moment(m) for m in moments]
for (i, m) in zip(MOMENTS_USED, mixture_moments):
print('\nMixture Moment {}'.format(i))
pretty_print(m)
Mixture Moment 1 μ₁⋅r + μ₂⋅(-r + 1) Mixture Moment 2 ⎛ 2 2⎞ ⎛ 2 2⎞ r⋅⎝μ₁ + σ₁ ⎠ + ⎝μ₂ + σ₂ ⎠⋅(-r + 1) Mixture Moment 3 ⎛ 2 2⎞ ⎛ 2 2⎞ μ₁⋅r⋅⎝μ₁ + 3⋅σ₁ ⎠ + μ₂⋅⎝μ₂ + 3⋅σ₂ ⎠⋅(-r + 1) Mixture Moment 4 ⎛ 4 2 2 4⎞ ⎛ 4 2 2 4⎞ r⋅⎝μ₁ + 6⋅μ₁ ⋅σ₁ + 3⋅σ₁ ⎠ + (-r + 1)⋅⎝μ₂ + 6⋅μ₂ ⋅σ₂ + 3⋅σ₂ ⎠ Mixture Moment 5 ⎛ 4 2 2 4⎞ ⎛ 4 2 2 4⎞ μ₁⋅r⋅⎝μ₁ + 10⋅μ₁ ⋅σ₁ + 15⋅σ₁ ⎠ + μ₂⋅(-r + 1)⋅⎝μ₂ + 10⋅μ₂ ⋅σ₂ + 15⋅σ₂ ⎠
def numerical_moment(a, n):
return np.sum(a ** n) / len(a)
actual_moments = [numerical_moment(fake_data, i) for i in MOMENTS_USED]
for (i, m) in zip(MOMENTS_USED, actual_moments):
print('\nActual Moment {}: {}'.format(i, m))
Actual Moment 1: 3.033243007556774 Actual Moment 2: 35.429816643767715 Actual Moment 3: 378.3815897586502 Actual Moment 4: 4697.59347129467 Actual Moment 5: 60793.208071695706
We have 5 moments, and 5 unknows (r and the 2 parameters for the 2 normal distributions). So now we should be able to use the numerical momemnts to solve for the parameters.
# These should all be solved for zero
equations = [m - value for (m,value) in zip(mixture_moments, actual_moments)]
for (i, m) in zip(MOMENTS_USED, equations):
print('\nEqn for moment {}'.format(i))
pretty_print(m)
Eqn for moment 1 μ₁⋅r + μ₂⋅(-r + 1) - 3.03324300755677 Eqn for moment 2 ⎛ 2 2⎞ ⎛ 2 2⎞ r⋅⎝μ₁ + σ₁ ⎠ + ⎝μ₂ + σ₂ ⎠⋅(-r + 1) - 35.4298166437677 Eqn for moment 3 ⎛ 2 2⎞ ⎛ 2 2⎞ μ₁⋅r⋅⎝μ₁ + 3⋅σ₁ ⎠ + μ₂⋅⎝μ₂ + 3⋅σ₂ ⎠⋅(-r + 1) - 378.38158975865 Eqn for moment 4 ⎛ 4 2 2 4⎞ ⎛ 4 2 2 4⎞ r⋅⎝μ₁ + 6⋅μ₁ ⋅σ₁ + 3⋅σ₁ ⎠ + (-r + 1)⋅⎝μ₂ + 6⋅μ₂ ⋅σ₂ + 3⋅σ₂ ⎠ - 4697.593471 29467 Eqn for moment 5 ⎛ 4 2 2 4⎞ ⎛ 4 2 2 4⎞ μ₁⋅r⋅⎝μ₁ + 10⋅μ₁ ⋅σ₁ + 15⋅σ₁ ⎠ + μ₂⋅(-r + 1)⋅⎝μ₂ + 10⋅μ₂ ⋅σ₂ + 15⋅σ₂ ⎠ - 6 0793.2080716957
print("True Paramaters for mu_1, sigma_1, mu_2, sigma_2, r: {}".format(REAL_PARAMS))
True Paramaters for mu_1, sigma_1, mu_2, sigma_2, r: (0, 2, 10, 3, 0.7)
def solve_numerically(equations, initial_guess):
# If I start off near-ish to the point, then we can solve it numerically.
solved = nsolve(equations, (mu_1, sigma_1, mu_2, sigma_2, r), list(initial_guess))
print('Numeric solution')
pretty_print(solved.T)
# Blarg. I get arbitrary precesion stuff out and I don't need that
to_float_array = lambda mfp_array: np.array(list(map(float, mfp_array)))
distance_to_real_array = norm(np.array(REAL_PARAMS) - to_float_array(solved.T))
# Not sure how I should really be talking about distance...
print('L2 distance from true solution (after sampling {} sample points): {}'.format(
NUM_POINTS, distance_to_real_array ))
If we just solve numerically starting at the real parameters, then the error is just the sampling error
solve_numerically(equations, np.array(REAL_PARAMS))
Numeric solution [-0.0964001259983833 1.93947261516827 9.45613788320574 3.31424965505908 0. 672375746577542] L2 distance from true solution (after sampling 5000 sample points): 0.6389510902822756
If we move the initial guess slightly we are fine
solve_numerically(equations, np.array(REAL_PARAMS) + np.array([.3, .2, .2, .2, 0]))
Numeric solution [-0.0964001259983833 1.93947261516827 9.45613788320574 3.31424965505908 0. 672375746577542] L2 distance from true solution (after sampling 5000 sample points): 0.6389510902822756
But then if we get a little farther....
solve_numerically(equations, np.array(REAL_PARAMS) + np.array([2, 1, -1, 1, .2]))
--------------------------------------------------------------------------- ValueError Traceback (most recent call last) <ipython-input-15-353227e314b9> in <module>() ----> 1 solve_numerically(equations, np.array(REAL_PARAMS) + np.array([2, 1, -1, 1, .2])) <ipython-input-12-dd1f346d468d> in solve_numerically(equations, initial_guess) 1 def solve_numerically(equations, initial_guess): 2 # If I start off near-ish to the point, then we can solve it numerically. ----> 3 solved = nsolve(equations, (mu_1, sigma_1, mu_2, sigma_2, r), list(initial_guess)) 4 print('Numeric solution') 5 pretty_print(solved.T) /home/justinvf/anaconda3/lib/python3.4/site-packages/sympy/solvers/solvers.py in nsolve(*args, **kwargs) 2473 J = lambdify(fargs, J, modules) 2474 # solve the system numerically -> 2475 x = findroot(f, x0, J=J, **kwargs) 2476 return x 2477 /home/justinvf/anaconda3/lib/python3.4/site-packages/sympy/mpmath/calculus/optimization.py in findroot(ctx, f, x0, solver, tol, verbose, verify, **kwargs) 973 '(%g > %g)\n' 974 'Try another starting point or tweak arguments.' --> 975 % (norm(f(*xl))**2, tol)) 976 return x 977 finally: ValueError: Could not find root within given tolerance. (1278.68 > 2.1684e-19) Try another starting point or tweak arguments.
If we looked at the data, we could guess some obvious initial guesses for the mus and sigmas. As we had in the histogram above:
plt.hist(fake_data, bins=50);
The left hump looks a little move massive, so I would guess .7 for $r$. Then $\mu_1$ looks to be around 0, $\mu_2$ around 10. The I would guess a larger standard deviation for the second.
solve_numerically(equations, np.array([0, 3, 10, 5, .7]))
Numeric solution [-0.0964001259983833 1.93947261516827 9.45613788320574 3.31424965505908 0. 672375746577542] L2 distance from true solution (after sampling 5000 sample points): 0.6389510902822756
And that works... But it's not satisfying. I should read more papers about this :).
My main questions left from this exercise (many probably probably obvious, I know like zero stats):