%pylab inline from scipy import sparse # to define sparse matrices from scipy.sparse.linalg import spsolve # to solve sparse systems def solve_BVP_direct(f,x,u_left,u_right): """ On input: f should be the function defining the ODE u''(x) = f(x), x should be an array of equally spaced points where the solution is to be computed, u_left should be the boundary condition at the left edge x[0], u_right should be the boundary condition at the right edge x[-1], Returns: u, vector of approximate values to solution at the points x. """ n = len(x) - 2 # number of interior points print "Solving tridiagonal system with n =%4i" % n dx = x[1]-x[0] # Form the matrix: d1 = ones(n) d0 = -2 * ones(n) A = sparse.spdiags([d1,d0,d1], [-1,0,1],n,n,format='csc') # Form the right-hand side fx = f(x) rhs = -dx**2 * fx[1:-1] rhs[0] = rhs[0] - u_left rhs[-1] = rhs[-1] - u_right # Solve the system for the interior points: u_int = spsolve(A,rhs) # vector of length n-2 # Append the known boundary values: u = hstack((u_left, u_int, u_right)) return u def check_solution(x,u,u_true): # u should be a vector of approximate solution values at the points x # u_true should be a function defining the true solution. plot(x,u,'ro') plot(x,u_true(x),'b') error = u - u_true(x) error_max = abs(abs(error)).max() print "Maximum error is %10.3e" % error_max f = lambda x: 100. * exp(x) u_left = 20. u_right = 60. u_true = lambda x: (100.*exp(1.) - 60.)*x + 120 - 100.*exp(x) assert abs(u_left - u_true(0.)) < 1e-14, "u_true doesn't match boundary condition u_left" assert abs(u_right - u_true(1.)) < 1e-14, "u_true doesn't match boundary condition u_right" n = 8 # number of interior points x = linspace(0, 1, n+2) u = solve_BVP_direct(f,x,u_left,u_right) check_solution(x,u,u_true) nvals = array([10,20,40,80,160]) errors = zeros(nvals.shape) for j,n in enumerate(nvals): x = linspace(0, 1, n+2) u = solve_BVP_direct(f,x,u_left,u_right) errors[j] = abs(u - u_true(x)).max() # maximum abs error over interval print "\n n error" for j,n in enumerate(nvals): print "%4i %15.9f" % (n,errors[j]) loglog(nvals,errors,'o-') title("Log-log plot of error vs. n") def solve_BVP_split(f,x,u_left,u_right): n2 = len(x) nmid = int(floor(n2/2.)) xhalf1 = x[:nmid+1] xhalf2 = x[nmid:] u_mid = u_true(x[nmid]) # Assumes we know true solution!! print "Using u_mid = ",u_mid uhalf1 = solve_BVP_direct(f,xhalf1,u_left,u_mid) uhalf2 = solve_BVP_direct(f,xhalf2,u_mid,u_right) u = hstack((uhalf1, uhalf2[1:])) return u x = linspace(0, 1, 20) u = solve_BVP_split(f,x,u_left,u_right) check_solution(x,u,u_true) def solve_BVP_split_mismatch(f,x,u_left,u_right): n2 = len(x) dx = x[1]-x[0] nmid = int(floor(n2/2.)) xhalf1 = x[:nmid+1] xhalf2 = x[nmid:] x_mid = x[nmid] u_mid_guesses = linspace(30,80,6) mismatch = zeros(u_mid_guesses.shape) for j,u_mid in enumerate(u_mid_guesses): uhalf1 = solve_BVP_direct(f,xhalf1,u_left,u_mid) uhalf2 = solve_BVP_direct(f,xhalf2,u_mid,u_right) u = hstack((uhalf1, uhalf2[1:])) plot(x,u) mismatch[j] = (uhalf1[-2] - 2.*u_mid + uhalf2[1]) + dx**2 * f(x_mid) print "With u_mid = %g, the mismatch is %g" % (u_mid, mismatch[j]) figure() plot(u_mid_guesses, mismatch,'o-') x = linspace(0, 1, 20) solve_BVP_split_mismatch(f,x,u_left,u_right) def solve_BVP_split(f,x,u_left,u_right): n2 = len(x) dx = x[1]-x[0] nmid = int(floor(n2/2.)) xhalf1 = x[:nmid+1] xhalf2 = x[nmid:] x_mid = x[nmid] # solve the sub-problems twice with different values of u_mid # Note that any two distinct values can be used. u_mid = 0. uhalf1 = solve_BVP_direct(f,xhalf1,u_left,u_mid) uhalf2 = solve_BVP_direct(f,xhalf2,u_mid,u_right) G0 = (uhalf1[-2] - 2.*u_mid + uhalf2[1]) + dx**2 * f(x_mid) v0 = hstack((uhalf1, uhalf2[1:])) u_mid = 1. uhalf1 = solve_BVP_direct(f,xhalf1,u_left,u_mid) uhalf2 = solve_BVP_direct(f,xhalf2,u_mid,u_right) G1 = (uhalf1[-2] - 2.*u_mid + uhalf2[1]) + dx**2 * f(x_mid) v1 = hstack((uhalf1, uhalf2[1:])) z = G1 / (G1 - G0) u = z*v0 + (1-z)*v1 return u x = linspace(0, 1, 20) u = solve_BVP_split(f,x,u_left,u_right) print x.shape, u.shape check_solution(x,u,u_true) def solve_BVP_split_recursive(f,x,u_left,u_right): n2 = len(x) print "Entering solve_BVP_split_recursive with n2 =%4i, for x from %5.3f to %5.3f" \ % (n2, x[0], x[-1]) if n2 < 20: # Stop recursing if the problem is small enough: u = solve_BVP_direct(f,x,u_left,u_right) return u else: # recursive nmid = int(floor(n2/2.)) x_mid = x[nmid] xhalf1 = x[:nmid+1] xhalf2 = x[nmid:] dx = x[1]-x[0] u_mid = 0. uhalf1 = solve_BVP_split_recursive(f,xhalf1,u_left,u_mid) uhalf2 = solve_BVP_split_recursive(f,xhalf2,u_mid,u_right) G0 = (uhalf1[-2] - 2.*u_mid + uhalf2[1]) + dx**2 * f(x_mid) v0 = hstack((uhalf1, uhalf2[1:])) u_mid = 1. uhalf1 = solve_BVP_split_recursive(f,xhalf1,u_left,u_mid) uhalf2 = solve_BVP_split_recursive(f,xhalf2,u_mid,u_right) G1 = (uhalf1[-2] - 2.*u_mid + uhalf2[1]) + dx**2 * f(x_mid) v1 = hstack((uhalf1, uhalf2[1:])) z = G1 / (G1 - G0) u = z*v0 + (1-z)*v1 return u x = linspace(0, 1, 21) u = solve_BVP_split_recursive(f,x,u_left,u_right) check_solution(x,u,u_true)