import numpy as np
from astropy import units as u
from poliastro.bodies import Earth
from poliastro.iod import lambert
from poliastro.math import dot
from poliastro.stumpff import c2, c3
k = Earth.k
r0 = [15945.34, 0.0, 0.0] * u.km
r = [12214.83399, 10249.46731, 0.0] * u.km
tof = 76.0 * u.min
lambert(k.to(u.km ** 3 / u.s ** 2).value,
r0.value, r.value,
tof.to(u.s).value)
(array([ 2.05891159, 2.91596286, 0. ]), array([-3.45156464, 0.91031284, 0. ]))
Prepare inputs to measure performance:
k_ = k.to(u.km ** 3 / u.s ** 2).value
r0_ = r0.to(u.km).value
r_ = r.to(u.km).value
tof_ = tof.to(u.s).value
%timeit lambert(k_, r0_, r_, tof_)
The slowest run took 7.36 times longer than the fastest. This could mean that an intermediate result is being cached 100000 loops, best of 3: 14.9 µs per loop
Let's implement the simplest one: Bate-Mueller-White universal variable approach, with a bisection method as suggested by Vallado.
def lambert_py(k, r0, r, tof, short=True, numiter=35, rtol=1e-6):
if short:
t_m = +1
else:
t_m = -1
norm_r0 = np.dot(r0, r0)**.5
norm_r = np.dot(r, r)**.5
cos_dnu = np.dot(r0, r) / (norm_r0 * norm_r)
sin_dnu = t_m * (1 - cos_dnu ** 2)**.5
A = t_m * (norm_r * norm_r0 * (1 + cos_dnu))**.5
if A == 0.0:
raise RuntimeError("Cannot compute orbit")
psi = 0.0
psi_low = -4 * np.pi
psi_up = 4 * np.pi
count = 0
while count < numiter:
y = norm_r0 + norm_r + A * (psi * c3(psi) - 1) / c2(psi)**.5
if A > 0.0 and y < 0.0:
# Readjust xi_low until y > 0.0 (?)
pass
xi = np.sqrt(y / c2(psi))
tof_new = (xi**3 * c3(psi) + A * np.sqrt(y)) / np.sqrt(k)
# Convergence check
if np.abs((tof_new - tof) / tof) < rtol:
break
else:
count += 1
# Bisection check
if tof_new <= tof:
psi_low = psi
else:
psi_up = psi
psi = (psi_up + psi_low) / 2
else:
raise RuntimeError("Convergence could not be achieved under "
"%d iterations" % numiter)
f = 1 - y / norm_r0
g = A * np.sqrt(y / k)
gdot = 1 - y / norm_r
v0 = (r - f * r0) / g
v = (gdot * r - r0) / g
return v0, v
lambert_py(k_, r0_, r_, tof_)
(array([ 2.05890855, 2.91596498, 0. ]), array([-3.45156368, 0.91031642, 0. ]))
%timeit lambert_py(k_, r0_, r_, tof_)
1000 loops, best of 3: 222 µs per loop
import numba
def lambert_numba(k, r0, r, tof, short=True, numiter=35, rtol=1e-8):
try:
f, g, fdot, gdot = _lambert(k, r0, r, tof, short, numiter, rtol)
except RuntimeError as e:
raise e
v0 = (r - f * r0) / g
v = (gdot * r - r0) / g
return v0, v
@numba.njit
def _lambert(k, r0, r, tof, short, numiter, rtol):
if short:
t_m = +1
else:
t_m = -1
norm_r0 = dot(r0, r0)**.5
norm_r = dot(r, r)**.5
cos_dnu = dot(r0, r) / (norm_r0 * norm_r)
sin_dnu = t_m * (1 - cos_dnu ** 2)**.5
A = t_m * (norm_r * norm_r0 * (1 + cos_dnu))**.5
if A == 0.0:
raise RuntimeError
psi = 0.0
psi_low = -4 * np.pi
psi_up = 4 * np.pi
count = 0
while count < numiter:
y = norm_r0 + norm_r + A * (psi * c3(psi) - 1) / c2(psi)**.5
if A > 0.0 and y < 0.0:
# Readjust xi_low until y > 0.0
# Translated directly from Vallado
while y < 0.0:
psi_low = psi
psi = 0.8 * (1.0 / c3(psi)) * (1.0 - (norm_r0 + norm_r) * np.sqrt(c2(psi)) / A)
y = norm_r0 + norm_r + A * (psi * c3(psi) - 1) / c2(psi)**.5
xi = np.sqrt(y / c2(psi))
tof_new = (xi**3 * c3(psi) + A * np.sqrt(y)) / np.sqrt(k)
# Convergence check
if np.abs((tof_new - tof) / tof) < rtol:
break
else:
count += 1
# Bisection check
if tof_new <= tof:
psi_low = psi
else:
psi_up = psi
psi = (psi_up + psi_low) / 2
else:
raise RuntimeError
f = 1 - y / norm_r0
g = A * np.sqrt(y / k)
gdot = 1 - y / norm_r
return f, g, (f * gdot - 1) / g, gdot
lambert_numba(k_, r0_, r_, tof_)
(array([ 2.05890993, 2.91596401, 0. ]), array([-3.45156412, 0.91031479, 0. ]))
%timeit lambert_numba(k_, r0_, r_, tof_)
The slowest run took 6.45 times longer than the fastest. This could mean that an intermediate result is being cached 10000 loops, best of 3: 21.6 µs per loop
import poliastro
import poliastro.iod
poliastro.iod.lambert = lambert_numba
poliastro.test()
============================= test session starts ============================== platform linux -- Python 3.4.3 -- py-1.4.26 -- pytest-2.6.4 collected 39 items ../poliastro/tests/test_bodies.py .... ../poliastro/tests/test_iod.py ... ../poliastro/tests/test_maneuver.py ...... ../poliastro/tests/test_math.py .... ../poliastro/tests/test_plotting.py .. ../poliastro/tests/test_twobody.py ................. ../poliastro/tests/test_util.py ... ========================== 39 passed in 0.74 seconds ===========================