Probabilistic Programming 4: Latent Variable & Dynamic Models


  • Understand how to estimate latent variables in models.
  • Understand how to estimate states in dynamical models.


Note that none of the material below is new. The point of the Probabilistic Programming sessions is to solve practical problems so that the concepts from Bert's lectures become less abstract.

In [1]:
using Pkg
In [2]:
using JLD
using Statistics
using StatsBase
using LinearAlgebra
using ProgressMeter
using LaTeXStrings
using ForneyLab
using Plots

import LinearAlgebra: I
import ForneyLab: unsafeMean

Problem: Stone Tools

Archeologists have asked for your help in analyzing data on stone tools. It is believed that primitive humans created tools by striking stones with others. During this process, the stone loses flakes, which have been preserved. The archeologists have recovered these flakes from various locations and time periods and want to know whether this stone tool shaping process has improved over the centuries.


The data is available from the UCI Machine Learning Repository. Each instance represents summary information of the stone flakes for a particular site. We will be using the attributes flaking angle (FLA) and the proportion of the dorsal surface worked (PROZD) for now.

In [3]:
dataset = load("../datasets/stoneflakes.jld");

I've done some pre-processing on the data set, namely z-scoring and removing two outliers. This reduces the scale of the attributes which helps numerical stability during optimization.

Now let's visualize the data with a scatterplot.

In [4]:
        xlabel="Proportion of worked dorsal surface (PROZD)",
        ylabel="Flaking angle (FLA)")

Model specification

We will be clustering this data with a Gaussian mixture model, to see if we can identify clear types of stone tools. The generative model for a Gaussian mixture consists of:

$$ p(X, z, \phi, \mu, \Lambda) =\ \underbrace{p(X \mid z, \mu, \Lambda)}_{\text{likelihood}}\ \times \ \underbrace{p(z \mid \phi)}_{\text{prior latent variables}} \ \times \ \underbrace{p(\mu \mid \Lambda)\ p(\Lambda)\ p(\phi)}_{\text{prior parameters}}$$

with the likelihood of observation $X_i$ being a Gaussian raised to the power of the latent assignment variables $z$

$$ p(X_i \mid z, \mu, \Lambda) = \prod_{k=1}^{K} \mathcal{N}(X_i \mid \mu_k, \Lambda_k^{-1})^{z_i = k}$$

the prior for each latent variable $z_i$ being a Categorical distribution

$$ p(z_i \mid \phi) = \text{Categorical}(z_i \mid \phi) $$

and priors for the parameters being

$$ \begin{align*} p(\mu_k \mid \Lambda_k) =&\ \mathcal{N}(\mu_k \mid m_0, l_0^{-1}\Lambda_k^{-1}) \qquad &\text{for all}\ k \\ p(\Lambda_k) =&\ \text{Wishart}(\Lambda_k \mid V_0, n_0) \qquad &\text{for all}\ k \\ p(\phi) =&\ \text{Dirichlet}(\phi \mid a_0) \, , \end{align*}$$

We will be implementing this model directly in ForneyLab. If you're unfamiliar with these distributions or with the Gaussian mixture model, have another look at Bert's lectures.

First, we will do a bit of bookkeeping.

In [5]:
# Data dimensionality
num_features = size(dataset["data"],2)

# Sample size
num_samples = size(dataset["data"],1)

# Number of mixture components
num_components = 3;

# Identity matrix (convenience variable)
Id = Matrix{Float64}(I, num_features, num_features);

Mixture models can be sensitive to initialization, so we are going to specify the prior parameters explicitly.

In [6]:
# Prior means
m0 = [ 1.0 0.0 -1.0;
      -1.0 0.0  1.0];

# Prior scale matrices
V0 = cat(Id, Id, Id, dims=3)

# Prior degrees of freedom 
n0 = num_features

# Prior concentration parameters
a0 = ones(num_components);

Now to start the factor graph.

In [7]:
# Start a graph
graph1 = FactorGraph()

# Initialize vector variables
z = Vector{Variable}(undef, num_samples)
X = Vector{Variable}(undef, num_samples)
Λ = Vector{Variable}(undef, num_components)
μ = Vector{Variable}(undef, num_components)

# Mixture weights are drawn from a Dirichlet distribution
@RV ϕ ~ Dirichlet(a0)

θ = []
for k = 1:num_components
    # Parameters of k-th component
    @RV Λ[k] ~ Wishart(V0[:,:,k], n0)
    @RV μ[k] ~ GaussianMeanPrecision(m0[:,k], Λ[k])
    push!(θ, μ[k], Λ[k])

for i = 1:num_samples
    # Assignment variable
    @RV z[i] ~ Categorical(ϕ)
    # Gaussian mixture component
    @RV X[i] ~ GaussianMixture(z[i], θ...)
    # Add data 
    placeholder(X[i], :X, dims=(num_features,), index=i)

This is another iid setting, which means the graph will be too large to visualize.

The next step is to compile an inference algorithm.

In [8]:
# Specify recognition factorization (mean-field)
q = PosteriorFactorization(ϕ, μ[1], Λ[1], μ[2], Λ[2], μ[3], Λ[3], z, 
                           ids=[, :μ_1, :Λ_1, :μ_2, :Λ_2, :μ_3, :Λ_3, :z])

# Generate the algorithm
algorithm = messagePassingAlgorithm(free_energy=true)
source_code = algorithmSourceCode(algorithm, free_energy=true);

After that, we feed in data, initialize recognition factors and run the inference procedure.

In [9]:
# Convert data to a format suited to ForneyLab
observations = [dataset["data"][i,:] for i in 1:num_samples]

# Add to data dictionary
data = Dict(:X => observations)

# Prepare recognition distributions
marginals = Dict()
marginals[] = ProbabilityDistribution(Dirichlet, a=ones(num_components,))
for k = 1:num_components
    marginals[:μ_*k] = ProbabilityDistribution(Multivariate, GaussianMeanPrecision, m=m0[:,k], w=Id)
    marginals[:Λ_*k] = ProbabilityDistribution(Wishart, v=Id, nu=num_features)
for i = 1:num_samples
    marginals[:z_*i] = ProbabilityDistribution(Categorical, p=ones(num_components,)./num_components)

# Number of iterations
num_iterations = 20

# Preallocate free energy tracking array
F = Float64[]

# Execute algorithm
@showprogress for i = 1:num_iterations
    # Update assignments
    stepz!(data, marginals)
    # Update parameters
    stepϕ!(data, marginals)
    stepμ_1!(data, marginals)
    stepΛ_1!(data, marginals)
    stepμ_2!(data, marginals)
    stepΛ_2!(data, marginals)
    stepμ_3!(data, marginals)
    stepΛ_3!(data, marginals)
    # Store variational free energy for visualization
    push!(F, freeEnergy(data, marginals))
Progress: 100%|█████████████████████████████████████████| Time: 0:00:28

Alright, we're done. Let's track the evolution of free energy.

In [10]:
# Plot free energy to check for convergence
plot(1:num_iterations, F, color="black", label="")
xlabel!("Number of iterations")
ylabel!("Free Energy")

That looks like it is nicely decreasing. We might want to increase the number of iterations a bit more.

Let's now visualize the cluster on top of the observations.

In [11]:
# Estimated means (unsafeMean retrieves parameters directly)
μ1_estimated = unsafeMean(marginals[:μ_1])
μ2_estimated = unsafeMean(marginals[:μ_2])
μ3_estimated = unsafeMean(marginals[:μ_3])

# Estimated precisions
Λ1_estimated = unsafeMean(marginals[:Λ_1])
Λ2_estimated = unsafeMean(marginals[:Λ_2])
Λ3_estimated = unsafeMean(marginals[:Λ_3])

# Invert to covariances
Σ1_estimated = inv(Λ1_estimated)
Σ2_estimated = inv(Λ2_estimated)
Σ3_estimated = inv(Λ3_estimated)

# Select dimensions to plot
dims_plot = [1, 2]
dim_limsx = [minimum(dataset["data"][:,dims_plot[1]])-1, maximum(dataset["data"][:,dims_plot[1]])+1]
dim_limsy = [minimum(dataset["data"][:,dims_plot[2]])-1, maximum(dataset["data"][:,dims_plot[2]])+1]

# Plot data and overlay estimated posterior probabilities
plot_clusters(dataset["data"][:, dims_plot], 
              μ=[μ1_estimated[dims_plot], μ2_estimated[dims_plot], μ3_estimated[dims_plot]], 
              Σ=[Σ1_estimated[dims_plot,dims_plot], Σ2_estimated[dims_plot,dims_plot], Σ3_estimated[dims_plot,dims_plot]], 
              x1=range(dim_limsx[1], step=0.01, stop=dim_limsx[2]), 
              x2=range(dim_limsy[1], step=0.01, stop=dim_limsy[2]), 

That doesn't look bad. The three Gaussians nicely cover all samples.

$\ast$ Try for yourself

Play around with the number of components. Can you get an equally good coverage with just 2 components? What if you had 4?

Problem: Alpine Railways

The Swiss Federal Railways company operates a series of mountain railways bringing hikers (in the summer) and skiers (in the winter) up the Alps. They are setting up a new fallback security system where they intend to track trains through cameras and remote sensors. They want you to design a system to keep track of the trains' positions.


A train going uphill updates its position according to: new position = old position + velocity x length of time-step + noise. The noise represents the train randomly slipping and sliding back down. We observe the train through a remote sensor, producing noisy observations of its position.

You receive a data set with past recordings. Your job is to set up an online filtering system, which can be deployed later on to process the incoming signal in real-time.

In [12]:
signal = load("../datasets/alpinerails_filtering.jld");

# Unpack data
states = signal["X"]
observations = signal["Y"]
transition = signal["A"]
emission = signal["C"]
process_noise = signal["Q"]
measurement_noise = signal["R"]
T = signal["T"]
Δt = signal["Δt"];

# Size
M = size(states,1)
N = size(observations,1)

# Visualize
plot(1:T, states[1,:], color="red", label="states", grid=false, xlabel="time (t)", ylabel="train position")
scatter!(1:T, observations[1,:], color="black", label="observations")

Model specification

We define noisy observations $y_k \in \mathbb{R}^{1}$ with latent states $x_k \in \mathbb{R}^{2}$. Observations $y_k$ are generated through a Gaussian likelihood centered on an emission matrix $C$ times the current state $x_k$ perturbed by measurement noise with covariance matrix $R$. State transitions follow a Gaussian distribution centered on a transition matrix $A$ times the previous state perturbed by process noise with covariance matrix $Q$. In equation form, these are:

$$\begin{align} p(x_k \mid x_{k-1}) =&\ \mathcal{N}(x_k \mid A x_{k-1}, Q)\\ p(y_k \mid x_k) =&\ \mathcal{N}(y_k \mid C x_k, R) \, . \end{align}$$

We have a prior for the previous state $x_{k-1} \sim \mathcal{N}(m_{k-1}, V_{k-1})$. In filtering problems, we feed the estimates of the current states as the parameters for the previous state in the next time-step.

In [13]:
# Initialize a graph
graph2 = FactorGraph()

# Define initial state prior
@RV x_kmin1 ~ GaussianMeanVariance(placeholder(:m_kmin1, dims=(M,)), 
                                   placeholder(:V_kmin1, dims=(M,M)))
# State transition
@RV x_k ~ GaussianMeanVariance(transition * x_kmin1, process_noise)
# Observation likelihood
@RV y_k ~ GaussianMeanVariance(dot(emission, x_k), measurement_noise)
# Tell FL that y is observed
placeholder(y_k, :y_k);

# Visualize subgraph
G 11961784858993670058 placeholder_V_kmin1 13525923077041460389 𝒩 gaussianmeanvariance_1 13525923077041460389--11961784858993670058 V_kmin1 1 out 3 v 17337918697909013489 placeholder_m_kmin1 13525923077041460389--17337918697909013489 m_kmin1 1 out 2 m 12542845186555412082 × multiplication_1 12542845186555412082--13525923077041460389 x_kmin1 1 out 2 in1 11282890469948126819 clamp_1 12542845186555412082--11282890469948126819 clamp_1 1 out 3 a 8256772125782865436 𝒩 gaussianmeanvariance_3 14488786852968061199 clamp_4 8256772125782865436--14488786852968061199 clamp_4 1 out 3 v 10313299326062430023 dot dotproduct_1 8256772125782865436--10313299326062430023 variable_2 1 out 2 m 4493200663958554577 placeholder_y_k 4493200663958554577--8256772125782865436 y_k 1 out 1 out 4290618869832858310 clamp_3 11365605713984537538 clamp_2 17953359042533603836 𝒩 gaussianmeanvariance_2 17953359042533603836--12542845186555412082 variable_1 1 out 2 m 17953359042533603836--11365605713984537538 clamp_2 1 out 3 v 10313299326062430023--4290618869832858310 clamp_3 1 out 3 in2 10313299326062430023--17953359042533603836 x_k 1 out 2 in1

We'll now construct the algorithm and infer results.

In [14]:
# Generate inference algorithm
algorithm = messagePassingAlgorithm(x_k)
source_code = algorithmSourceCode(algorithm)

For filtering, we use the same graph in each time-step.

In [15]:
# Initialize recognition distribution marginals
marginals = Dict(:x_k => vague(GaussianMeanVariance, M))

# Initialize message array
messages = Array{Message}(undef, 5)

# Keep track of estimates
m_x = 10*ones(M,T+1)
V_x = repeat(10*Matrix{Float64}(I,M,M), outer=(1,1,T+1))

@showprogress for k = 1:T
    # Initialize data
    data = Dict(:y_k => observations[k],
                :m_kmin1 => m_x[:,k],
                :V_kmin1 => V_x[:,:,k])
    # Update states
    step!(data, marginals, messages)
    # Store estimates
    m_x[:,k+1] = mean(marginals[:x_k])
    V_x[:,:,k+1] = cov(marginals[:x_k])
Progress: 100%|█████████████████████████████████████████| Time: 0:00:01

Let's check whether everything went ok. We'll visualize the state estimations.

In [16]:
# Visualization
plot(1:T, states[1,:], color="red", label="states", grid=false)
plot!(1:T, m_x[1,2:end], ribbon=[sqrt.(V_x[1,1,2:end]), sqrt.(V_x[1,1,2:end])], fillalpha=0.2, color="blue", label="inferred")
scatter!(1:T, observations[1,:], color="black", label="observations")

We're going to inspect some messages. Let's open up the algorithm and look up the marginal computation for the final $x_k$. It will be the multiplication of two messages, one consisting of the state transition prediction and the other consisting of the measurement likelihood.

In [17]:

function step!(data::Dict, marginals::Dict=Dict(), messages::Vector{Message}=Array{Message}(undef, 5))

messages[1] = ruleSPGaussianMeanVarianceOutNPP(nothing, Message(Multivariate, PointMass, m=data[:m_kmin1]), Message(MatrixVariate, PointMass, m=data[:V_kmin1]))
messages[2] = ruleSPMultiplicationOutNGP(nothing, messages[1], Message(MatrixVariate, PointMass, m=[1.0 0.1; 0.0 1.0]))
messages[3] = ruleSPGaussianMeanVarianceOutNGP(nothing, messages[2], Message(MatrixVariate, PointMass, m=[0.01 0.0; 0.0 0.1]))
messages[4] = ruleSPGaussianMeanVarianceMPNP(Message(Univariate, PointMass, m=data[:y_k]), nothing, Message(Univariate, PointMass, m=0.1))
messages[5] = ruleSPDotProductIn1GNP(messages[4], nothing, Message(Multivariate, PointMass, m=[1.0, 0.0]))

marginals[:x_k] = messages[3].dist * messages[5].dist

return marginals


end # block

Alright, we need messages 3 and 5. Let's visualize them along with the state marginal.

In [18]:
# Collect state transition prediction
m_statepred = convert(ProbabilityDistribution{Multivariate, GaussianMeanVariance}, messages[3].dist)

# Collect observation likelihood
m_likelihood = convert(ProbabilityDistribution{Multivariate, GaussianMeanVariance}, messages[5].dist)

# Collect corrected prediction
state_marginal = convert(ProbabilityDistribution{Multivariate, GaussianMeanVariance}, marginals[:x_k])

# # Extract x-coordinates
m_statepred_x = ProbabilityDistribution(Univariate, GaussianMeanVariance, m=m_statepred.params[:m][1], v=m_statepred.params[:v][1,1])
m_likelihood_x = ProbabilityDistribution(Univariate, GaussianMeanVariance, m=m_likelihood.params[:m][1], v=m_likelihood.params[:v][1,1])
state_marginal_x = ProbabilityDistribution(Univariate, GaussianMeanVariance, m=state_marginal.params[:m][1], v=state_marginal.params[:v][1,1])

# Plot of the prediction, noisy measurement, and corrected prediction for x-coordinate
plot_messages(m_statepred_x, m_likelihood_x, state_marginal_x)

As you can see, the state estimate is a combination of the state prediction, produced by the message from the state transition node, and the observation likelihood, produced by the message from the likelihood.

$\ast$ Try for yourself

Re-run the inference procedure and stop at an earlier time-step, for example $k$=2. How does the balance between the state prediction and the observation likelihood differ?