%pylab inline import numpy as np import matplotlib.pylab as plt import nibabel as nib fname = 'bold.nii.gz' img = nib.load(fname) hdr = img.get_header() hdr print(hdr) # hdr.get_slice_times() slice_order = range(0, 35, 2) + range(1, 35, 2) slice_order = np.array(slice_order) slice_order n_slices = img.shape[2] n_slices TR = hdr['pixdim'][4] time_one_slice = TR / n_slices time_one_slice space_to_order = np.argsort(slice_order) space_to_order slice_order[space_to_order] slice_times = space_to_order * time_one_slice slice_times slice_times = space_to_order * time_one_slice + time_one_slice / 2. slice_times data = img.get_data() slice0_vol0 = data[:, :, 0, 0] plt.gray() plt.imshow(slice0_vol0) slice0_time_course = data[32, 25, 0, :] # all points in time plt.plot(slice0_time_course) plt.xlabel('Scan number') n_scans = img.shape[-1] scan_starts = np.arange(n_scans) * TR # times scans began scan_starts[:10] slice0_times = scan_starts + slice_times[0] slice0_times[:10] plt.plot(slice0_times, slice0_time_course) plt.plot(slice0_times[:10], slice0_time_course[:10], 'r:+') plt.xlabel('Time of acquisition') space_to_order slice1_time_course = data[32, 25, 1, :] # all points in time plt.plot(slice1_time_course) plt.xlabel('Scan number') slice1_times = scan_starts + slice_times[1] slice1_times[:10] plt.plot(slice0_times[:10], slice0_time_course[:10], 'r:+', label='slice 0') plt.hold(True) plt.plot(slice1_times[:10], slice1_time_course[:10], 'b-o', label='slice 1') plt.xlabel('Time of acquisition') plt.legend() import scipy.interpolate as spi x = slice1_times y = slice1_time_course interpolator = spi.interp1d(x, y, 'linear') interpolator # What happens here? # slice1_at_slice0 = interpolator(slice0_times) interpolator = spi.interp1d(x, y, 'linear', bounds_error=False, fill_value = 0) slice1_at_slice0 = interpolator(slice0_times) plt.plot(slice0_times[:10], slice0_time_course[:10], 'r:+', label='slice 0') plt.hold(True) plt.plot(slice1_times[:10], slice1_time_course[:10], 'b-o', label='slice 1') plt.plot(slice0_times[:10], slice1_at_slice0[:10], 'kx', label='1 -> 0') plt.xlabel('Time of acquisition') plt.legend() interpolator = spi.interp1d(x, y, 'linear', bounds_error = False, fill_value = np.mean(y)) slice1_at_slice0 = interpolator(slice0_times) plt.plot(slice0_times[:10], slice0_time_course[:10], 'r:+', label='slice 0') plt.hold(True) plt.plot(slice1_times[:10], slice1_time_course[:10], 'b-o', label='slice 1') plt.plot(slice0_times[:10], slice1_at_slice0[:10], 'kx', label='1 -> 0') plt.xlabel('Time of acquisition') plt.legend() n_scans x_padded = np.zeros((n_scans + 2,)) y_padded = np.zeros((n_scans + 2,)) x_padded[1:-1] = slice1_times x_padded[0] = x[0] - TR x_padded[-1] = x[-1] + TR y_padded[1:-1] = slice1_time_course y_padded[0] = y[0] y_padded[-1] = y[-1] interpolator = spi.interp1d(x_padded, y_padded, 'linear') slice1_at_slice0 = interpolator(slice0_times) plt.plot(slice0_times[:10], slice0_time_course[:10], 'r:+', label='slice 0') plt.hold(True) plt.plot(slice1_times[:10], slice1_time_course[:10], 'b-o', label='slice 1') plt.plot(slice0_times[:10], slice1_at_slice0[:10], 'kx', label='1 -> 0') plt.xlabel('Time of acquisition') plt.legend() interpolator = spi.interp1d(x_padded, y_padded, 'cubic') fine_time = np.linspace(x_padded[0], x_padded[9], 100) predicted_signal = interpolator(fine_time) plt.plot(fine_time, predicted_signal, ':', label='predicted') plt.plot(slice1_times[:10], slice1_time_course[:10], 'x', label='actual') plt.xlabel('Time of acquisition') plt.legend() def pad_ends(first, middle, last): """ Pad array `middle` along last axis with `first` value and `last` value """ middle = np.array(middle) # Make sure middle is an array # Find the length of the axis we are padding # Work out the shape of the padded array # Make a padded array ready to fill return padded # Return the padded array assert np.all(pad_ends(0, [2, 3], 5) == [0, 2, 3, 5]) a = np.zeros((2, 3)) b = np.ones((2, 3, 4)) * 10 c = np.ones((2, 3)) assert np.all(pad_ends(a, b, c) == np.concatenate((a.reshape((2, 3, 1)), b, c.reshape((2, 3, 1))), axis=2)) slice1_all = data[:, :, 1, :] slice1_all.shape x = slice1_times # as before y = slice1_all # as before interpolator = spi.interp1d(x, y, 'linear', axis=2, bounds_error=False, fill_value=0) slice_dims = img.shape[:2] n_scans = img.shape[3] slice1_all_padded = pad_ends(# What goes here?) interpolator = spi.interp1d(x_padded, slice1_all_padded, 'linear') slice1_all_at_slice0 = interpolator(slice0_times) slice1_at_slice0_again = slice1_all_at_slice0[32, 25, :] plt.plot(slice0_times[:10], slice0_time_course[:10], 'r:+', label='slice 0') plt.hold(True) plt.plot(slice1_times[:10], slice1_time_course[:10], 'b-o', label='slice 1') plt.plot(slice0_times[:10], slice1_at_slice0_again[:10], 'kx', label='1 -> 0') plt.xlabel('Time of acquisition') plt.legend() def interp_slice(old_times, slice_nd, new_times, kind='linear'): """ Interpolate a 3D slice `slice_nd` with times changing from `old_times` to `new_times` """ n_time = slice_nd.shape[-1] assert n_time == len(old_times) # padded_times = ? # padded_slice_nd = ? # interpolator = ? return interpolator(new_times) interpolated = interp_slice(slice1_times, slice1_all, slice0_times) assert np.allclose(interpolated, slice1_all_at_slice0) data = img.get_data() slice_times scan_starts interp_data = np.empty(data.shape) desired_times = scan_starts for slice_no in range(data.shape[-2]): these_times = slice_times[slice_no] + scan_starts data_slice = data[:, :, slice_no, :] interped = interp_slice(these_times, data_slice, desired_times, 'cubic') interp_data[:, :, slice_no, :] = interped one_tc = data[32, 32, 17, :] old_times = slice_times[17] + scan_starts n_scans = len(one_tc) one_tc_padded = pad_ends(one_tc[0], one_tc, one_tc[-1]) old_times_padded = pad_ends(old_times[0] - TR, old_times, old_times[-1] + TR) interpolator = spi.interp1d(old_times_padded, one_tc_padded, 'cubic') tc = interpolator(desired_times) assert np.allclose(tc, interp_data[32, 32, 17, :]) def slice_time_image(img, slice_times, TR, kind='cubic'): """ Take nibabel image `img` and run slice timing correction using `slice_times` """ data = img.get_data() assert len(slice_times) == img.shape[-2] # What goes here in order to create interp_data? new_img = nib.Nifti1Image(interp_data, img.get_affine(), img.get_header()) return new_img new_img = slice_time_image(img, slice_times, TR) new_data = new_img.get_data() assert np.allclose(tc, new_data[32, 32, 17, :]) import os pth, fname = os.path.split(fname) pth, fname new_fname = os.path.join(pth, 'a' + fname) new_fname raw_img = nib.load(fname) interp_img = slice_time_image(raw_img, slice_times, TR) nib.save(interp_img, new_fname) def slice_time_file(fname, slice_times, TR, kind='cubic'): # Use the stuff above to make this work... slice_time_file(fname, slice_times, TR, 'cubic')