%matplotlib inline
import sunode
import numpy as np
import matplotlib.pyplot as plt
print("sunode", sunode.__version__)
sunode 0.4.0
params = {
'α': (),
'β': (),
'γ': (),
'δ': (),
}
states = {
'hares': (),
'lynxes': (),
}
def lotka_volterra(t, y, p):
"""Right hand side of Lotka-Volterra equation.
All inputs are dataclasses of sympy variables, or in the case
of non-scalar variables numpy arrays of sympy variables.
"""
return {
'hares': p.α * y.hares - p.β * y.lynxes * y.hares,
'lynxes': p.δ * y.hares * y.lynxes - p.γ * y.lynxes,
}
problem = sunode.SympyProblem(
params=params,
states=states,
rhs_sympy=lotka_volterra,
derivative_params=[('α',), ('β',), ('γ',), ('δ',)]
)
# solver = sunode.solver.Solver(problem, sens_mode="simultaneous")
solver = sunode.solver.AdjointSolver(problem)
y0 = np.zeros((), dtype=problem.state_dtype)
y0['hares'] = 1
y0['lynxes'] = 0.1
# At which time points do we want to evalue the solution
t = np.linspace(0, 10)
α, β, γ, δ = 0.1, 0.2, 0.3, 0.4
θ = α, β, γ, δ
solver.set_params_dict({
'α': α,
'β': β,
'γ': γ,
'δ': δ,
})
#y, sens = solver.make_output_buffers(tvals)
y, grad, lam = solver.make_output_buffers(t)
# solver.solve(t0=0, tvals=t, y0=y0, y_out=y, sens0=np.zeros_like(sens[0]), sens_out=sens)
solver.solve_forward(t0=t[0], tvals=t, y0=y0, y_out=y)
plt.plot(t, y)
plt.xlabel('t')
plt.ylabel('y')
plt.xlim(t[0], t[-1])
plt.legend(['hares', 'lynx']);
solver.solve_backward(t0=t[-1], tend=t[0], tvals=t,
grads=np.ones((len(t), y.shape[-1])),
grad_out=grad, lamda_out=lam)
lam, grad
(array([-82.13485284, 25.18562792]), array([ 465.25049911, -103.50680243, -25.88314792, 38.67576041]))
def predict(θ):
α, β, γ, δ = θ
solver.set_params_dict({
'α': α,
'β': β,
'γ': γ,
'δ': δ,
})
y_out, grad, lam = solver.make_output_buffers(t)
solver.solve_forward(t0=t[0], tvals=t, y0=y0, y_out=y_out)
return y_out
θ_guess = 0.1, 0.1, 0.1, 0.1
yhat = predict(θ_guess)
def loss(y, θ): # mse
yhat = predict(θ)
resid = y - yhat
return (resid * resid).sum()
loss(y, θ_guess)
14.487093439626495
def gradient(θ):
α, β, γ, δ = θ
solver.set_params_dict({
'α': α,
'β': β,
'γ': γ,
'δ': δ,
})
y_out, grad_out, lam_out = solver.make_output_buffers(t)
solver.solve_forward(t0=t[0], tvals=t, y0=y0, y_out=y_out)
res = (y - y_out)
solver.solve_backward(t0=t[-1], tend=t[0], tvals=t,
grads=-2*res,
grad_out=grad_out, lamda_out=lam_out)
return grad_out
θ_guess = 0.1, 0.1, 0.1, 0.1
grad = gradient(θ_guess)
grad
array([433.82952474, -50.6250329 , 52.96938998, -75.86014686])
α, β, γ, δ = θ_guess
Δ = 1e-6
# yhat1 = predict((α+Δα/2, β, γ, δ))
# yhat2 = predict((α-Δα/2, β, γ, δ))
# (yhat1.sum() - yhat2.sum())/Δα, grad[0]
L1 = loss(y, (α+Δ/2, β, γ, δ))
L2 = loss(y, (α-Δ/2, β, γ, δ))
(L1-L2)/Δ, grad[0]
(433.82952247483786, 433.8295247366786)
# yhat1 = predict((α, β+Δβ/2, γ, δ))
# yhat2 = predict((α, β-Δβ/2, γ, δ))
# (yhat1.sum() - yhat2.sum())/Δβ, grad[1]
L1 = loss(y, (α, β+Δ/2, γ, δ))
L2 = loss(y, (α, β-Δ/2, γ, δ))
(L1-L2)/Δ, grad[1]
(-50.62503427133436, -50.62503290111857)
L1 = loss(y, (α, β, γ+Δ/2, δ))
L2 = loss(y, (α, β, γ-Δ/2, δ))
(L1-L2)/Δ, grad[2]
(52.96939328225392, 52.969389983518944)
L1 = loss(y, (α, β, γ, δ+Δ/2))
L2 = loss(y, (α, β, γ, δ-Δ/2))
(L1-L2)/Δ, grad[3]
(-75.86015293803428, -75.860146862688)
def average(p, c, β):
return β * p + (1 - β) * c
class AdamOptimizer:
def __init__(self, α=0.001, β1=0.9, β2=0.999, ϵ=1e-8):
self.α = α
self.β1 = β1
self.β2 = β2
self.ϵ = ϵ
self.m = None
self.v = None
self.t = 0
def send(self, grad):
if self.m is None:
self.m = 0
if self.v is None:
self.v = 0
self.t += 1
αt = self.α * np.sqrt(1 - self.β2**self.t) / (1 - self.β1**self.t)
self.m = average(self.m, grad, self.β1)
self.v = average(self.v, (grad*grad), self.β2)
updates = -αt * self.m / (np.sqrt(self.v) + self.ϵ)
assert np.isfinite(updates).all()
return updates
θ_hat = np.array([0.1, 0.1, 0.1, 0.1])
η = np.array([0.0001, 0.0001, 0.0001, 0.0001])
imax = 10000
opt = AdamOptimizer()
print("L = {:.2g}, θ = (({:.4f}, {:.4f}, {:.4f}, {:.4f}))".format(loss(y, θ), *θ))
for i in range(imax):
grad = gradient(θ_hat)
Δθ = opt.send(grad)
θ_hat += Δθ
if i % (imax // 50) == 0:
print("L = {:.2g}, θ_hat = ({:.4f}, {:.4f}, {:.4f}, {:.4f}) ".format(loss(y, θ_hat), *θ_hat))
L = 0, θ = ((0.1000, 0.2000, 0.3000, 0.4000)) L = 14, θ_hat = (0.0990, 0.1010, 0.0990, 0.1010) L = 0.097, θ_hat = (0.0857, 0.1328, 0.0120, 0.1865) L = 0.054, θ_hat = (0.0925, 0.1573, 0.0141, 0.1865) L = 0.034, θ_hat = (0.0977, 0.1765, 0.0164, 0.1869) L = 0.028, θ_hat = (0.1008, 0.1884, 0.0187, 0.1878) L = 0.026, θ_hat = (0.1024, 0.1943, 0.0211, 0.1891) L = 0.026, θ_hat = (0.1030, 0.1967, 0.0237, 0.1909) L = 0.025, θ_hat = (0.1032, 0.1975, 0.0265, 0.1930) L = 0.025, θ_hat = (0.1032, 0.1978, 0.0297, 0.1953) L = 0.024, θ_hat = (0.1032, 0.1978, 0.0332, 0.1980) L = 0.023, θ_hat = (0.1031, 0.1978, 0.0371, 0.2010) L = 0.022, θ_hat = (0.1030, 0.1978, 0.0415, 0.2042) L = 0.022, θ_hat = (0.1030, 0.1978, 0.0462, 0.2078) L = 0.021, θ_hat = (0.1029, 0.1978, 0.0514, 0.2118) L = 0.02, θ_hat = (0.1028, 0.1978, 0.0570, 0.2160) L = 0.019, θ_hat = (0.1027, 0.1978, 0.0632, 0.2207) L = 0.018, θ_hat = (0.1027, 0.1978, 0.0698, 0.2257) L = 0.017, θ_hat = (0.1026, 0.1978, 0.0769, 0.2310) L = 0.015, θ_hat = (0.1025, 0.1978, 0.0845, 0.2368) L = 0.014, θ_hat = (0.1024, 0.1979, 0.0926, 0.2429) L = 0.013, θ_hat = (0.1022, 0.1979, 0.1012, 0.2495) L = 0.012, θ_hat = (0.1021, 0.1979, 0.1103, 0.2564) L = 0.011, θ_hat = (0.1020, 0.1980, 0.1199, 0.2636) L = 0.0096, θ_hat = (0.1019, 0.1980, 0.1299, 0.2711) L = 0.0084, θ_hat = (0.1018, 0.1981, 0.1402, 0.2790) L = 0.0073, θ_hat = (0.1016, 0.1982, 0.1510, 0.2871) L = 0.0062, θ_hat = (0.1015, 0.1983, 0.1619, 0.2954) L = 0.0053, θ_hat = (0.1014, 0.1984, 0.1731, 0.3039) L = 0.0044, θ_hat = (0.1012, 0.1985, 0.1843, 0.3124) L = 0.0035, θ_hat = (0.1011, 0.1986, 0.1956, 0.3209) L = 0.0028, θ_hat = (0.1010, 0.1987, 0.2067, 0.3293) L = 0.0022, θ_hat = (0.1009, 0.1988, 0.2176, 0.3376) L = 0.0017, θ_hat = (0.1008, 0.1989, 0.2281, 0.3455) L = 0.0012, θ_hat = (0.1006, 0.1991, 0.2382, 0.3532) L = 0.00088, θ_hat = (0.1005, 0.1992, 0.2476, 0.3603) L = 0.00061, θ_hat = (0.1005, 0.1993, 0.2564, 0.3669) L = 0.00041, θ_hat = (0.1004, 0.1994, 0.2643, 0.3730) L = 0.00026, θ_hat = (0.1003, 0.1995, 0.2715, 0.3784) L = 0.00016, θ_hat = (0.1002, 0.1996, 0.2777, 0.3831) L = 9.2e-05, θ_hat = (0.1002, 0.1997, 0.2830, 0.3871) L = 5.1e-05, θ_hat = (0.1001, 0.1998, 0.2874, 0.3904) L = 2.6e-05, θ_hat = (0.1001, 0.1999, 0.2909, 0.3931) L = 1.3e-05, θ_hat = (0.1001, 0.1999, 0.2937, 0.3952) L = 5.6e-06, θ_hat = (0.1000, 0.1999, 0.2958, 0.3968) L = 2.3e-06, θ_hat = (0.1000, 0.2000, 0.2973, 0.3980) L = 8.4e-07, θ_hat = (0.1000, 0.2000, 0.2984, 0.3988) L = 2.8e-07, θ_hat = (0.1000, 0.2000, 0.2991, 0.3993) L = 8.3e-08, θ_hat = (0.1000, 0.2000, 0.2995, 0.3996) L = 2.2e-08, θ_hat = (0.1000, 0.2000, 0.2997, 0.3998) L = 7.1e-09, θ_hat = (0.1000, 0.2000, 0.2999, 0.3999)