Estimating Mixture Models from Moments

Justin Vincent (justinvf)

An algorithmshop.com production

I was inspired by this blog post about James Pearson to learn a bit about his method of solving mixture models. I also figured it would be a good way to learn sympy a little better.

The basic gist of this notebook is that moments for a mixture of gaussians are nice polynomials in terms of means and variance. Thus if you calculate the moments in your data, you can use those polynomials to get the formulas for the parameters of the mixture model. But to have the numerical solver converge, it requires picking points somewhat near the real solution. Also, I don't know much about stats. Feel free to educate me.

I'm going to start writing more algorithmshop.com stuff just in IPython notebooks. This might make its way into a post someday, but I would rather just put things out there and talk to folks. I'm justinvf on Twitter. If you live in the bay (or the internet) and want to talk math, please reach out! I think I made my site nicer than I want to commit to regularly. I would rather be learning stuff than making my site pretty.

In [1]:
from sympy.stats import Normal, moment
from sympy import Symbol, simplify, init_printing, symbols, pretty_print
from sympy.solvers import nsolve
init_printing()

import numpy as np
from numpy.linalg import norm
np.random.seed(43)

from matplotlib import pyplot as plt
%pylab inline
Populating the interactive namespace from numpy and matplotlib
In [2]:
# The moments to use in the calculations
MOMENTS_USED = [1,2,3,4,5]

# The parameters of the 2 gaussians, and the mixture param (mu_1, sigma_2, mu_2, sigma_2, r)
REAL_PARAMS = (0, 2, 10, 3, .7)
NUM_POINTS = 5000
In [3]:
mu = Symbol('mu')
sigma = Symbol('sigma', positive=True)
X = Normal('X', mu, sigma)

moments = [simplify(moment(X,i)) for i in MOMENTS_USED]

for (i, m) in zip(MOMENTS_USED, moments):
    print('\nMoment {}'.format(i))
    pretty_print(m)
Moment 1
μ

Moment 2
 2    2
μ  + σ 

Moment 3
  ⎛ 2      2⎞
μ⋅⎝μ  + 3⋅σ ⎠

Moment 4
 4      2  2      4
μ  + 6⋅μ ⋅σ  + 3⋅σ 

Moment 5
  ⎛ 4       2  2       4⎞
μ⋅⎝μ  + 10⋅μ ⋅σ  + 15⋅σ ⎠

Now let's generate data from a gaussian mixture

In [4]:
def generate_data(mu_1, sigma_1, mu_2, sigma_2, r, n):
    """mu and sigmas are for the two gaussians.
       r is the probability of picking from from the first gaussian.
       n is the number of points to generate
    """
    s = np.random.rand(n) < r
    return np.select([s, ~s],
                [np.random.normal(mu_1, sigma_1, n),
                 np.random.normal(mu_2, sigma_2, n)])
In [5]:
fake_data = generate_data(*REAL_PARAMS, n=NUM_POINTS)
In [6]:
plt.hist(fake_data, bins=50);

If we assume our data is a mixture of two gaussians, N_1 and N_2, with r the mixture parameter, then we know what the moments should look like for the data.

In [7]:
mu_1, sigma_1, mu_2, sigma_2, r = symbols('mu_1 sigma_1 mu_2 sigma_2 r')
In [8]:
def mixture_moment(m):
    return (        r * (m.subs({mu:mu_1, sigma: sigma_1}))
            + (1 - r) * (m.subs({mu:mu_2, sigma: sigma_2})))

mixture_moments = [mixture_moment(m) for m in moments]

for (i, m) in zip(MOMENTS_USED, mixture_moments):
    print('\nMixture Moment {}'.format(i))
    pretty_print(m)
Mixture Moment 1
μ₁⋅r + μ₂⋅(-r + 1)

Mixture Moment 2
  ⎛  2     2⎞   ⎛  2     2⎞         
r⋅⎝μ₁  + σ₁ ⎠ + ⎝μ₂  + σ₂ ⎠⋅(-r + 1)

Mixture Moment 3
     ⎛  2       2⎞      ⎛  2       2⎞         
μ₁⋅r⋅⎝μ₁  + 3⋅σ₁ ⎠ + μ₂⋅⎝μ₂  + 3⋅σ₂ ⎠⋅(-r + 1)

Mixture Moment 4
  ⎛  4       2   2       4⎞            ⎛  4       2   2       4⎞
r⋅⎝μ₁  + 6⋅μ₁ ⋅σ₁  + 3⋅σ₁ ⎠ + (-r + 1)⋅⎝μ₂  + 6⋅μ₂ ⋅σ₂  + 3⋅σ₂ ⎠

Mixture Moment 5
     ⎛  4        2   2        4⎞               ⎛  4        2   2        4⎞
μ₁⋅r⋅⎝μ₁  + 10⋅μ₁ ⋅σ₁  + 15⋅σ₁ ⎠ + μ₂⋅(-r + 1)⋅⎝μ₂  + 10⋅μ₂ ⋅σ₂  + 15⋅σ₂ ⎠
In [9]:
def numerical_moment(a, n):
    return np.sum(a ** n) / len(a)

actual_moments = [numerical_moment(fake_data, i) for i in MOMENTS_USED]

for (i, m) in zip(MOMENTS_USED, actual_moments):
    print('\nActual Moment {}: {}'.format(i, m))
Actual Moment 1: 3.033243007556774

Actual Moment 2: 35.429816643767715

Actual Moment 3: 378.3815897586502

Actual Moment 4: 4697.59347129467

Actual Moment 5: 60793.208071695706

We have 5 moments, and 5 unknows (r and the 2 parameters for the 2 normal distributions). So now we should be able to use the numerical momemnts to solve for the parameters.

In [10]:
# These should all be solved for zero
equations = [m - value for (m,value) in zip(mixture_moments, actual_moments)]
for (i, m) in zip(MOMENTS_USED, equations):
    print('\nEqn for moment {}'.format(i))
    pretty_print(m)
Eqn for moment 1
μ₁⋅r + μ₂⋅(-r + 1) - 3.03324300755677

Eqn for moment 2
  ⎛  2     2⎞   ⎛  2     2⎞                            
r⋅⎝μ₁  + σ₁ ⎠ + ⎝μ₂  + σ₂ ⎠⋅(-r + 1) - 35.4298166437677

Eqn for moment 3
     ⎛  2       2⎞      ⎛  2       2⎞                           
μ₁⋅r⋅⎝μ₁  + 3⋅σ₁ ⎠ + μ₂⋅⎝μ₂  + 3⋅σ₂ ⎠⋅(-r + 1) - 378.38158975865

Eqn for moment 4
  ⎛  4       2   2       4⎞            ⎛  4       2   2       4⎞              
r⋅⎝μ₁  + 6⋅μ₁ ⋅σ₁  + 3⋅σ₁ ⎠ + (-r + 1)⋅⎝μ₂  + 6⋅μ₂ ⋅σ₂  + 3⋅σ₂ ⎠ - 4697.593471

     
29467

Eqn for moment 5
     ⎛  4        2   2        4⎞               ⎛  4        2   2        4⎞    
μ₁⋅r⋅⎝μ₁  + 10⋅μ₁ ⋅σ₁  + 15⋅σ₁ ⎠ + μ₂⋅(-r + 1)⋅⎝μ₂  + 10⋅μ₂ ⋅σ₂  + 15⋅σ₂ ⎠ - 6

               
0793.2080716957
In [11]:
print("True Paramaters for mu_1, sigma_1, mu_2, sigma_2, r: {}".format(REAL_PARAMS))
True Paramaters for mu_1, sigma_1, mu_2, sigma_2, r: (0, 2, 10, 3, 0.7)
In [12]:
def solve_numerically(equations, initial_guess):
    # If I start off near-ish to the point, then we can solve it numerically.
    solved = nsolve(equations, (mu_1, sigma_1, mu_2, sigma_2, r), list(initial_guess))
    print('Numeric solution')
    pretty_print(solved.T)

    # Blarg. I get arbitrary precesion stuff out and I don't need that
    to_float_array = lambda mfp_array: np.array(list(map(float, mfp_array)))

    distance_to_real_array = norm(np.array(REAL_PARAMS) - to_float_array(solved.T))

    # Not sure how I should really be talking about distance...
    print('L2 distance from true solution (after sampling {} sample points): {}'.format(
        NUM_POINTS, distance_to_real_array ))

If we just solve numerically starting at the real parameters, then the error is just the sampling error

In [13]:
solve_numerically(equations, np.array(REAL_PARAMS))
Numeric solution
[-0.0964001259983833  1.93947261516827  9.45613788320574  3.31424965505908  0.
672375746577542]
L2 distance from true solution (after sampling 5000 sample points): 0.6389510902822756

If we move the initial guess slightly we are fine

In [14]:
solve_numerically(equations, np.array(REAL_PARAMS) + np.array([.3, .2, .2, .2, 0]))
Numeric solution
[-0.0964001259983833  1.93947261516827  9.45613788320574  3.31424965505908  0.
672375746577542]
L2 distance from true solution (after sampling 5000 sample points): 0.6389510902822756

But then if we get a little farther....

In [15]:
solve_numerically(equations, np.array(REAL_PARAMS) + np.array([2, 1, -1, 1, .2]))
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-15-353227e314b9> in <module>()
----> 1 solve_numerically(equations, np.array(REAL_PARAMS) + np.array([2, 1, -1, 1, .2]))

<ipython-input-12-dd1f346d468d> in solve_numerically(equations, initial_guess)
      1 def solve_numerically(equations, initial_guess):
      2     # If I start off near-ish to the point, then we can solve it numerically.
----> 3     solved = nsolve(equations, (mu_1, sigma_1, mu_2, sigma_2, r), list(initial_guess))
      4     print('Numeric solution')
      5     pretty_print(solved.T)

/home/justinvf/anaconda3/lib/python3.4/site-packages/sympy/solvers/solvers.py in nsolve(*args, **kwargs)
   2473     J = lambdify(fargs, J, modules)
   2474     # solve the system numerically
-> 2475     x = findroot(f, x0, J=J, **kwargs)
   2476     return x
   2477 

/home/justinvf/anaconda3/lib/python3.4/site-packages/sympy/mpmath/calculus/optimization.py in findroot(ctx, f, x0, solver, tol, verbose, verify, **kwargs)
    973                              '(%g > %g)\n'
    974                              'Try another starting point or tweak arguments.'
--> 975                              % (norm(f(*xl))**2, tol))
    976         return x
    977     finally:

ValueError: Could not find root within given tolerance. (1278.68 > 2.1684e-19)
Try another starting point or tweak arguments.

If we looked at the data, we could guess some obvious initial guesses for the mus and sigmas. As we had in the histogram above:

In [16]:
plt.hist(fake_data, bins=50);

The left hump looks a little move massive, so I would guess .7 for $r$. Then $\mu_1$ looks to be around 0, $\mu_2$ around 10. The I would guess a larger standard deviation for the second.

In [17]:
solve_numerically(equations, np.array([0, 3, 10, 5, .7]))
Numeric solution
[-0.0964001259983833  1.93947261516827  9.45613788320574  3.31424965505908  0.
672375746577542]
L2 distance from true solution (after sampling 5000 sample points): 0.6389510902822756

And that works... But it's not satisfying. I should read more papers about this :).

My main questions left from this exercise (many probably probably obvious, I know like zero stats):

  1. What distance metric is appropriate when talking about the distance between distribution parameters.
  2. How does the number of gaussians increase the difficulty for the numerical solver.
  3. I could probably pick means by doing a sliding window over the points. As long as the gaussians have some distance it shouldn't be too hard to get good guesses for standard deviation either.
  4. I just picked the first 5 moments. I could have used later moments (the MOMENTS_USED constant). Using higher order moments increases the error on the parameter estimates though (4x for using 2-6 as opposed to 1-5). If this technique were extended to 3 gaussians though, it would necesitate using these higher order moments. That seems damming.
  5. How will NUM_POINTS affect the error?
  6. Could the central moments be of use? How do the central moments of the individual gaussians relate the central moments of the full distribution?