%matplotlib inline
import matplotlib.pyplot as plt
import numpy.random as npr
import jax.numpy as jnp
from jax import jit, grad
from jax import vmap
from scipy.optimize import curve_fit
from jax.experimental.ode import odeint
from jax.scipy.optimize import minimize
from scipy.optimize import minimize as _minimize_
from scipy.integrate import odeint as _odeint_
# Toy 1D dataset.
x = jnp.reshape(jnp.linspace(-2.0, 2.0, 10), (10, 1))
y = x**3 + 0.1 * x
plt.scatter(x, y)
plt.xlabel('x')
plt.ylabel('y');
def mlp(params, x):
# A multi-layer perceptron, i.e. a fully-connected neural network.
for w, b in params:
y = jnp.dot(x, w) + b # Linear transform
x = jnp.tanh(y) # Nonlinearity
return y
param_scale = 1.0
def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
return [
(scale * rng.randn(m, n), scale * rng.randn(n))
for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])
]
def resnet(params, x, depth):
for i in range(depth):
y = mlp(params, x) + x
return y
resnet_depth = 3
def resnet_loss(params, x, y):
yhat = resnet(params, x, resnet_depth)
return jnp.mean((yhat - y)**2)
def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
return [(scale * rng.randn(m, n), scale * rng.randn(n))
for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])]
# A simple gradient-descent optimizer.
@jit
def resnet_update(params, x, y):
grads = grad(resnet_loss)(params, x, y)
return [
(w - η * dw, b - η * db)
for (w, b), (dw, db) in zip(params, grads)
]
# Hyperparameters.
layer_sizes = [1, 20, 1]
η = 0.01
train_iters = 1000
# Initialize and train.
resnet_params = init_random_params(param_scale, layer_sizes)
for i in range(train_iters):
resnet_params = resnet_update(resnet_params, x, y)
def nn_dynamics(y, t, params):
y_and_t = jnp.hstack([y, jnp.array(t)])
return mlp(params, y_and_t)
def odenet(params, y0, t=jnp.array([0.0, 1.0])):
_, y1 = odeint(nn_dynamics, y0, t, params)
return y1
odenet = vmap(odenet, in_axes=(None, 0))
def odenet_loss(params, x, y):
yhat = odenet(params, x)
return jnp.mean((yhat - y)**2)
@jit
def odenet_update(params, x, y):
grads = grad(odenet_loss)(params, x, y)
return [
(w - η * dw, b - η * db)
for (w, b), (dw, db) in zip(params, grads)
]
# We need to change the input dimension to 2, to allow time-dependent dynamics.
odenet_layer_sizes = [2, 20, 1]
# Initialize
odenet_params = init_random_params(param_scale, odenet_layer_sizes)
# train
for i in range(train_iters):
odenet_params = odenet_update(odenet_params, x, y)
# train
for i in range(train_iters):
odenet_params = odenet_update(odenet_params, x, y)
print(
'resent = ', resnet_loss(resnet_params, x, y),
'odenet = ', odenet_loss(odenet_params, x, y)
)
# Plot resulting model.
fig = plt.figure(figsize=(6, 4), dpi=150)
ax = fig.gca()
ax.scatter(x, y, lw=0.5, label='data')
fine_x = jnp.reshape(jnp.linspace(-3.0, 3.0, 100), (100, 1))
ax.plot(fine_x, resnet(resnet_params, fine_x, resnet_depth), lw=0.5, label='resent')
ax.plot(fine_x, odenet(odenet_params, fine_x), lw=0.5, label='odenet')
ax.set_xlabel('x')
ax.set_ylabel('y')
plt.legend();
resent = 0.10412481 odenet = 0.0023549802
def logistic(t, y0, r, K):
return K / (1 - (1 - K/y0) * jnp.exp(-r * t))
def logistic_loss(params, t, y):
yhat = logistic(t, *params)
return jnp.mean((yhat - y)**2)
K = 1.2
r = 0.5
y0 = 0.2
reps = 5
t = jnp.linspace(0, 24, 24*2)
t = jnp.array([t]*reps).squeeze()
y = logistic(t, y0, r, K) + npr.normal(0, 0.02, size=t.shape)
y = y.squeeze()
plt.plot(t.T, y.T, '-k', alpha=1/reps)
plt.xlabel('t')
plt.ylabel('y')
plt.xlim(0, 24);
params_guess = jnp.array([y.min(), 1, y.max()])
params, pcov = curve_fit(logistic, t.ravel(), y.ravel(), params_guess)
print(y0, r, K, '|', logistic_loss((y0, r, K), t, y))
print(*params, '|', logistic_loss(params, t, y))
plt.plot(t, y, '.k')
plt.plot(t.T, logistic(t, *params).T, '-r');
0.2 0.5 1.2 | 0.00040311948 0.20128386981070637 0.4991952381803526 1.2042213722496522 | 0.0003895881
def logistic_ode(y, t, r, K):
return r * y * (1 - y/K)
def logistic_odeint(t, y0, r, K):
return odeint(logistic_ode, y0, t, r, K)
# plt.plot(t[0], logistic_odeint(t[0], y0, r, K));
# logistic_odeint = vmap(logistic_odeint, in_axes=(0, None, None, None))
# plt.plot(t.T, logistic_odeint(t, y0, r, K).T);
def logistic_odeint_loss(params, t, y):
yhat = logistic_odeint(t, *params).reshape((1, -1))
return jnp.mean((yhat - y)**2)
print(logistic_odeint_loss([y0, r, K], t[0], y))
0.0004031221
res = minimize(logistic_odeint_loss, params_guess, args=(t[0], y), method='BFGS')
print(res.success)
False
if not res.success:
res = minimize(logistic_odeint_loss, res.x, args=(t[0], y), method='BFGS')
print(res.success)
True
print(y0, r, K, '|', logistic_odeint_loss([y0, r, K], t[0], y))
print(res.x, '|', logistic_odeint_loss(res.x, t[0], y))
plt.plot(t, y, '.k')
plt.plot(t[0], logistic_odeint(t[0], *res.x), '-r');
0.2 0.5 1.2 | 0.0004031221 [0.20128083 0.49919215 1.2042261 ] | 0.00038958865
def mlp_ode(y, t, params):
y_and_t = jnp.hstack([y, t])
return mlp(params, y_and_t)
def odenet_odeint(t, params):
return odeint(mlp_ode, y0, t, params)
# odenet_odeint = vmap(odenet_odeint, in_axes=(0, None))
def odenet_loss(params, t, y):
yhat = odenet_odeint(t, params).reshape((1, -1))
return jnp.mean((yhat - y)**2)
@jit
def odenet_update(params, t, y):
grads = grad(odenet_loss)(params, t, y)
return [
(w - η * dw, b - η * db)
for (w, b), (dw, db) in zip(params, grads)
]
# We need to change the input dimension to 2, to allow time-dependent dynamics.
odenet_layer_sizes = [2, 6, 6, 1] # t,y(t) -> ... -> y(t+dt)
# Initialize
odenet_params = init_random_params(param_scale, odenet_layer_sizes)
print(sum(__.size for _ in odenet_params for __ in _))
67
# train
η = 0.01
train_iters = 1000000
for i in range(train_iters):
odenet_params = odenet_update(odenet_params, t[0], y)
if i % 5000 == 0:
print(i, odenet_loss(odenet_params, t[0], y))
0 131.81291 5000 0.0044554514 10000 0.0023229162 15000 0.0017886197 20000 0.0014564791 25000 0.001261144 30000 0.0011124271 35000 0.0009967838 40000 0.0009056021 45000 0.0008329434 50000 0.0007745297 55000 0.00072706793 60000 0.000688042 65000 0.00065569556 70000 0.0006284404 75000 0.0006052384 80000 0.0005898033 85000 0.0005779843 90000 0.000567442 95000 0.00055784016 100000 0.0005489524 105000 0.0005406805 110000 0.00053310697 115000 0.00052631256 120000 0.0005199418 125000 0.0005138974 130000 0.0005082313 135000 0.00050291576 140000 0.000497857 145000 0.00049304706 150000 0.00048859866 155000 0.0004845434 160000 0.00048070843 165000 0.000477074 170000 0.0004735961 175000 0.000470258 180000 0.0004670442 185000 0.00046400787 190000 0.00046108747 195000 0.00045828344 200000 0.0004555908 205000 0.00045304032 210000 0.00045063236 215000 0.00044832114 220000 0.00044608646 225000 0.00044407006 230000 0.00044219842 235000 0.00044043234 240000 0.00043872616 245000 0.00043710327 250000 0.0004355227 255000 0.00043401113 260000 0.00043253863 265000 0.00043111105 270000 0.00042972778 275000 0.00042839904 280000 0.00042711833 285000 0.00042588738 290000 0.00042471252 295000 0.0004235837 300000 0.00042249475 305000 0.00042144422 310000 0.00042043335 315000 0.00041946143 320000 0.00041850653 325000 0.0004175952 330000 0.00041671604 335000 0.00041587185 340000 0.00041507528 345000 0.0004143063 350000 0.0004135817 355000 0.00041288196 360000 0.00041219706 365000 0.00041152272 370000 0.00041087248 375000 0.00041026316 380000 0.0004096793 385000 0.00040911845 390000 0.0004085675 395000 0.0004080341 400000 0.00040753031 405000 0.00040706305 410000 0.00040661887 415000 0.00040618234 420000 0.00040575973 425000 0.00040534142 430000 0.00040495928 435000 0.0004045956 440000 0.00040423733 445000 0.00040389664 450000 0.00040355587 455000 0.0004032241 460000 0.0004028958 465000 0.00040257222 470000 0.0004022519 475000 0.00040194346 480000 0.0004016443 485000 0.00040136385 490000 0.00040109124 495000 0.0004008208 500000 0.00040055448 505000 0.00040029723 510000 0.00040003966 515000 0.00039979097 520000 0.00039954597 525000 0.00039930097 530000 0.00039906503 535000 0.00039883744 540000 0.0003986405 545000 0.00039845533 550000 0.00039827317 555000 0.00039809532 560000 0.00039791677 565000 0.00039774604 570000 0.00039757046 575000 0.00039739884 580000 0.00039722773 585000 0.0003970602 590000 0.0003968896 595000 0.00039672482 600000 0.00039656038 605000 0.00039639513 610000 0.00039624234 615000 0.0003960917 620000 0.00039594478 625000 0.00039579306 630000 0.00039564815 635000 0.00039550275 640000 0.00039535842 645000 0.00039521712 650000 0.00039507277 655000 0.00039493854 660000 0.00039480077 665000 0.00039466325 670000 0.00039452498 675000 0.00039438976 680000 0.00039425638 685000 0.00039412332 690000 0.0003939876 695000 0.00039385815 700000 0.00039372992 705000 0.0003936001 710000 0.00039347247 715000 0.00039334624 720000 0.00039321993 725000 0.0003931031 730000 0.00039298754 735000 0.00039287453 740000 0.00039276117 745000 0.0003926476 750000 0.0003925359 755000 0.00039242642 760000 0.0003923155 765000 0.00039220497 770000 0.00039209434 775000 0.00039198526 780000 0.00039187877 785000 0.00039176882 790000 0.00039166294 795000 0.00039155514 800000 0.00039144678 805000 0.00039134227 810000 0.00039123622 815000 0.00039113208 820000 0.00039102524 825000 0.0003909265 830000 0.00039082926 835000 0.0003907328 840000 0.00039063595 845000 0.0003905417 850000 0.00039044925 855000 0.000390357 860000 0.00039026348 865000 0.00039017072 870000 0.00039007896 875000 0.0003899867 880000 0.0003898956 885000 0.00038980338 890000 0.00038971406 895000 0.0003896256 900000 0.00038953562 905000 0.00038944578 910000 0.00038935905 915000 0.00038926926 920000 0.00038917988 925000 0.00038909324 930000 0.00038900712 935000 0.00038891908 940000 0.00038882918 945000 0.0003887425 950000 0.00038865968 955000 0.00038857406 960000 0.00038849347 965000 0.00038841218 970000 0.00038832726 975000 0.00038824376 980000 0.00038816285 985000 0.00038808427 990000 0.00038799897 995000 0.00038792053
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
tt = jnp.linspace(t[0,0], t[0,-1], 500)
ax = axes[0]
ax.plot(t, y, '.k', markersize=2)
yhat = odenet_odeint(tt, odenet_params).squeeze()
ax.plot(tt, yhat, '-r', label='ODENet');\
ax.plot(tt, logistic(tt, y0, r, K), '-g', label='Logistic');
ax.set_xlabel('$t$')
ax.set_ylabel('$y$')
ax.legend()
ax = axes[1]
dyhatdt = [mlp_ode(yhat[i], tt[i], odenet_params) for i in range(len(tt))]
dydt = [logistic_ode(yhat[i], tt[i], r, K) for i in range(len(tt))]
ax.plot(yhat, dyhatdt, 'o', label='ODENet')
ax.plot(yhat, dydt, 'o', label='Logistic')
ax.set_ylim(0, max(dyhatdt))
ax.set_xlabel('$y$')
ax.set_ylabel('$dy/dt$')
ax.legend();