%matplotlib inline from make_partition_plots import make_butterfly make_butterfly() from IPython import display display.Audio(filename='resources/bluewhale.wav') import numpy as np def load_whale(): # http://www.mathworks.com/help/matlab/math/fast-fourier-transform-fft.html data = np.load('resources/bluewhale.npz') X = data['X'] sampling_rate = int(data['rate']) blue_whale_begin = 24500 - 1 blue_whale_end = 31000 - 1 blue_whale_call = X[blue_whale_begin:blue_whale_end + 1] size_call = len(blue_whale_call) # Pad signal with zeros up to the next power of 2. N = int(2**np.ceil(np.log2(size_call))) blue_whale_call = np.hstack([ blue_whale_call, np.zeros(N - len(blue_whale_call)), ]) return sampling_rate, blue_whale_call sampling_rate, blue_whale_call = load_whale() N = len(blue_whale_call) time_base = np.arange(N) * 10.0 / sampling_rate from matplotlib import pyplot as plt plt.plot(time_base, blue_whale_call) plt.title('Blue Whale B Call') plt.xlim((time_base[0], time_base[-1])) plt.xlabel('Time (seconds)') plt.ylabel('Amplitude') plt.show() import time start = time.time() dft_whale_call = np.fft.fft(blue_whale_call, n=N) fft_duration = time.time() - start message = r'\text{Duration with FFT optimized: } %2.10f' % (fft_duration,) display.display(display.Math(message)) re_sampled_time = np.arange(len(dft_whale_call)) * sampling_rate / (10.0 * N) amplitude = np.abs(dft_whale_call) / N plt.plot(re_sampled_time[:N/2], amplitude[:N/2]) plt.xlabel('Frequency (Hz)') plt.ylabel('Amplitude') plt.title('Component Frequencies') plt.show() from make_partition_plots import get_random_intervals from make_partition_plots import naive_interaction S_VALUES, T_VALUES = get_random_intervals(16) naive_interaction(S_VALUES, T_VALUES, N=16) def compute_f_hat(f, t, s, kernel_func): f_hat = np.zeros(f.shape, dtype=np.complex128) for k in xrange(len(f)): # Vectorized update. f_hat[k] = np.sum(kernel_func(t[k], s) * f) return f_hat # We expect a "slowdown factor" of N / log N expected_quadratic = fft_duration * N / np.log2(N) message = r'\text{Expected Quadratic Run-time: } %2.10f' % (expected_quadratic,) display.display(display.Math(message)) N_vals = np.arange(N, dtype=np.float64) t = 2 * np.pi * N_vals s = N_vals / N def dft_kernel(t, s): return np.exp(- 1.0j * t * s) start = time.time() f_hat = compute_f_hat(blue_whale_call, t, s, dft_kernel) naive_duration = time.time() - start message = r'\text{Actual Naive Duration: } %2.10f' % (naive_duration,) display.display(display.Math(message)) error = np.linalg.norm(f_hat - dft_whale_call, ord=2) sig, exponent = str(error).split('e') expr = r'\|e\|_2 = %s \cdot 10^{%s}' % (sig, exponent) display.display(display.Math(expr)) naive_interaction(S_VALUES, T_VALUES, N=16) from make_partition_plots import binned_interaction binned_interaction(S_VALUES, T_VALUES, L=2) from make_partition_plots import make_1D_centers make_1D_centers(L=2, s_values=S_VALUES) make_butterfly(L=4, level=0) make_butterfly(L=4, level=1) make_butterfly(L=4, level=2) make_butterfly(L=4, level=3) make_butterfly(L=4, level=4) import sympy t, s, tau, sigma = sympy.symbols('t s tau sigma') RHS = tau * sigma + tau * (s - sigma) + sigma * (t - tau) + (s - sigma) * (t - tau) display.display(display.Math(sympy.latex(RHS))) display.display(display.Math(sympy.latex(RHS.simplify()))) make_butterfly(L=4, level=0) make_butterfly(L=4, level=4) def get_bins_and_deltas(vals, min_val, bin_width, num_bins): bin_indices = np.floor((vals - min_val) / bin_width) # max(vals) falls in the last bin bin_indices = np.minimum(bin_indices, num_bins - 1) bin_centers = min_val + (0.5 + bin_indices) * bin_width return bin_indices.astype(int), vals - bin_centers def create_initial_data(s_values, min_s, max_s, tau, actual_data, num_bins, M): bin_width = (max_s - min_s) / float(num_bins) bin_indices, s_deltas = get_bins_and_deltas(s_values, min_s, bin_width, num_bins) sum_parts = np.zeros((len(s_values), M), dtype=np.complex128) sum_parts[:, 0] = dft_kernel(tau, s_values) * actual_data for alpha in xrange(1, M): sum_parts[:, alpha] = (sum_parts[:, alpha - 1] * s_deltas * (-1.0j) / alpha) result = [] curr_sigma = min_s + 0.5 * bin_width for bin_index in xrange(num_bins): sum_across_bin = np.sum( sum_parts[np.where(bin_indices == bin_index)[0], :], axis=0) result.append((tau, curr_sigma, sum_across_bin.reshape(M, 1))) curr_sigma += bin_width return result def compute_t_by_bins(t_values, min_t, max_t, coeff_vals, num_bins, M): bin_width = (max_t - min_t) / float(num_bins) bin_indices, t_deltas = get_bins_and_deltas(t_values, min_t, bin_width, num_bins) # We assume all sigma values are the same and do not check. sigma = coeff_vals[0][1] exponents = np.arange(M, dtype=np.float64) t_delta_powers = t_deltas[:, np.newaxis]**exponents[np.newaxis, :] # Use hstack since the vectors are 2D, then transpose. coefficients = np.hstack([triple[2] for triple in coeff_vals]).T t_matched_coeffs = coefficients[bin_indices, :] fhat_t = np.sum(t_delta_powers * t_matched_coeffs, axis=1, dtype=np.complex128) fhat_t *= dft_kernel(t_deltas, sigma) return fhat_t def A1(M, delta, eye_func=np.eye): result = eye_func(M) result[0, 1] = delta for col in xrange(2, M): prev_val = result[0, col - 1] # Pascal's triangle does not apply at ends. The zero term # already set on the diagonal. result[0, col] = delta * prev_val for row in xrange(1, col): curr_val = result[row, col - 1] result[row, col] = prev_val + delta * curr_val prev_val = curr_val return result A1_val = A1(5, -1.0) print A1_val A1_symb = A1(5, sympy.Symbol('Delta'), eye_func=sympy.eye) display.display(display.Math(sympy.latex(A1_symb))) from make_partition_plots import show_tau_refine show_tau_refine() def mult_diag(val, A, M, diag): first_row = max(0, -diag) last_row = min(M, M - diag) for row in xrange(first_row, last_row): A[row, row + diag] *= val def A_update(A_val, scale_multiplier=0.5, upper_diags=True): M = A_val.shape[0] # If not `upper_diags` we want the lower diagonal. diag_mult = 1 if upper_diags else -1 # We don't need to update the main diagonal since exponent=0. scale_factor = 1 for diagonal in xrange(1, M): scale_factor *= scale_multiplier mult_diag(scale_factor, A_val, M, diagonal * diag_mult) A1_negative = A1_val.copy() A_update(A1_negative, scale_multiplier=-1.0) print A1_negative all_ones = sympy.ones(4, 4) A_update(all_ones, scale_multiplier=sympy.Symbol('x'), upper_diags=True) A_update(all_ones, scale_multiplier=sympy.Symbol('y'), upper_diags=False) display.display(display.Math(sympy.latex(all_ones))) from make_partition_plots import show_sigma_coarsen show_sigma_coarsen() def complex_eye(M): return np.eye(M, dtype=np.complex128) def set_diag(val, A, M, diag): first_row = max(0, -diag) last_row = min(M, M - diag) for row in xrange(first_row, last_row): A[row, row + diag] = val def A2(M, delta, eye_func=complex_eye, imag=1.0j): result = eye_func(M) new_delta = -imag * delta diagonal_value = 1 for sub_diagonal in xrange(1, M): diagonal_value = diagonal_value * new_delta / sub_diagonal set_diag(diagonal_value, result, M, -sub_diagonal) return result A2_val = A2(5, 1.0j) print 'Frobenius norm of imaginary part:', np.linalg.norm(np.imag(A2_val), ord='fro') print '-' * 40 A2_val = np.real(A2_val) print A2_val A2_symb = A2(5, sympy.Symbol('Delta'), eye_func=sympy.eye, imag=sympy.I) display.display(display.Math(sympy.latex(A2_symb))) A_update(A2_val, scale_multiplier=2.0, upper_diags=False) print A2_val A_update(A2_val, scale_multiplier=-1.0, upper_diags=False) print A2_val a, b, c, d, e, f = sympy.symbols('a b c d e f') M_left = sympy.Matrix([[a, b, 0, 0], [0, 0, a, b]]) M_right = sympy.Matrix([[c, 0], [0, d], [e, 0], [0, f]]) message = sympy.latex(M_left) + sympy.latex(M_right) display.display(display.Math(message)) M_prod = M_left * M_right message = sympy.latex(M_prod) display.display(display.Math(message)) from itertools import izip def increase_tau_refinement(values, num_tau, num_sigma, update_func): result = [] for tau_index in xrange(num_tau): begin = tau_index * num_sigma end = begin + num_sigma # We need to hold the right values until all the left values # have been added. right_updated = [] # Assumes num_sigma is even. left_vals, right_vals = values[begin:end:2], values[1 + begin:end:2] for left_val, right_val in izip(left_vals, right_vals): new_left, new_right = update_func(left_val, right_val) result.append(new_left) right_updated.append(new_right) result.extend(right_updated) return result def custom_update(left_val, right_val): tau, sigma1 = left_val tau2, sigma2 = right_val new_sigma = sigma1 >> 1 if tau != tau2 or new_sigma != (sigma2 >> 1): raise ValueError(left_val, right_val) new_left = (2 * tau, new_sigma) new_right = (2 * tau + 1, new_sigma) return new_left, new_right num_tau = 2 num_sigma = 4 index_pairs1 = [(tau, sigma) for tau in xrange(num_tau) for sigma in xrange(num_sigma)] index_pairs2 = increase_tau_refinement( index_pairs1, num_tau, num_sigma, custom_update) index_pairs3 = increase_tau_refinement( index_pairs2, 2 * num_tau, num_sigma / 2, custom_update) row_template = (r'\left(\tau_{%d}, \sigma_{%d}\right) & \rightarrow & ' r'\left(\tau_{%d}, \sigma_{%d}\right) & \rightarrow & ' r'\left(\tau_{%d}, \sigma_{%d}\right) \\') latex_rows = [ r'\begin{array}{ccccc}', r'\ell = 1 & & \ell = 2 & & \ell = 3 \\', r'\hline', ] for (i1, j1), (i2, j2), (i3, j3) in zip(index_pairs1, index_pairs2, index_pairs3): latex_rows.append(row_template % (i1, j1, i2, j2, i3, j3)) latex_rows.append(r'\end{array}') display.display(display.Math('\n'.join(latex_rows))) def make_update_func(A1_minus, A1_plus, A2_minus, A2_plus, delta_T): top_left, top_right = A2_minus.dot(A1_minus), A2_plus.dot(A1_minus) bottom_left, bottom_right = A2_minus.dot(A1_plus), A2_plus.dot(A1_plus) def update_func(left_val, right_val): tau, sigma, alpha_vals_left = left_val # We expect the pair to share tau, and don't check to avoid slowdown. tau_same, sigma_prime, alpha_vals_right = right_val sigma_minus = 0.5 * (sigma + sigma_prime) tau_left = tau - delta_T tau_right = tau + delta_T new_alpha_vals_left = ( dft_kernel(-delta_T, sigma) * top_left.dot(alpha_vals_left) + dft_kernel(-delta_T, sigma_prime) * top_right.dot(alpha_vals_right) ) new_alpha_vals_right = ( dft_kernel(delta_T, sigma) * bottom_left.dot(alpha_vals_left) + (dft_kernel(delta_T, sigma_prime) * bottom_right.dot(alpha_vals_right)) ) new_left_val = (tau_left, sigma_minus, new_alpha_vals_left) new_right_val = (tau_right, sigma_minus, new_alpha_vals_right) return new_left_val, new_right_val return update_func def solve(s, t, data, L, M): min_t, max_t = np.min(t), np.max(t) min_s, max_s = np.min(s), np.max(s) num_bins = 2**L tau = 0.5 * (min_t + max_t) coeff_vals = create_initial_data(s, min_s, max_s, tau, data, num_bins, M) # ell = 0 delta_S = (max_s - min_s) / (2.0 * num_bins) delta_T = (max_t - min_t) * 0.25 num_tau, num_sigma = 1, num_bins A1_minus, A1_plus = A1(M, -delta_T), A1(M, delta_T) A2_minus, A2_plus = A2(M, -delta_S), A2(M, delta_S) for ell in xrange(1, L + 1): update_func = make_update_func(A1_minus, A1_plus, A2_minus, A2_plus, delta_T) coeff_vals = increase_tau_refinement(coeff_vals, num_tau, num_sigma, update_func) num_tau, num_sigma, delta_T = update_loop_vals( num_tau, num_sigma, delta_T, A1_minus, A1_plus, A2_minus, A2_plus) return compute_t_by_bins(t, min_t, max_t, coeff_vals, num_bins, M) def update_loop_vals(num_tau, num_sigma, delta_T, A1_minus, A1_plus, A2_minus, A2_plus): num_tau = num_tau * 2 num_sigma = num_sigma / 2 delta_T *= 0.5 A_update(A1_plus, scale_multiplier=0.5, upper_diags=True) A_update(A1_minus, scale_multiplier=0.5, upper_diags=True) A_update(A2_plus, scale_multiplier=2.0, upper_diags=False) A_update(A2_minus, scale_multiplier=2.0, upper_diags=False) return num_tau, num_sigma, delta_T N = blue_whale_call.size N_vals = np.arange(N, dtype=np.float64) t = 2 * np.pi * N_vals s = N_vals / N L = 10 # relax by 3 levels M = 50 start = time.time() soln = solve(s, t, blue_whale_call, L, M) duration = time.time() - start message = r'\text{Butterfly Duration: } %2.10f' % (duration,) display.display(display.Math(message)) message = r'\text{Recall, Actual Naive Duration: } %2.10f' % (naive_duration,) display.display(display.Math(message)) error = np.linalg.norm(soln - dft_whale_call, ord=2) sig, exponent = str(error).split('e') expr = r'\|e\|_2 = %s \cdot 10^{%s}' % (sig, exponent) display.display(display.Math(expr)) # NOTE: This assumes s, t, blue_whale_call, dft_whale_call are globals at definition time def get_time(L, M, s=s, t=t, data=blue_whale_call, true_soln=dft_whale_call): start = time.time() soln = solve(t, s, data, L, M=M) duration = time.time() - start error = np.linalg.norm(soln - true_soln, ord=2) return error, duration L10_M_choices = (10, 18, 26, 34, 42, 50, 58) L10_time_pairs = [get_time(L=10, M=M) for M in L10_M_choices] L11_M_choices = (5, 10, 15, 20, 25, 30, 35) L11_time_pairs = [get_time(L=11, M=M) for M in L11_M_choices] L12_M_choices = (4, 8, 12, 16, 20, 24, 28) L12_time_pairs = [get_time(L=12, M=M) for M in L12_M_choices] L13_M_choices = (3, 6, 9, 12, 15, 18, 21) L13_time_pairs = [get_time(L=12, M=M) for M in L13_M_choices] from make_partition_plots import make_time_plots make_time_plots(L10_M_choices, L11_M_choices, L12_M_choices, L13_M_choices, L10_time_pairs, L11_time_pairs, L12_time_pairs, L13_time_pairs) # Custom styling, borrowed from https://github.com/ellisonbg/talk-2013-scipy import os MAKE_SLIDES = os.getenv('MAKE_BUTTERFLY_SLIDES') == 'True' # Make plots centered. H/T to http://stackoverflow.com/a/27168595/1068170 STYLE = """\ """ SLIDES_STYLE = """ """ if MAKE_SLIDES: STYLE = STYLE + '\n' + SLIDES_STYLE display.display(display.HTML(STYLE)) %%file make_partition_plots.py from matplotlib import pyplot as plt import numpy as np N_exp = 4 # 8 boxes N = 2**N_exp # import seaborn # seaborn.husl_palette(n_colors=7) SEABORN_COLORS = [ [0.9677975592919913, 0.44127456009157356, 0.5358103155058701], [0.7757319041862729, 0.5784925270759935, 0.19475566538551875], [0.5105309046900421, 0.6614299289084904, 0.1930849118538962], [0.20433460114757862, 0.6863857739476534, 0.5407103379425205], [0.21662978923073606, 0.6676586160122123, 0.7318695594345369], [0.5049017849530067, 0.5909119231215284, 0.9584657252128558], [0.9587050080494409, 0.3662259565791742, 0.9231469575614251], ] def remove_axis_frame(ax): ax.set_frame_on(False) ax.yaxis.set_label_position('right') ax.set_xticklabels([]) # H/T: http://stackoverflow.com/a/20416681/1068170 for tic in ax.xaxis.get_major_ticks(): tic.tick1On = tic.tick2On = False ax.set_yticklabels([]) for tic in ax.yaxis.get_major_ticks(): tic.tick1On = tic.tick2On = False def make_1D_centers(L=N_exp, s_values=None): rows, cols = 1, 1 fig, ax = plt.subplots(rows, cols) N = 2**L all_sigma = np.linspace(0, 1, N + 1) ax.plot(all_sigma, np.zeros(all_sigma.shape), color='black', marker='|', markersize=20) if s_values is not None: ax.plot(s_values, np.zeros(s_values.shape), color='blue', marker='o', linestyle='None') center_pts = all_sigma[:-1] + 0.5 / N ax.plot(center_pts, np.zeros(center_pts.shape), color='red', marker='o', linestyle='None') for i, x_val in enumerate(center_pts): label = r'$\sigma_{%d}$' % (i,) ax.annotate(label, xy=(x_val, -0.05), xytext=(x_val - 0.01, -0.04), fontsize=24) remove_axis_frame(ax) ax.axis('scaled') ax.set_xlim(-0.05, 1.05) ax.set_ylim(-0.05, 0.05) width, height = fig.get_size_inches() fig.set_size_inches(2 * width, 2 * height, forward=True) plt.show() def add_labeled_M(ax, time_only, log_errors, M_choices, loc='upper right'): for x, y, M, c in zip(time_only, log_errors, M_choices, SEABORN_COLORS): ax.plot([x], [y], marker='o', label=r'$M = %d$' % (M,), color=c) ax.legend(loc=loc) def make_time_plots(L10_M_choices, L11_M_choices, L12_M_choices, L13_M_choices, L10_time_pairs, L11_time_pairs, L12_time_pairs, L13_time_pairs): rows = cols = 2 fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(rows, cols) L10_log_errors = [np.log10(pair[0]) for pair in L10_time_pairs] L11_log_errors = [np.log10(pair[0]) for pair in L11_time_pairs] L12_log_errors = [np.log10(pair[0]) for pair in L12_time_pairs] L13_log_errors = [np.log10(pair[0]) for pair in L13_time_pairs] L10_time_only = [pair[1] for pair in L10_time_pairs] L11_time_only = [pair[1] for pair in L11_time_pairs] L12_time_only = [pair[1] for pair in L12_time_pairs] L13_time_only = [pair[1] for pair in L13_time_pairs] true_min = np.min(L10_log_errors + L11_log_errors + L12_log_errors + L13_log_errors) true_min -= 1.0 true_max = np.max(L10_log_errors + L11_log_errors + L12_log_errors + L13_log_errors) true_max += 1.3 true_left = np.min(L10_time_only + L11_time_only + L12_time_only + L13_time_only) true_left -= 0.1 true_right = np.max(L10_time_only + L11_time_only + L12_time_only + L13_time_only) true_right += 0.1 ax1.plot(L10_time_only, L10_log_errors, color='black') ax1.set_ylabel(r'$\log_{10} ||e||_2$', rotation=0, fontsize=20, labelpad=40) ax1.set_title(r'$L = 10$', fontsize=20) ax1.set_xlim((true_left, true_right)) ax1.set_ylim((true_min, true_max)) add_labeled_M(ax1, L10_time_only, L10_log_errors, L10_M_choices) ax2.plot(L11_time_only, L11_log_errors, color='black') ax2.set_title(r'$L = 11$', fontsize=20) ax2.set_xlim((true_left, true_right)) ax2.set_ylim((true_min, true_max)) add_labeled_M(ax2, L11_time_only, L11_log_errors, L11_M_choices) ax3.plot(L12_time_only, L12_log_errors, color='black') ax3.set_ylabel(r'$\log_{10} ||e||_2$', rotation=0, fontsize=20, labelpad=40) ax3.set_xlabel('runtime', fontsize=20) ax3.set_title(r'$L = 12$', fontsize=20) ax3.set_xlim((true_left, true_right)) ax3.set_ylim((true_min, true_max)) add_labeled_M(ax3, L12_time_only, L12_log_errors, L12_M_choices, loc='upper left') ax4.plot(L13_time_only, L13_log_errors, marker='o') ax4.set_title(r'$L = 13$', fontsize=20) ax4.set_xlabel('runtime', fontsize=20) ax4.set_xlim((true_left, true_right)) ax4.set_ylim((true_min, true_max)) add_labeled_M(ax4, L13_time_only, L13_log_errors, L13_M_choices, loc='upper left') width, height = fig.get_size_inches() fig.set_size_inches(2 * width, 2 * height, forward=True) plt.show() def make_butterfly(L=4, level=None): rows, cols = 1, 1 fig, ax = plt.subplots(rows, cols) if level is not None: s_index, t_index = L - level, level ax.set_title(r'$L = %d, \, \ell = %d$' % (L, L - s_index), fontsize=20) heights = np.linspace(0, 1, L + 1) x_vals = np.array([0.0, 1.0]) # Left and right half. for i in xrange(L + 1): j = L - i begin = np.array([0.0, heights[i]]) end = np.array([1.0, heights[j]]) quarter_dir = 0.25 * (end - begin) dx, dy = quarter_dir ax.arrow(dx, heights[i] + dy, 2 * dx, 2 * dy, length_includes_head=True, color='blue') ax.arrow(0.0, heights[i], dx, dy, length_includes_head=False, color='black') ax.plot([3 * dx, 1.0], [heights[i] + 3 * dy, heights[j]], color='black', linewidth=1.5) if level is not None: s_boundaries = np.linspace(-0.55, -0.05, 2**s_index + 1) ax.plot(s_boundaries, heights[s_index] * np.ones(2**s_index + 1), color='black', marker='|', markersize=20) s_width = s_boundaries[1] - s_boundaries[0] ax.plot(s_width * 0.5 + s_boundaries[:-1], heights[s_index] * np.ones(2**s_index), color='red', marker='o', linestyle='None') t_boundaries = np.linspace(1.05, 1.55, 2**t_index + 1) ax.plot(t_boundaries, heights[t_index] * np.ones(2**t_index + 1), color='black', marker='|', markersize=20) t_width = t_boundaries[1] - t_boundaries[0] ax.plot(t_width * 0.5 + t_boundaries[:-1], heights[t_index] * np.ones(2**t_index), color='red', marker='o', linestyle='None') remove_axis_frame(ax) ax.axis('scaled') if level is None: ax.set_xlim(0, 1) ax.set_ylim(0, 1) else: ax.set_xlim(-0.55, 1.55) ax.set_ylim(-0.1, 1.1) width, height = fig.get_size_inches() fig.set_size_inches(2 * width, 2 * height, forward=True) plt.show() def get_random_intervals(N): interval_width = 1.0 / N interval_starts = np.linspace(0, 1 - interval_width, N) L = interval_starts + interval_width * np.random.random(N) R = interval_starts + interval_width * np.random.random(N) return L, R def naive_interaction(L, R, N=16, return_fig=False): rows, cols = 1, 1 fig, ax = plt.subplots(rows, cols) N = N or len(L) ax.set_title(r'$N = %d$' % (N,), fontsize=20) for l_val in L: for r_val in R: ax.plot([0, 1], [l_val, r_val], color='blue') remove_axis_frame(ax) ax.annotate('$S$', xy=(0.0, 0.5), xytext=(-0.05, 0.5), fontsize=20) ax.annotate('$T$', xy=(1.0, 0.5), xytext=(1.025, 0.5), fontsize=20) ax.axis('scaled') width, height = fig.get_size_inches() fig.set_size_inches(2 * width, 2 * height, forward=True) if return_fig: return ax, fig else: plt.show() def binned_interaction(s_values, t_values, L=2): N = len(s_values) bin_width = 1.0 / 2**L sigma_values = 0.5 * bin_width + np.arange(2**L) * bin_width bin_ends = np.arange(2**L + 1) * bin_width ax, fig = naive_interaction(sigma_values, t_values, N=N, return_fig=True) ax.plot(np.zeros(bin_ends.shape), bin_ends, marker='_', color='black', markersize=20) ax.plot(np.zeros(s_values.shape), s_values, color='blue', marker='o', linestyle='None') ax.plot(np.zeros(sigma_values.shape), sigma_values, color='red', marker='o', linestyle='None') ax.set_xlim(-0.1, 1.0) plt.show() def show_tau_refine(): rows, cols = 1, 1 fig, ax = plt.subplots(rows, cols) ax.plot([-1, 1], [0, 0], color='black', marker='|', markersize=20) ax.plot([-1, 0, 1], [-0.2, -0.2, -0.2], color='black', marker='|', markersize=20) ax.plot([0], [0], color='red', marker='o', linestyle='None') ax.plot([-0.5, 0.5], [-0.2, -0.2], color='red', marker='o', linestyle='None') dx, dy = 0.5, -0.2 ax.arrow(-dx * 0.1, dy * 0.1, -dx * 0.8, dy * 0.8, length_includes_head=True, color='blue') ax.arrow(dx * 0.1, dy * 0.1, dx * 0.8, dy * 0.8, length_includes_head=True, color='blue') remove_axis_frame(ax) ax.axis('scaled') ax.set_xlim(-1.1, 1.1) ax.set_ylim(-0.3, 0.1) ax.annotate(r'$\tau$', xy=(0.0, 0.0), xytext=(-0.02, 0.05), fontsize=24) ax.annotate(r"$\tau_{+}'$", xy=(0.0, 0.0), xytext=(-0.5 - 0.025, -0.325), fontsize=24) ax.annotate(r"$\tau_{+}''$", xy=(0.0, 0.0), xytext=(0.5 - 0.025, -0.325), fontsize=24) width, height = fig.get_size_inches() fig.set_size_inches(2 * width, 2 * height, forward=True) plt.show() def show_sigma_coarsen(): rows, cols = 1, 1 fig, ax = plt.subplots(rows, cols) ax.plot([-1, 1], [-0.2, -0.2], color='black', marker='|', markersize=20) ax.plot([-1, 0, 1], [0, 0, 0], color='black', marker='|', markersize=20) ax.plot([0], [-0.2], color='red', marker='o', linestyle='None') ax.plot([-0.5, 0.5], [0, 0], color='red', marker='o', linestyle='None') dx, dy = 0.5, -0.2 ax.arrow(-0.5 + dx * 0.1, dy * 0.1, dx * 0.8, dy * 0.8, length_includes_head=True, color='blue') ax.arrow(0.5 - dx * 0.1, dy * 0.1, -dx * 0.8, dy * 0.8, length_includes_head=True, color='blue') remove_axis_frame(ax) ax.axis('scaled') ax.set_xlim(-1.1, 1.1) ax.set_ylim(-0.3, 0.1) ax.annotate(r'$\sigma_{\minus}$', xy=(0.0, 0.0), xytext=(-0.02, -0.275), fontsize=24) ax.annotate(r'$\sigma$', xy=(0.0, 0.0), xytext=(-0.5 - 0.025, 0.05), fontsize=24) ax.annotate(r"$\sigma'$", xy=(0.0, 0.0), xytext=(0.5 - 0.025, 0.05), fontsize=24) width, height = fig.get_size_inches() fig.set_size_inches(2 * width, 2 * height, forward=True) plt.show()