QuTiP development: compare solvers

J.R. Johansson and P.D. Nation

For more information about QuTiP see http://qutip.org

In [ ]:
%pylab inline
In [ ]:
from qutip import *
In [ ]:
import time
In [ ]:
# Example: Rabi oscillation in the dissipative Jaynes-Cummings model.
def qubit_integrate(epsilon, delta, g1, g2, solver):

    H = epsilon / 2.0 * sigmaz() + delta / 2.0 * sigmax()
        
    # collapse operators
    c_op_list = []

    rate = g1
    if rate > 0.0:
        c_op_list.append(sqrt(rate) * sigmam())

    rate = g2
    if rate > 0.0:
        c_op_list.append(sqrt(rate) * sigmaz())

    if solver == "me":
        output = mesolve(H, psi0 * psi0.dag(), tlist, c_op_list, [sigmax(), sigmay(), sigmaz()])
        expt_list = output.expect
    elif solver == "wf":
        output = mesolve(H, psi0, tlist, [], [sigmax(), sigmay(), sigmaz()])
        expt_list = output.expect
    elif solver == "es":
        expt_list = essolve(H, psi0, tlist, c_op_list, [sigmax(), sigmay(), sigmaz()])
    elif solver == "mc1":
        output = mcsolve(H, psi0, tlist, c_op_list, [sigmax(), sigmay(), sigmaz()], 1)
        expt_list = output.expect
    elif solver == "mc250":
        output = mcsolve(H, psi0, tlist, c_op_list, [sigmax(), sigmay(), sigmaz()], 250)
        expt_list = output.expect
    elif solver == "mc500":
        output = mcsolve(H, psi0, tlist, c_op_list, [sigmax(), sigmay(), sigmaz()], 500)
        expt_list = output.expect
    else:
        raise ValueError("unknown solver")
        
    return expt_list[0], expt_list[1], expt_list[2]
In [ ]:
epsilon = 0.0 * 2 * pi   # cavity frequency
delta   = 1.0 * 2 * pi   # atom frequency
g2 = 0.01
g1 = 0.01

# intial state
psi0 = basis(2,0)
tlist = linspace(0,5,200)
In [ ]:
figure(figsize=(12, 4))

for solver in ("me", "wf", "es", "mc1", "mc250"):

    start_time = time.time()
    sx1, sy1, sz1 = qubit_integrate(epsilon, delta, g1, g2, solver)
    print(solver + ' time elapsed = ' +str(time.time() - start_time))

    figure(1)
    plot(tlist, real(sx1), 'r')
    plot(tlist, real(sy1), 'b')
    plot(tlist, real(sz1), 'g')
    
xlabel('Time')
ylabel('Expectation value');

Performance

In [ ]:
def system_integrate(Na, Nb, wa, wb, wab, ga, gb, solver):

    # Hamiltonian and initial state
    a = tensor(destroy(Na), qeye(Nb))
    b = tensor(qeye(Na), destroy(Nb))
    na = a.dag() * a
    nb = b.dag() * b
    H = wa * na  + wb * nb + wab * (a.dag() * b + a * b.dag())

    # start with one more excitation in a than in b
    psi0 = tensor(basis(Na,Na-1), basis(Nb,Nb-2))

    # collapse operators
    c_op_list = []

    rate = ga
    if rate > 0.0:
        c_op_list.append(sqrt(rate) * a)

    rate = gb
    if rate > 0.0:
        c_op_list.append(sqrt(rate) * b)

    if solver == "me":
        output = mesolve(H, psi0 * psi0.dag(), tlist, c_op_list, [na, nb])
        expt_list = output.expect
    elif solver == "wf":
        output = mesolve(H, psi0, tlist, [], [na, nb])
        expt_list = output.expect
    elif solver == "es":
        expt_list = essolve(H, psi0, tlist, c_op_list, [na, nb])
    elif solver == "mc1":
        output = mcsolve(H, psi0, tlist, c_op_list, [na, nb], 1)
        expt_list = output.expect
    elif solver == "mc250":
        output = mcsolve(H, psi0, tlist, c_op_list, [na, nb], 250)
        expt_list = output.expect
    elif solver == "mc500":
        output = mcsolve(H, psi0, tlist, c_op_list, [na, nb], 500)
        expt_list = output.expect
    else:
        raise ValueError("unknown solver")
        
    return expt_list[0], expt_list[1]
In [ ]:
wa  = 1.0 * 2 * pi   # frequency of system a
wb  = 1.0 * 2 * pi   # frequency of system a
wab = 0.1 * 2 * pi   # coupling frequency
ga = 0.0             # dissipation rate of system a
gb = 0.0             # dissipation rate of system b
Na = 2               # number of states in system a
Nb = 2               # number of states in system b

tlist = linspace(0, 10, 200)

show_dynamics = False
style_map = {"es": "r.", "ode": "b", "mc1": "g", "wf": "m*"}

Unitary dynamics

In [ ]:
solvers = ("wf", "es", "mc1")
Na_vec = arange(2, 60, 1)
times = zeros((len(Na_vec), len(solvers)))
In [ ]:
n_runs = 1

for n_run in range(n_runs):
    n_idx = 0
    for Na in Na_vec:   
        s_idx = 0
        for solver in solvers:
            start_time = time.time()
            na, nb = system_integrate(Na, Nb, wa, wb, wab, ga, gb, solver)
            times[n_idx, s_idx] += time.time() - start_time
            s_idx += 1

            if show_dynamics:
                figure(3)
                plot(tlist, real(na), style_map[solver], tlist, real(nb), style_map[solver])    

        if show_dynamics:
            show()

        n_idx += 1

times = times / n_runs
In [ ]:
figure(1)
s_idx = 0

for solver in solvers:
        
    plot(Na_vec * Nb, times[:,s_idx])  
    s_idx += 1

xlabel('Numbers of quantum states')
ylabel('Time to evolve system (seconds)')
title('Comparison of solver performance for unitary evolution')
legend(solvers);

Dissipative dynamics

In [ ]:
ga = 0.05            # dissipation rate of system a
gb = 0.0             # dissipation rate of system b

solvers = ("me", "mc250", "mc500", "es")
Na_vec = arange(2, 35, 2)
show_dynamics = False
times = zeros((len(Na_vec), len(solvers)))
In [ ]:
n_runs = 1

for n_run in range(n_runs):
    n_idx = 0
    for Na in Na_vec:   
        s_idx = 0
        for solver in solvers:
        
            start_time = time.time()
            na, nb = system_integrate(Na, Nb, wa, wb, wab, ga, gb, solver)
            times[n_idx, s_idx] += time.time() - start_time
            s_idx += 1

            if show_dynamics:
                figure(3)
                plot(tlist, real(na), 'r', tlist, real(nb), 'b')    

        if show_dynamics:
            show()

        n_idx += 1

times = times / n_runs
In [ ]:
figure(2)
s_idx = 0

for solver in solvers:
        
    plot(Na_vec * Nb, times[:,s_idx])  
    s_idx += 1

xlabel('Numbers of quantum states')
ylabel('Time to evolve system (seconds)')
title('Comparison of solver performance for nonunitary evolution')
legend(solvers);

Versions

In [ ]:
from qutip.ipynbtools import version_table

version_table()