EM as a Message Passing Algorithm

Preliminaries

  • Goals
    • Describe Expectation-Maximization (EM) as a message passing algorithm on a Forney-style factor graph
  • Materials

A Problem for the Multiplier Node

  • Consider the multiplier factor $f(x,y,\theta) = \delta(y-\theta x)$ with incoming Gaussian messages $\overrightarrow{\mu}_X(x) = \mathcal{N}(x|m_x,v_x)$ and $\overleftarrow{\mu}_Y(y) = \mathcal{N}(y|m_y,v_y)$. For simplicity's sake, we assume all variables are scalar.

  • In a system identification setting, we are interested in computing the outgoing message $\overleftarrow{\mu}_\Theta(\theta)$.

  • Let's compute the sum-product message:

$$\begin{align*} \overleftarrow{\mu}_\Theta(\theta) &= \int \overrightarrow{\mu}_X(x) \, \overleftarrow{\mu}_Y(y) \, f(x,y,\theta) \, \mathrm{d}x \mathrm{d}y \\ &= \int \mathcal{N}(x\,|\,m_x,v_x) \, \mathcal{N}(y\,|\,m_y,v_y) \, \delta(y-\theta x)\, \, \mathrm{d}x \mathrm{d}y \\ &= \int \mathcal{N}(x\,|\,m_x,v_x) \,\mathcal{N}(\theta x\,|\,m_y,v_y) \, \mathrm{d}x \\ &= \int \mathcal{N}(x\,|\,m_x,v_x) \,\mathcal{N}\left(x \,\bigg|\, \frac{m_y}{\theta},\frac{v_y}{\theta^2}\right) \, \mathrm{d}x \\ &= \mathcal{N}\left(\frac{m_y}{\theta} \,\bigg|\, m_x, v_x + \frac{v_y}{\theta^2}\right) \cdot \int \mathcal{N}(x\,|\,m_*,v_*)\, \mathrm{d}x \tag{SRG-6} \\ &= \mathcal{N}\left(\frac{m_y}{\theta} \,\bigg|\, m_x, v_x + \frac{v_y}{\theta^2}\right) \end{align*}$$
  • This is not a Gaussian message for $\Theta$! Passing this message into the graph leads to very serious problems when trying to compute sum-product messages for other factors in the graph.

    • (We have seen before in the lesson on Working with Gaussians that multiplication of two Gaussian-distributed variables does not produce a Gaussian distributed variable.)
  • The same problem occurs in a forward message passing schedule when we try to compute a message for $Y$ from incoming Gaussian messages for both $X$ and $\Theta$.

Limitations of Sum-Product Messages

  • The foregoing example shows that the sum-product (SP) message update rule will sometimes not do the job. For example:

    • On large-dimensional discrete domains, the SP update rule maybe computationally intractable.
    • On continuous domains, the SP update rule may not have a closed-form solution or the rule may lead to a function that is incompatible with Gaussian message passing.
  • There are various ways to cope with 'intractable' SP update rules. In this lesson, we discuss how the EM-algorithm can be written as a message passing algorithm on factor graphs. Then, we will solve the 'multiplier node problem' with EM messages (rather than with SP messages).

EM as Message Passing

  • Consider first a general setting with likelihood function $f(x,\theta)$, hidden variables $x$ and tuning parameters $\theta$. Assume that we are interested in the maximum likelihood estimate
$$\begin{align*} \hat{\theta} &= \arg\max_\theta \int f(x,\theta) \mathrm{d}x\,. \end{align*}$$
  • If $\int f(x,\theta) \mathrm{d}x$ is intractible, we can try to apply the EM-algorithm to estimate $\hat{\theta}$, which leads to the following iterations (cf. lesson on the EM algorithm):
$$ \hat{\theta}^{(k+1)} = \underbrace{\arg\max_\theta}_{\text{M-step}} \left( \underbrace{\int_x f(x,\hat{\theta}^{(k)})\,\log f(x,\theta)\,\mathrm{d}x}_{\text{E-step}} \right) $$
  • It turns out that for factorized functions $f(x,\theta)$, the EM-algorihm can be executed as a message passing algorithm on the factor graph.

  • As an simple example, we consider the factorization

$$ f(x,\theta) = f_a(\theta)f_b(x,\theta) $$

  • Applying the EM-algorithm to this graph leads to the following forward and backward messages over the $\theta$ edge $$\begin{align*} \textbf{E-step}&: \quad \eta(\theta) = \int p_b(x|\hat{\theta}^{(k)}) \log f_b(x,\theta) \,\mathrm{d}x \\ \textbf{M-step}&: \quad \hat{\theta}^{(k+1)} = \arg\max_\theta \left( f_a(\theta)\, e^{\eta(\theta)}\right) \end{align*}$$ where $p_b(x|\hat{\theta}^{(k)}) \triangleq \frac{f_b(x,\hat{\theta}^{(k)})}{\int f_b(x^\prime,\hat{\theta}^{(k)}) \,\mathrm{d}x^\prime}$.
    Proof:

    $$\begin{align*} \hat{\theta}^{(k+1)} &= \arg\max_\theta \, \int_x f(x,\hat{\theta}^{(k)}) \,\log f(x,\theta)\,\mathrm{d}x \\ &= \arg\max_\theta \, \int_x f_a(\theta)f_b(x,\hat{\theta}^{(k)}) \,\log \left( f_a(\theta)f_b(x,\theta) \right) \,\mathrm{d}x \\ &= \arg\max_\theta \, \int_x f_b(x,\hat{\theta}^{(k)}) \cdot \left( \log f_a(\theta) + \log f_b(x,\theta) \right) \,\mathrm{d}x \\ &= \arg\max_\theta \left( \log f_a(\theta) + \frac{\int f_b(x,\hat{\theta}^{(k)}) \log f_b(x,\theta) \,\mathrm{d}x }{\int f_b(x^\prime,\hat{\theta}^{(k)}) \,\mathrm{d}x^\prime} \right) \\ &= \arg\max_\theta \left( \log f_a(\theta) + \underbrace{\int p_b(x|\hat{\theta}^{(k)}) \log f_b(x,\theta) \,\mathrm{d}x}_{\eta(\theta)} \right) \\ &= \underbrace{\arg\max_\theta}_{\text{M-step}} \left( f_a(\theta)\,\underbrace{e^{\eta(\theta)}}_{\text{E-step}} \right) \end{align*}$$
  • The messages represent the 'E' and 'M' steps, respectively:

  • The quantity $\eta(\theta)$ (a.k.a. the E-log message) may be interpreted as a log-domain summary of $f_b$. The message $e^{\eta(\theta)}$ is the corresponding 'probability domain' message that is consistent with the semantics of messages as summaries of factors. In a software implementation, you can use either domain, as long as a consistent method is chosen.

  • Note that the denominator $\int f_b(x^\prime,\hat{\theta}^{(k)}) \,\mathrm{d}x^\prime$ in $p_b$ is just a scaling factor that can usually be ignored, leading to a simpler E-log message $$\eta(\theta) = \int f_b(x,\hat{\theta}^{(k)}) \log f_b(x,\theta) \,\mathrm{d}x \,.$$

EM vs SP and MP Message Passing

  • Consifer again the likelihood model $f(x,\theta)$ with $x$ a set of hidden variables. We are interested in the ML estimate
$$ \hat{\theta} = \arg\max_\theta \int f(x,\theta) \mathrm{d}x\,. $$
  • Recall that in a 'regular' (not message passing) setting, the EM-algorithm is particularly useful when the expectation (E-step) $$ \eta(\theta) = \int_x f(x,\hat{\theta}^{(k)})\,\log f(x,\theta)\,\mathrm{d}x $$ leads to easier expressions than the marginalization (which is what we really want) $$ \bar f(\theta) = \int f(x,\theta) \mathrm{d}x . $$

  • Similarly, in a message passing framework with connected nodes $f_a$ and $f_b$, EM messages are particularly useful when the expectation (represented by the E-log message) $$ \eta(\theta) = \int f_b(x|\hat{\theta}^{(k)}) \log f_b(x,\theta) \,\mathrm{d}x $$ leads to easier expressions than the marginalization (represented by the sum-product message, which is also what we really want) $$ \mu(\theta) = \int f_b(x,\theta) \mathrm{d}x . $$

  • Just as for the sum-product (SP) and max-product (MP) messages, we can work out the outgoing E-log message on the $Y$ edge for a general node $f(x_1,\ldots,x_M,y)$ with given message inputs $\overrightarrow{\mu}_{X_m}(x_m)$ (see also Dauwels et al. (2009), Table-1, pg.4):

$$\begin{align*} \textbf{SP}:&\;\;\overrightarrow{\mu}(y) = \int \overrightarrow{\mu}_{X_1}(x_1) \cdots \overrightarrow{\mu}_{X_M}(x_M)\, f(x_1,\ldots,x_M,y) \, \mathrm{d}x_1 \ldots \mathrm{d}x_M \\ \textbf{MP}:&\;\;\hat{y} = \arg\max_{x_1,\ldots,x_M} \overrightarrow{\mu}_{X_1}(x_1) \cdots \overrightarrow{\mu}_{X_M}(x_M)\, f(x_1,\ldots,x_M,y) \\ \textbf{E-log}:&\;\;\overrightarrow{\eta}(y) = \int p(x_1,\ldots,x_M | y^{(k)})\,\log f(x_1,\ldots,x_M,y) \, \mathrm{d}x_1 \ldots \mathrm{d}x_M \end{align*}$$

where $p(x_1,\ldots,x_M | y^{(k)}) \triangleq \frac{\overrightarrow{\mu}_{X_1}(x_1) \cdots \overrightarrow{\mu}_{X_M}(x_M)\, f(x_1,\ldots,x_M,\hat{y}^{(k)})}{\int \overrightarrow{\mu}_{X_1}(x_1) \cdots \overrightarrow{\mu}_{X_M}(x_M)\, f(x_1,\ldots,x_M,\hat{y}^{(k)}) \, \mathrm{d}x_1 \ldots \mathrm{d}x_M}$.

  • **Exercise**: proof the generic E-log message update rule.

A Snag for EM Message Passing on Deterministic Nodes

  • The factors for deterministic nodes are (Dirac) delta functions, e.g., $\delta(y-\theta x)$ for the multiplier.

  • Note that the outgoing E-log message for a deterministic node will also be a delta function, since the expectation of $\log \delta(\cdot)$ is again a delta function. For details, consult Dauwels et al. (2009) pg.5, section F.

  • This would stall the iterative estimation process at the current estimate since the outgoing E-log message would express complete certainty about the estimate.

  • This issue can be resolved by closing a box around a subgraph that includes (the deterministic node) $f$ and at least one non-deterministic factor. EM message passing can now proceed with the newly created node.

A Solution for the Multiplier Node with Unknown Coefficient

  • We get back no to the original problem in this lesson. Consider again the (scalar) multiplier with unknown coefficient $f(x,y,\theta) = \delta(y-\theta x)$ and incoming messages $\overrightarrow{\mu_X}(x) = \mathcal{N}(x|m_x,v_x)$ and $\overleftarrow{\mu_Y}(y) = \mathcal{N}(y|m_y,v_y)$. We will now compute the outgoing E-log message for $\Theta$.

  • Since $f(x,y,\theta)$ is deterministic, we will first group $f$ with the (non-deterministic) node $\overleftarrow{\mu_Y}(y) = \mathcal{N}(y|m_y,v_y)$, leading (through sum-product rule) to $$\begin{align*} g(x,\theta) &\triangleq \int \overleftarrow{\mu_Y}(y)\, f(x,y,\theta) \,\mathrm{d}y \\ &= \int \mathcal{N}(y|m_y,v_y)\, \delta(y-\theta x) \,\mathrm{d}y \\ &= \mathcal{N}(\theta x\mid m_y,v_y)\,. \end{align*}$$

  • The problem now is to pass an E-log message out of $g(x,\theta)$. Assume that $g$ has received an estimate $\hat{\theta}$ from the incoming message over the $\Theta$ edge. The E-log update rule then prescribes $$\begin{align*} \eta(\theta) &= \mathbb{E}\left[ \log g(x,\theta) \right] \\ &= \mathbb{E}\left[ \mathcal{N}(\theta x|m_y,v_y) \right] \\ &= \text{const.} - \frac{1}{2v_y}\, \left( \mathbb{E}[X^2] \theta^2 - 2 m_y \mathbb{E}[X] \theta + m_y^2\right) \\ &\propto \mathcal{N}_{\xi} \left( \theta \,\bigg|\, \frac{m_y \mathbb{E}\left[X\right]}{v_y}, \frac{\mathbb{E}\left[X^2\right]}{v_y} \right) \end{align*}$$ where we used the 'canonical' parametrization of the Gaussian $\mathcal{N}_{\xi}(\theta \mid\xi,w) \propto \exp \left( \xi \theta- \frac{1}{2} w \theta^2\right)$.

  • In the E-log message update rule, the expections $\mathbb{E}\left[X\right]$ and $\mathrm{E}\left[X^2\right]$ have to be taken w.r.t. $ p(x|\hat{\theta}) = \overrightarrow{\mu_X}(x)\,g(x,\hat{\theta})$ (consult the generic E-log update rule). A straightforward (but rather painful) derivation leads to $$\begin{align*} p(x \mid \hat{\theta}) &= \overrightarrow{\mu_X}(x)\,g(x,\hat{\theta}) \\ &= \mathcal{N}(x \mid m_x,v_x)\cdot \mathcal{N}(\hat{\theta} x \mid m_y,v_y) \\ &= \mathcal{N}(x \mid m_x,v_x)\cdot \mathcal{N}\left(x \,\bigg| \,\frac{m_y}{\hat{\theta} },\frac{v_y}{\hat{\theta^2}} \right) \\ &\propto \mathcal{N_\xi}( x \mid \xi_g , w_g) \end{align*}$$ where $w_g = \frac{1}{v_x} + \frac{\hat{\theta^2}}{v_y}$ and $\xi_g \triangleq w_g m_g = \frac{m_x}{v_x}+\frac{\hat{\theta}m_y}{v_y}$. It follows that
    $$\begin{align*} \mathbb{E}\left[X\right] &= m_g \\ \mathbb{E}\left[X^2\right] &= m_g^2 + w_g^{-1} \end{align*}$$

  • $\Rightarrow$ The E-log update formula may not be fun to derive, but the result is very pleasing: the E-log message for the multiplier with unknown coefficient is a Gaussian message with closed-form expressions for its parameters! See also Dauwels et al. (2009) Table-2, pg.6.

Automating Inference

  • It follows that, for a dynamical system with unknown coefficients, both state estimation and parameter learning can be achieved through Gaussian message passing based on SP and EM message update rules.

  • These (SP and EM) message update rules can be tabularized and implemented in software for a large set of factors that are common in probabilistic models. (See the tables in Loeliger et al. (2007) and Dauwels et al. (2009)).

  • Tabulated SP and EM messages for frequently occuring factors facilitate the automated derivation of nontrivial inference algorithms.

  • This makes it possible to automate inference for state and parameter estimation in very complex probabilistic model. Here (in the SPS group at TU/e), we are developing such a factor graph toolbox in Julia.

  • There is lots more to say about factor graphs. This is a very exciting area of research that promises both

    1. to consolidate a wide range of signal processing and machine learning algorithms in one elegant framework
    2. to automate inference and learning in new models that have previously been untractable for existing machine learning methods.

Example: Linear Dynamical Systems

As before let us consider the linear dynamical system (LDS)

$$\begin{align*} z_n &= A z_{n-1} + w_n \\ x_n &= C z_n + v_n \\ w_n &\sim \mathcal{N}(0,\Sigma_w) \\ v_n &\sim \mathcal{N}(0,\Sigma_v) \end{align*}$$

Again, we will consider the case where $x_n$ is observed and $z_n$ is a hidden state. $C$, $\Sigma_w$ and $\Sigma_v$ are given parameters but in contrast to the previous section, we will assume that the value of parameter $A$ is unknown.


The cell below loads the style file

In [1]:
open("../../styles/aipstyle.html") do f
    display("text/html", read(f,String))
end
In [ ]: