using Pkg Pkg.activate("../../../lessons/") Pkg.instantiate(); using JLD using Statistics using LinearAlgebra using Distributions using RxInfer using ColorSchemes using LaTeXStrings using Plots default(label="", grid=false, linewidth=3, margin=10Plots.pt) # Load data from file data = load("../datasets/shaking_buildings.jld") # Data states = data["states"] observations = data["observations"] # Parameters mass = data["m"] friction = data["c"] stiffness = data["k"] # Measurement noise variance σ = data["σ"] # Time Δt = data["Δt"] T = length(observations) time = range(1,step=Δt,length=T) plot(time, states[1,:], color="red", label="states", xlabel="time (sec)", ylabel="train position") scatter!(time, observations, color="black", label="observations", legend=:topleft, size=(800,300)) # Transition matrix A = [1 Δt; -stiffness/mass*Δt -friction/mass*Δt+1] # Emission matrix C = [1.0, 0.0] # Set process noise covariance matrix Q = diagm(ones(2)) @model function LGDS(prior_params, A,C,Q, σ; T=1) "State estimation in linear Gaussian dynamical system" z = randomvar(T) y = datavar(Float64,T) # Prior state z_0 ~ MvNormalMeanCovariance(prior_params[:z0][1], prior_params[:z0][2]) z_kmin1 = z_0 for k in 1:T # State transition z[k] ~ MvNormalMeanCovariance(A * z_kmin1, Q) # Likelihood y[k] ~ NormalMeanVariance(dot(C, z[k]), σ^2) # Update recursive aux z_kmin1 = z[k] end return y, z end # Initial state prior prior_params = Dict(:z0 => (zeros(2), diageye(2))) (posteriors,_) = inference( model = LGDS(prior_params, A,C,Q, σ, T=T), data = (y = [observations[k] for k in 1:T],), free_energy = true, ) m_z = cat(mean.(posteriors[:z])...,dims=2) v_z = cat(var.( posteriors[:z])...,dims=2) plot(time, states[1,:], color="red", label="states", xlabel="time (sec)", ylabel="train position") plot!(time, m_z[1,:], color="blue", ribbon=v_z[1,:], label="inferred") scatter!(time, observations, color="black", alpha=0.2, label="observations", legend=:bottomright, size=(800,300)) @model function LGDS_Q(prior_params, A,C, σ; T=1) "State estimation in a linear Gaussian dynamical system with unknown process noise" z = randomvar(T) y = datavar(Float64,T) # Prior state z_0 ~ MvNormalMeanCovariance(prior_params[:z0][1], prior_params[:z0][2]) # Process noise covariance matrix Q ~ InverseWishart(prior_params[:Q][1], prior_params[:Q][2]) z_kmin1 = z_0 for k in 1:T # State transition z[k] ~ MvNormalMeanCovariance(A * z_kmin1, Q) # Likelihood y[k] ~ NormalMeanVariance(dot(C, z[k]), σ^2) # Update recursive aux z_kmin1 = z[k] end return y, z, Q end # Define prior parameters prior_params = Dict(:z0 => (zeros(2), diageye(2)), :Q => (10, diageye(2))) # Iterations of variational inference num_iters = 100 # Initialize variational marginal distributions and messages inits = Dict(:z => MvNormalMeanCovariance(zeros(2), diageye(2)), :Q => InverseWishart(10, diageye(2))) # Define variational distribution factorization constraints = @constraints begin q(z_0, z,Q) = q(z_0, z)q(Q) end # Variational inference procedure results = inference( model = LGDS_Q(prior_params, A,C, σ, T=T), data = (y = [observations[k] for k in 1:T],), constraints = constraints, iterations = num_iters, options = (limit_stack_depth = 100,), initmarginals = inits, initmessages = inits, free_energy = true, showprogress = true, ) plot(1:num_iters, results.free_energy, color="black", xscale=:log10, xlabel="Number of iterations", ylabel="Free Energy", size=(800,300)) m_z = cat(mean.(last(results.posteriors[:z]))...,dims=2) v_z = cat(var.(last(results.posteriors[:z]))...,dims=2) plot(time, states[1,:], color="red", label="states", xlabel="time (sec)", ylabel="train position") plot!(time, m_z[1,:], color="blue", ribbon=v_z[1,:], label="inferred") scatter!(time, observations, color="black", alpha=0.2, label="observations", legend=:topleft, size=(800,300)) Q_MAP = mean(last(results.posteriors[:Q])) # True data Q_true = data["Q"] # Colorbar limits clims = (minimum([Q_MAP[:]; Q_true[:]]), maximum([Q_MAP[:]; Q_true[:]])) # Plot covariance matrices as heatmaps p401 = heatmap(Q_MAP, axis=([], false), yflip=true, title="Estimated", clims=clims) p402 = heatmap(Q_true, axis=([], false), yflip=true, title="True", clims=clims) plot(p401,p402, layout=(1,2), size=(900,300))