%matplotlib inline %gui qt import os.path as op import numpy as np from scipy import stats import nibabel as nib import seaborn as sns import matplotlib as mpl import matplotlib.pyplot as plt from surfer import Brain from scipy.ndimage import binary_erosion import lyman subjects = lyman.determine_subjects() analysis_dir = lyman.gather_project_info()["analysis_dir"] from surfer import Brain sns.set(style="ticks", context="paper") mpl.rcParams["savefig.dpi"] = 200 hemis = ["lh", "rh"] python process_searchlight.py -do fit surf group = dict(lh=[], rh=[]) temp = op.join(analysis_dir, "dksort/{subj}/mvpa/searchlight/{hemi}.dimension_dksort_pfc.mgz") for subj in subjects: for hemi in hemis: group[hemi].append(nib.load(temp.format(subj=subj, hemi=hemi)).get_data().squeeze()) group_accs = dict() group_means = dict() group_masks = dict() group_ts = dict() vertex_accs = [] for hemi in hemis: accs = np.vstack(group[hemi]) means = accs.mean(axis=0) masks = accs.all(axis=0) ts = (means - (1. / 3)) / (accs.std(axis=0) / np.sqrt(len(accs) - 1)) ts = np.nan_to_num(ts) vertex_accs.extend(means[masks]) group_accs[hemi] = accs group_means[hemi] = means group_masks[hemi] = masks group_ts[hemi] = ts alpha = 0.005 thresh = stats.t(14).ppf(alpha) * -1 views = dict() for hemi in hemis: b = Brain("fsaverage", hemi, "semi7", config_opts={"background": "white", "width": 500, "height": 420}) data = group_means[hemi] data[group_ts[hemi] < thresh] = 0 b.add_data(group_means[hemi], min=0.3, max=0.5, thresh=0.1, colormap="OrRd_r", colorbar=False) b.add_data(~group_masks[hemi], min=.5, max=1.05, thresh=.5, alpha=.5, colormap="bone_r", colorbar=False) b.add_label("yeo17_ifs", borders=True, color=".3") b.show_view(dict(elevation=80, azimuth=dict(lh=150, rh=30)[hemi], focalpoint=[0, 10, 10]), distance=325) views[hemi] = b.screenshot() b.close() slc = 12 epi = nib.load(op.join(analysis_dir, "dksort/dk11/preproc/run_1/mean_func.nii.gz")).get_data() epi = epi[..., slc].T mask = nib.load(op.join(analysis_dir, "dksort/dk11/preproc/run_1/functional_mask.nii.gz")).get_data() mask = binary_erosion(mask[..., slc].astype(bool).T, iterations=2) roi = nib.load(op.join(analysis_dir, "dksort/dk11/mvpa/searchlight/dksort_pfc_mask.nii.gz")).get_data() roi = roi[..., slc].T epi[~mask] = np.nan roi[roi < .5] = np.nan fig = plt.figure(figsize=(3.34, 2.8)) # Plot the average searchlights on the surface rh_ax = fig.add_axes([-.04, .5, .54, .5]) rh_ax.imshow(views["rh"]) rh_ax.set_axis_off() lh_ax = fig.add_axes([.5, .5, .54, .5]) lh_ax.imshow(views["lh"]) lh_ax.set_axis_off() # Draw a colorbar for the statistical overlay with mpl.rc_context({"axes.linewidth": .4}): cbar_ax = fig.add_axes([.35, .49, .3, .035]) cbar_ax.pcolormesh(np.atleast_2d(np.linspace(0, 1, 100)), cmap="OrRd_r") cbar_ax.set(xticks=[], yticks=[]) fig.text(.34, .505, "0.3", ha="right", va="center", size=7) fig.text(.66, .505, "0.5", ha="left", va="center", size=7) fig.text(.5, .53, "Mean accuracy", ha="center", va="bottom", size=7) # Show an example slice through the mean functional and searchlight mask mask_ax = fig.add_axes([-.04, 0, .45, .45]) mask_ax.imshow(epi, cmap="Greys_r") roi_cmap = mpl.colors.ListedColormap(["steelblue"]) mask_ax.imshow(roi, cmap=roi_cmap, alpha=.7, interpolation="nearest") mask_ax.set_axis_off() # Show the distribution of average searchlight accuracy with mpl.rc_context({"axes.linewidth": .5, "xticks.major.width": .5, "yticks.major.width": .5}): hist_ax = fig.add_axes([.38, .12, .56, .33]) counts, bins = np.histogram(vertex_accs, np.linspace(.3, .4, 25)) hist_pal = sns.color_palette("OrRd_r", 50)[:len(counts)] hist_ax.bar(bins[:-1], counts, width=.0033, color=hist_pal, alpha=.8) hist_ax.set(yticks=[], xlim=(.3, .4)) hist_ax.set_xticks(np.linspace(.3, .4, 6)) hist_ax.set_xticklabels(np.linspace(.3, .4, 6), size=7) hist_ax.set_xlabel("Cross-validated decoding accuracy", labelpad=1.5) hist_ax.axvline(x=.33333, ymax=1, ls=":", c=".2") sns.despine(ax=hist_ax, left=True) fig.text(.02, .94, "A", size=12) fig.text(.02, .41, "B", size=12) fig.text(.36, .41, "C", size=12) fig.savefig("figures/figure_5.tiff", dpi=300) fig.savefig("figures/figure_5.pdf", dpi=300)