%matplotlib inline
%gui qt
import os.path as op
import numpy as np
from scipy import stats
import nibabel as nib
import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt
from surfer import Brain
from scipy.ndimage import binary_erosion
import lyman
subjects = lyman.determine_subjects()
analysis_dir = lyman.gather_project_info()["analysis_dir"]
from surfer import Brain
sns.set(style="ticks", context="paper")
mpl.rcParams["savefig.dpi"] = 200
hemis = ["lh", "rh"]
The computation is handled by an external script (process_searchlight.py
). You can activate the cell below and run it.
Load in the cached searchlight data
group = dict(lh=[], rh=[])
temp = op.join(analysis_dir, "dksort/{subj}/mvpa/searchlight/{hemi}.dimension_dksort_pfc.mgz")
for subj in subjects:
for hemi in hemis:
group[hemi].append(nib.load(temp.format(subj=subj, hemi=hemi)).get_data().squeeze())
Do a group t-test against expected chance at each vertex
group_accs = dict()
group_means = dict()
group_masks = dict()
group_ts = dict()
vertex_accs = []
for hemi in hemis:
accs = np.vstack(group[hemi])
means = accs.mean(axis=0)
masks = accs.all(axis=0)
ts = (means - (1. / 3)) / (accs.std(axis=0) / np.sqrt(len(accs) - 1))
ts = np.nan_to_num(ts)
vertex_accs.extend(means[masks])
group_accs[hemi] = accs
group_means[hemi] = means
group_masks[hemi] = masks
group_ts[hemi] = ts
Set a threshold based on the t statistic for plotting (though we will plot the mean accuracy)
alpha = 0.005
thresh = stats.t(14).ppf(alpha) * -1
Use PySurfer to plot the data on the surface
views = dict()
for hemi in hemis:
b = Brain("fsaverage", hemi, "semi7", config_opts={"background": "white", "width": 500, "height": 420})
data = group_means[hemi]
data[group_ts[hemi] < thresh] = 0
b.add_data(group_means[hemi], min=0.3, max=0.5, thresh=0.1, colormap="OrRd_r", colorbar=False)
b.add_data(~group_masks[hemi], min=.5, max=1.05, thresh=.5, alpha=.5, colormap="bone_r", colorbar=False)
b.add_label("yeo17_ifs", borders=True, color=".3")
b.show_view(dict(elevation=80,
azimuth=dict(lh=150, rh=30)[hemi],
focalpoint=[0, 10, 10]), distance=325)
views[hemi] = b.screenshot()
b.close()
Grab an example slice through a subject's mean fucntional and searchlight mask
slc = 12
epi = nib.load(op.join(analysis_dir, "dksort/dk11/preproc/run_1/mean_func.nii.gz")).get_data()
epi = epi[..., slc].T
mask = nib.load(op.join(analysis_dir, "dksort/dk11/preproc/run_1/functional_mask.nii.gz")).get_data()
mask = binary_erosion(mask[..., slc].astype(bool).T, iterations=2)
roi = nib.load(op.join(analysis_dir, "dksort/dk11/mvpa/searchlight/dksort_pfc_mask.nii.gz")).get_data()
roi = roi[..., slc].T
epi[~mask] = np.nan
roi[roi < .5] = np.nan
Plot the data
fig = plt.figure(figsize=(3.34, 2.8))
# Plot the average searchlights on the surface
rh_ax = fig.add_axes([-.04, .5, .54, .5])
rh_ax.imshow(views["rh"])
rh_ax.set_axis_off()
lh_ax = fig.add_axes([.5, .5, .54, .5])
lh_ax.imshow(views["lh"])
lh_ax.set_axis_off()
# Draw a colorbar for the statistical overlay
with mpl.rc_context({"axes.linewidth": .4}):
cbar_ax = fig.add_axes([.35, .49, .3, .035])
cbar_ax.pcolormesh(np.atleast_2d(np.linspace(0, 1, 100)), cmap="OrRd_r")
cbar_ax.set(xticks=[], yticks=[])
fig.text(.34, .505, "0.3", ha="right", va="center", size=7)
fig.text(.66, .505, "0.5", ha="left", va="center", size=7)
fig.text(.5, .53, "Mean accuracy", ha="center", va="bottom", size=7)
# Show an example slice through the mean functional and searchlight mask
mask_ax = fig.add_axes([-.04, 0, .45, .45])
mask_ax.imshow(epi, cmap="Greys_r")
roi_cmap = mpl.colors.ListedColormap(["steelblue"])
mask_ax.imshow(roi, cmap=roi_cmap, alpha=.7, interpolation="nearest")
mask_ax.set_axis_off()
# Show the distribution of average searchlight accuracy
with mpl.rc_context({"axes.linewidth": .5,
"xticks.major.width": .5,
"yticks.major.width": .5}):
hist_ax = fig.add_axes([.38, .12, .56, .33])
counts, bins = np.histogram(vertex_accs, np.linspace(.3, .4, 25))
hist_pal = sns.color_palette("OrRd_r", 50)[:len(counts)]
hist_ax.bar(bins[:-1], counts, width=.0033, color=hist_pal, alpha=.8)
hist_ax.set(yticks=[], xlim=(.3, .4))
hist_ax.set_xticks(np.linspace(.3, .4, 6))
hist_ax.set_xticklabels(np.linspace(.3, .4, 6), size=7)
hist_ax.set_xlabel("Cross-validated decoding accuracy", labelpad=1.5)
hist_ax.axvline(x=.33333, ymax=1, ls=":", c=".2")
sns.despine(ax=hist_ax, left=True)
fig.text(.02, .94, "A", size=12)
fig.text(.02, .41, "B", size=12)
fig.text(.36, .41, "C", size=12)
fig.savefig("figures/figure_5.tiff", dpi=300)
fig.savefig("figures/figure_5.pdf", dpi=300)