from __future__ import division import os.path as op import itertools import numpy as np import scipy as sp import pandas as pd import nibabel as nib from scipy import stats import statsmodels.api as sm import matplotlib as mpl import matplotlib.pyplot as plt from sklearn.linear_model import LogisticRegression from sklearn.metrics import r2_score import lyman from lyman import mvpa, evoked import seaborn as sns import moss %matplotlib inline sns.set(context="paper", style="ticks", font="Arial") mpl.rcParams.update({"xtick.major.width": 1, "ytick.major.width": 1, "savefig.dpi": 150}) pd.set_option('display.precision', 3) from mpl_toolkits import mplot3d from mpl_toolkits.mplot3d import Axes3D np.random.seed(sum(map(ord, reversed("DKsort")))) %load_ext rmagic %R library(lme4) %R library(multcomp) from IPython.parallel import Client, TimeoutError try: dv = Client()[:] dv4 = Client()[:4] except (IOError, TimeoutError): dv = None dv4 = None import warnings warnings.filterwarnings("ignore", category=DeprecationWarning) def save_figure(fig, figname): fig.savefig("figures/%s.pdf" % figname, dpi=300) fig.savefig("figures/%s.tiff" % figname, dpi=300) subjects = pd.Series(lyman.determine_subjects(), name="subj") project = lyman.gather_project_info() anal_dir = project["analysis_dir"] data_dir = project["data_dir"] rule_colors = dict(dim=["#EE6F25", "#4BC75B", "#9370DB"], dec=["#6495EB", "#FF6347"]) roi_colors = dict(IFS="#CC3333", aMFG="#3380CC", pMFG="#33CC80", pSFS="#CC8732", IFG="#6464D8", aIns="#F08080", FPC="#CCCC33", IPS="#2A4380", OTC="#24913C") roi_colors.update({"aIFS": "#863D3D", "pIFS": "#7A5252", "lh-IFS": "#D19494", "rh-IFS": "#C2A3A3"}) %%R lr_test = function(m1, m2, name){ out = anova(m1, m2) chi2 = out$Chisq[2] dof = out$"Chi Df"[2] p = out$"Pr(>Chisq)"[2] test_str = "Likelihood ratio test for %s:\n Chisq(%d) = %.2f; p = %.3g" writeLines(sprintf(test_str, name, dof, chi2, p)) } def dksort_behav_join(): behav_temp = op.join(data_dir, "%s/behav/behav_data.csv") behav_data = [] for subj in subjects: df = pd.read_csv(behav_temp % subj, index_col="trial") df["subj"] = np.repeat(subj, len(df)) behav_data.append(df) behav_full = pd.concat(behav_data) behav_df = behav_full[behav_full["clean"] & behav_full["correct"]] return behav_full, behav_df behav_full, behav_df = dksort_behav_join() %Rpush behav_df %Rpush behav_full pfc_rois = ["IFS", "aMFG", "pMFG", "FPC", "IFG", "aIns", "pSFS"] net_rois = ["IFS", "IPS", "OTC"] all_rois = pd.Series(pfc_rois + net_rois[1:], name="roi") other_rois = ["lh-IFS", "rh-IFS", "aIFS", "pIFS", "AllROIs"] other_masks = ["lh.yeo17_ifs", "rh.yeo17_ifs", "yeo17_aifs", "yeo17_pifs", "dksort_all_pfc"] frames = np.arange(-1, 5) timepoints = frames * 2 + 1 up_timepoints = (np.linspace(0, 12, 25) - 1)[:-1] model = LogisticRegression() n_shuffle = 1000 peak = slice(2, 4) shuffle_seed = sum(map(ord, "DKsort")) def dksort_decode(rois, rule): """Do the all the main decoding steps across sets of ROIs. Parameters ---------- rois: list of strings list of ROI names that can be easily mapped to a mask name rule: string name of rule type corresponding to design file in data dir Returns ------- dictionary with the following entries rois: list of roi names accs: DataFrame with decoding accuracy. Index is hierarchical with (ROI, subject), and columns are timepoint. chance: DataFrame in same shape as `accs` with the mean value from the shuffled null distribution for each test. peak: DataFrame with decoding accuracy from data averaged over 3s and 5s. Index is subject id and column is ROI. null: DataFrame with maximum shuffled accuracy across timepoints. Index is hierarchical with (subj, iteration) and columns are ROIS. ttest: DataFrame with ROIs in the index and (t, p, max_tp) for the group average test against empirical chance in the columns. `t` is the t statistic for the peak accuracy value, and `p` is the corresponding p value corrected for multiple comparisons across region and timepoint. max_tp is in seconds relative to stimulus onset. signif: DataFrame with the null distribution percentiles corresponding to the best observed accuracy for each subject/ROI. Index is hierarchical with (correction, subject) where correction can be `time` or `omni` for the space of tests the percentile is corrected against. columns are ROIs. """ # Set up the DataFrames to hold the persisent outputs roi_index = moss.product_index([rois, subjects], ["ROI", "subj"]) columns = pd.Series(timepoints, name="timepoints") accs = pd.DataFrame(index=roi_index, columns=columns, dtype=float) null_index = moss.product_index([subjects, np.arange(n_shuffle)], ["subj", "iter"]) null = pd.DataFrame(index=null_index, columns=pd.Series(rois, name="ROI"), dtype=float) peak_df = pd.DataFrame(index=subjects, columns=pd.Series(rois, name="ROI")) chance = pd.DataFrame(index=roi_index, columns=timepoints, dtype=float) # For each ROI load the data, decode, and simulate the null distribution for roi in rois: mask = "yeo17_" + roi.lower() # Load the dataset and do the basic time-resolved decoding ds = mvpa.extract_group(rule, roi, mask, frames, confounds="rt", dv=dv4) roi_accs = mvpa.decode_group(ds, model, dv=dv) accs.loc[roi, :] = roi_accs # Now do the shuffling and save the partially transformed null distribution roi_null = mvpa.classifier_permutations(ds, model, n_iter=n_shuffle, random_seed=shuffle_seed, dv=dv) chance.loc[roi, :] = roi_null.mean(axis=1) null[roi] = roi_null.max(axis=-1).ravel() # Finally re-load the data averaging over the peak timepoints and decode peak_ds = mvpa.extract_group(rule, roi, mask, frames, peak, "rt", dv=dv4) peak_df[roi] = mvpa.decode_group(peak_ds, model, dv=dv) # Do the group t tests wide_accs = accs.unstack(level="ROI") wide_chance = chance.unstack(level="ROI") mus = wide_accs.mean(axis=0) sds = wide_accs.std(axis=0) ts, ps = moss.randomize_onesample(wide_accs, h_0=wide_chance) # Build the t test output t_df = pd.DataFrame(dict(t=ts, p=ps, mu=mus, sd=sds), index=wide_accs.columns, columns=["mu", "sd", "t", "p"]) ttest = t_df.groupby(level="ROI").apply(lambda x: x.loc[x.mu.idxmax()]) ttest["tp"] = t_df.groupby(level="ROI").mu.apply(lambda x: x.idxmax()[0]) # From the null distribution, find the percentile corresponding to the observed score signif_index = moss.product_index([["time", "omni"], subjects], ["correction", "subj"]) signif = pd.DataFrame(index=signif_index, columns=rois, dtype=float) for roi, roi_accs in accs.max(axis=1).groupby(level="ROI"): for subj, score in roi_accs.groupby(level="subj"): signif.loc[("time", subj), roi] = stats.percentileofscore(null[roi], score[0]) signif.loc[("omni", subj), roi] = stats.percentileofscore(null.max(axis=1), score[0]) return dict(rois=rois, accs=accs, chance=chance, peak=peak_df, null=null, ttest=ttest, signif=signif) def dksort_timecourse_figure(ax, data, ytick_args, err_style="ci_band", legend=True, err_kws=None, **kwargs): """Plot time-resolved decoding accuracies for multiple ROIs""" # Represent chance empirically based on the null distribution chance = np.array(data["chance"].mean(axis=0)) ax.plot(timepoints, chance, "k--") # Plot the stimulus onset ax.plot([0, 0], ytick_args[:2], ls=":", c="k") # Draw the accuracy timecourse for each ROI colors = [roi_colors[roi] for roi in data["rois"]] interpolate = not err_style == "ci_bars" accs = pd.melt(data["accs"].reset_index(), ["subj", "ROI"], value_name="acc") sns.tsplot(accs, time="timepoints", unit="subj", condition="ROI", value="acc", color=colors, err_style=err_style, interpolate=interpolate, legend=legend, err_kws=err_kws, ax=ax, **kwargs) # Adjust the axis scales and tick labels ylim = ytick_args[:2] yticks = np.linspace(*ytick_args) ax.set_xlim(timepoints.min(), timepoints.max()) ax.set_ylim(ylim) ax.set_yticks(yticks) # Label the axes ax.set_xlabel("Time relative to stimulus onset (s)") ax.set_ylabel("Cross-validated decoding accuracy") # Make a legend with the ROIs ax.legend(frameon=False, loc="upper right") def dksort_point_figure(ax, data, chance, ytick_args, yaxis_label): """Plot a point estimate and CI for the 3-5s accuracy by ROI.""" # Plot the chance line ax.axhline(chance, color="k", ls="--") # Plot each accuracy and confidence interval for i, roi in enumerate(data.columns): color = roi_colors[roi] accs = data[roi].values ci = moss.ci(moss.bootstrap(accs), 68) ax.plot(i, accs.mean(), "o", color=color, ms=5, mew=0, mec=color) ax.plot([i, i], ci, color=color, lw=2, solid_capstyle="round") # Set the axis limits ax.set_xlim(-.5, data.shape[1] - .5) ax.set_ylim(*ytick_args[:2]) yticks = np.linspace(*ytick_args) ax.set_yticks(yticks) ax.set_yticklabels(["%.2f" % t for t in yticks]) ax.xaxis.grid(False) # Set the axis labels ax.set_xticks(np.arange(len(data.columns))) ax.set_xticklabels(data.columns, rotation=60) if yaxis_label: ax.set_ylabel("Cross-validated decoding accuracy") else: ax.set_yticklabels([]) def dksort_significance_figure(ax, data): """Plot the percentile of the null for each ROI by subject.""" width = 2 / len(data) xbase = np.arange(0, 1, width) rois = data.columns # Plot percentile across ROI/time and time correction for kind, kind_df in data.groupby(level="correction"): if kind == "time": continue for i, roi in enumerate(rois): height = kind_df[roi] color = sns.desaturate(roi_colors[roi], .75) edgecolor = sns.desaturate(color, .5) alpha = .25 if kind == "time" else .7 ax.bar(xbase + i, np.sort(height), width=width, color=color, edgecolor=edgecolor, lw=0.1, alpha=alpha) if i: ax.axvline(i, lw=.5, color="#555555") # Set the other plot aesthetics ax.set_ylim(0, 100) ax.axhline(95, c="#444444", ls=":") ax.set_xticks(np.arange(data.shape[1]) + .5) ax.set_xticklabels(rois, rotation=60) ax.xaxis.grid(False) ax.set_ylabel("Percentile in null distribution") def dksort_behav_summary(): accs = behav_full.groupby("subj").correct.mean() print "Accuracy across subjects: Mean %.2g; Std %.2g; Min: %.2g" % (accs.mean(), accs.std(), accs.min()) rts = behav_df.groupby("subj").rt.median() * 1000 print "Median RT across subjects (ms) - Mean %.3g; Std %.3g" % (rts.mean(), rts.std()) iqrs = behav_df.groupby("subj").rt.apply(moss.iqr) * 1000 print "RT IQRs across subjects (ms) - Mean: %.3g; Std %.3g" % (iqrs.mean(), iqrs.std()) dksort_behav_summary() %%R m.rt.int = lmer(rt ~ dim_rule * dec_rule + (dim_rule + dec_rule | subj), behav_df, REML=F) m.rt.add = lmer(rt ~ dim_rule + dec_rule + (dim_rule + dec_rule | subj), behav_df, REML=F) lr_test(m.rt.int, m.rt.add, "interaction") %R print(m.rt.add, corr=FALSE) %%R print(anova(m.rt.add)) m.rt.dim = lmer(rt ~ dec_rule + (dim_rule + dec_rule | subj), behav_df, REML=F) m.rt.dec = lmer(rt ~ dim_rule + (dim_rule + dec_rule | subj), behav_df, REML=F) lr_test(m.rt.add, m.rt.dim, "dimension rule") lr_test(m.rt.add, m.rt.dec, "decision rule") def dksort_rt_anova(): subj_grouped = behav_df.groupby("subj") dim = subj_grouped.apply(moss.df_oneway, "dim_rule", "rt", False) dec = subj_grouped.apply(moss.df_oneway, "dec_rule", "rt", False) f_scores = pd.concat([dim, dec], keys=["dimension", "decision"], names=["rules", "subj"]) print "Median F score: %.2f" % f_scores.loc["dimension"]["F"].median() f_scores["p < 0.05"] = f_scores.p < 0.05 f_scores["p < 0.1"] = f_scores.p < 0.1 print f_scores.groupby(level="rules")[["p < 0.05", "p < 0.1"]].sum() return subj_grouped, f_scores subj_grouped, f_scores = dksort_rt_anova() def dksort_rule_rt_effect(): pivot = pd.pivot_table(behav_df, "rt", "subj", "dec_rule", np.median) cost = np.diff(pivot, axis=1) * 1000 boots = moss.bootstrap(cost) ci_low, ci_high = moss.ci(boots, 95) print "'Different' rule effect (ms): Mean %.3g, CI: [%.3g, %.3g]" % (boots.mean(), ci_low, ci_high) dksort_rule_rt_effect() %%R m.match = lmer(rt ~ attend_match + (attend_match | subj), behav_df, REML=FALSE) m.match.drop = lmer(rt ~ (attend_match | subj), behav_df, REML=FALSE) print(m.match, corr=FALSE) lr_test(m.match, m.match.drop, "attended matching") %%R m.match.int = lmer(rt ~ attend_match * dec_rule + (attend_match + dec_rule | subj), behav_df, REML=FALSE) m.match.add = lmer(rt ~ attend_match + dec_rule + (attend_match + dec_rule | subj), behav_df, REML=FALSE) print(m.match.int, corr=FALSE) lr_test(m.match.int, m.match.add, "interaction") %%R m.acc.int = lmer(correct ~ dim_rule * dec_rule + (dim_rule + dec_rule | subj), behav_full, family=binomial) m.acc.add = lmer(correct ~ dim_rule + dec_rule + (dim_rule + dec_rule | subj), behav_full, family=binomial) lr_test(m.acc.int, m.acc.add, "interaction") %R print(m.acc.add, corr=FALSE) %%R m.acc.dim = lmer(correct ~ dec_rule + (dim_rule + dec_rule | subj), behav_full, family=binomial) m.acc.dec = lmer(correct ~ dim_rule + (dim_rule + dec_rule | subj), behav_full, family=binomial) lr_test(m.acc.add, m.acc.dim, "dimension rules") lr_test(m.acc.add, m.acc.dec, "decision rules") def dksort_rule_acc_effect(): pivot = pd.pivot_table(behav_full, "correct", "subj", "dec_rule") cost = np.diff(pivot, axis=1) * 100 boots = moss.bootstrap(cost) low, high = moss.ci(boots, 95) args = cost.mean(), low, high print "'Different' rule response accuracy cost: %.2g%%; CI: [%.2g%%, %.2g%%]" % args dksort_rule_acc_effect() %%R m.lag.dim = lmer(rt ~ dim_shift_lag + (dim_shift_lag | subj), behav_df, REML=FALSE) m.lag.dec = lmer(rt ~ dec_shift_lag + (dec_shift_lag | subj), behav_df, REML=FALSE) %R print(m.lag.dim, corr=FALSE) %R print(m.lag.dec, corr=FALSE) %%R m.lag.nodim = lmer(rt ~ 1 + (dim_shift_lag | subj), behav_df, REML=FALSE) m.lag.nodec = lmer(rt ~ 1 + (dec_shift_lag | subj), behav_df, REML=FALSE) lr_test(m.lag.dim, m.lag.nodim, "dimension rules") lr_test(m.lag.dec, m.lag.nodec, "decision rules") %%R shift_subset = behav_df$block_pos == 0 m.shift.add = lmer(rt ~ dim_shift + dec_shift + (dim_shift + dec_shift | subj), behav_df, subset=shift_subset, REML=FALSE) m.shift.int = lmer(rt ~ dim_shift * dec_shift + (dim_shift + dec_shift | subj), behav_df, subset=shift_subset, REML=FALSE) lr_test(m.shift.add, m.shift.int, "interaction") %R print(m.shift.add, corr=FALSE) %%R m.shift.nodim = lmer(rt ~ dec_shift + (dim_shift + dec_shift | subj), behav_df, subset=shift_subset, REML=FALSE) m.shift.nodec = lmer(rt ~ dim_shift + (dim_shift + dec_shift | subj), behav_df, subset=shift_subset, REML=FALSE) lr_test(m.shift.add, m.shift.nodim, "dimension rule") lr_test(m.shift.add, m.shift.nodec, "decision rule") def dksort_figure_2(): # Set up variables shared across subplots dfs = [behav_df, behav_full] measures = ["rt", "correct"] agg_funcs = [np.median, np.mean] ci = 68 xlabels = ["Dimension rules", "Attended features", "Trials since rule switch"] ylabels = ["Reaction time (s)", "Response accuracy"] xticks = [range(3), range(2), range(1, 7)] xticklabels = [["Shape", "Color", "Pattern"], ["Mismatch", "Match"], range(1, 7)] xlims = [[-.5, 2.5], [-.5, 1.5], [.25, 6.75]] ylims = [[.67, 1.02], [.8, 1.02]] yticks = [[.7, 1, 4], [.8, 1, 3]] ms, mew, lw = 3.5, 1.2, 1.2 err_kws = dict(linewidth=lw, mew=mew) lag_colors = dict(Dimension="black", Decision="lightslategray") text_offset = (0.01, -0.01) text_size = 11 # Draw each plot f, axes = plt.subplots(2, 3, figsize=(4.48, 3)) for i in range(2): # First analyze RT/accuracy analysis sorted by rule types pivots = [] for subj, df_subj in dfs[i].groupby("subj"): pivot = pd.pivot_table(df_subj, measures[i], "dim_rule", "dec_rule", agg_funcs[i]) pivots.append(np.array(pivot)) pivots = np.array(pivots, float) means = pivots.mean(axis=0) cis = moss.ci(moss.bootstrap(pivots, axis=0), ci, axis=0) # Plot the above results in plot columns ax = axes[i, 0] for dim in range(3): color = rule_colors["dim"][dim] for dec in range(2): x = dim + (dec - .5) / 3.5 y = means[dim, dec] ci_x = cis[:, dim, dec] mfc = ax.get_axis_bgcolor() if dec else color # Plotting calls ax.plot([x, x], ci_x, lw=lw, color=color) ax.plot(x, y, "o", ms=ms, mew=mew, mfc=mfc, mec=color) # Draw two plots out of the axis range to support the legend for j, rule in enumerate(["Same", "Different"]): # Plotting calls ax.plot(-1, -1, "o", color="#444444", mec="#444444", ms=ms, mfc=["none", "#444444"][j], mew=mew, label="'%s' rule" % rule) # Next plot the decision rule by feature matching analysis pivots = [] for subj, df_subj in dfs[i].groupby("subj"): pivot = pd.pivot_table(df_subj, measures[i], "attend_match", "dec_rule", agg_funcs[i]) pivots.append(np.array(pivot)) pivots = np.array(pivots, float) means = pivots.mean(axis=0) cis = moss.ci(moss.bootstrap(pivots, axis=0), ci, axis=0) ax = axes[i, 1] for k, rule in enumerate(["Same", "Different"]): color = rule_colors["dec"][k] x = np.array([0, 1]) # Plotting calls sns.tsplot(pivots[..., 1 - k], time=x, err_style="ci_bars", ci=ci, color=color, marker="o", mec=color, label = "'%s' rule" % rule, ms=ms, mfc=color, mew=mew, lw=lw, err_kws=err_kws, ax=ax) # Now do the analysis as a function of rule switching for j, ruleset in enumerate(["Dimension", "Decision"]): color = lag_colors[ruleset] pivot_cols = ruleset[:3].lower() + "_shift_lag" lag_pivot = pd.pivot_table(dfs[i], measures[i], "subj", pivot_cols, agg_funcs[i]) ts_kws = dict(err_style="ci_bars", color=color, ci=ci, lw=lw, err_kws=err_kws, ax=axes[i, 2], marker="o", ms=3) # Plotting calls sns.tsplot(lag_pivot.loc[:, :2].values, time=[1, 2, 3], **ts_kws) sns.tsplot(lag_pivot.loc[:, (2, 3)].values, time=[3, 4], linestyle=":", **ts_kws) sns.tsplot(lag_pivot.loc[:, 3:5].values, time=[4, 5, 6], label=ruleset + " rules", **ts_kws) # Finally, iterate through the plots and set more supporting variables for j in range(3): ax = axes[i, j] yticks_ = np.linspace(*yticks[i]) ax.set_xticks(xticks[j]) ax.set_yticks(yticks_) ax.set_xlim(*xlims[j]) ax.set_ylim(ylims[i]) if i: ax.set_xlabel(xlabels[j]) ax.set_xticklabels(xticklabels[j]) else: ax.set_xticklabels([]) if not j: ax.set_ylabel(ylabels[i]) ax.set_yticklabels(yticks_) else: ax.set_yticklabels([]) if i: ax.legend(loc="lower left") f.subplots_adjust(wspace=.08, hspace=.08, left=.1, bottom=.13, top=.97, right=.97) # Label the facets for ax, s in zip(f.axes, "ACEBDF"): (x, _), (_, y) = ax.bbox.transformed(f.transFigure.inverted()).get_points() x += text_offset[0] y += text_offset[1] f.text(x, y, s, size=text_size, ha="left", va="top") sns.despine() save_figure(f, "figure_2") dksort_figure_2() def dksort_supplemental_figure_1(): f, axes = plt.subplots(2, 1, figsize=(6.85, 6.3)) positions = dict(dim=[.25, .5, .75], dec=[.33, .66]) widths = dict(dim=.18, dec=.25) rules = dict(dim=["shape", "color", "pattern"], dec=["same", "different"]) text_offset = (.01, -.015) text_size = 12 subj_medians = subj_grouped.rt.median() subj_sorted = subjects[np.argsort(subj_medians)] for ruleset, ax in zip(["dimension", "decision"], axes): rs = ruleset[:3] for i, subj in enumerate(subj_sorted): df_i = behav_df[behav_df["subj"] == subj] for j, rule in enumerate(rules[rs]): df_rule = df_i[df_i[rs + "_rule"] == rule] pos = np.array([i + positions[rs][j]]) rule_rt = np.array(df_rule["rt"]) color = sns.desaturate(rule_colors[rs][j], .7) sns.boxplot([rule_rt], color=color, positions=pos, fliersize=2, whis=1.75, linewidth=1, widths=widths[rs], ax=ax) ax.set_xlim(0, 15) ax.set_xticks(np.arange(15) + 0.5) rule_fs = f_scores["F"][ruleset].reindex(subj_sorted) ax.set_xticklabels(["%.2g" % _f for _f in rule_fs]) ax.set_ylabel("Reaction time (s)") ax.set_xlabel("One-way F score over rules") indiv_rules = rules[ruleset[:3]] n_rules = len(indiv_rules) ax.legend(ax.artists[:n_rules], indiv_rules, loc=4, ncol=n_rules) sns.despine() f.tight_layout() f.text(.02, .96, "A", size=text_size) f.text(.02, .47, "B", size=text_size) save_figure(f, "supplemental_figure_1") dksort_supplemental_figure_1() def dksort_roi_sizes(): roi_sizes = pd.DataFrame(columns=all_rois, index=subjects, dtype=float) mask_template = op.join(data_dir, "%s/masks/yeo17_%s.nii.gz") for roi in roi_sizes: for subj in subjects: vox_count = nib.load(mask_template % (subj, roi.lower())).get_data().sum() roi_sizes.loc[subj, roi] = vox_count return roi_sizes.describe().loc[["mean", "std"]].astype(int) dksort_roi_sizes() def dksort_roi_signal(): roi_signal = pd.DataFrame(columns=all_rois, index=subjects, dtype=float) for roi in all_rois: signal = evoked.extract_group("yeo17_" + roi.lower(), dv=dv4) signal = [np.concatenate(d["data"]).mean() for d in signal] roi_signal[roi] = signal return (roi_signal / 100).describe().loc[["mean", "std"]] dksort_roi_signal() def dksort_artifact_counts(): artifacts = [] for subj in subjects: subj_artifacts = 0 for run in range(1, 5): art = pd.read_csv(op.join(anal_dir, "dksort/%s/preproc/run_%d/artifacts.csv" % (subj, run))) art = art.max(axis=1) subj_artifacts += art.sum() artifacts.append(subj_artifacts) print "Mean number of artifacts: %d" % np.mean(artifacts) print "Percent artifact scans: %.1f%%" % (np.mean(artifacts) / 2328 * 100) dksort_artifact_counts() pfc_dimension = dksort_decode(pfc_rois, "dimension") pfc_dimension["ttest"] pfc_dimension["signif"].groupby(level="correction").apply(lambda x: (x > 95).sum()) def dksort_pfc_paired_ttests(): pfc_paired_t = pd.DataFrame(index=pfc_rois, columns=pfc_rois, dtype=float) pfc_paired_p = pd.DataFrame(index=pfc_rois, columns=pfc_rois, dtype=float) for roi_i, roi_j in itertools.product(pfc_rois, pfc_rois): peak_i = pfc_dimension["peak"][roi_i] peak_j = pfc_dimension["peak"][roi_j] t, p = stats.ttest_rel(peak_i, peak_j) pfc_paired_t.loc[roi_i, roi_j] = t pfc_paired_p.loc[roi_i, roi_j] = p pfc_paired_p *= 15 sns.symmatplot(pfc_paired_t.values, pfc_paired_p.values, names=pfc_rois, cmap="RdPu_r", cmap_range=(-10, 0)) dksort_pfc_paired_ttests() def dksort_peak_decode(problem, rois, masks, comparison): rois = pd.Series(rois, name="ROI") accs = pd.DataFrame(columns=rois, index=subjects, dtype=np.float) ttest = pd.DataFrame(columns=rois, index=["t", "p"], dtype=np.float) for roi, mask in zip(rois, masks): ds = mvpa.extract_group(problem, roi, mask, frames, peak, "rt", dv=dv4) roi_accs = mvpa.decode_group(ds, model, dv=dv).squeeze() accs[roi] = roi_accs ttest[roi] = stats.ttest_rel(roi_accs, comparison) return dict(accs=accs, ttest=ttest, descrip=accs.describe()) ifs_subrois = dksort_peak_decode("dimension", other_rois, other_masks, pfc_dimension["peak"]["IFS"]) def dksort_ifs_mean_decode(): mean_ds = mvpa.extract_group("dimension", "IFS", "yeo17_ifs", frames, peak, "rt") for ds in mean_ds: ds["X"] = ds["X"].mean(axis=1).reshape(-1, 1) ds["roi_name"] = "IFS_mean" accs = mvpa.decode_group(mean_ds, model, dv=dv).squeeze() null = mvpa.classifier_permutations(mean_ds, model, n_shuffle, random_seed=shuffle_seed, dv=dv).squeeze() thresh = moss.percentiles(null, 95, axis=1) ifs_subrois["accs"]["IFS_mean"] = accs ifs_subrois["ttest"]["IFS_mean"] = stats.ttest_rel(accs, pfc_dimension["peak"]["IFS"]) return dict(accs=accs, null=null, thresh=thresh) ifs_mean = dksort_ifs_mean_decode() ifs_subrois["accs"].describe().T ifs_subrois["accs"].apply(lambda d: pd.Series(moss.ci(moss.bootstrap(d), 95), index=[2.5, 97.5]), axis=0) ifs_subrois["ttest"].T def dksort_ifs_lateralization_test(): diff = ifs_subrois["accs"]["lh-IFS"] - ifs_subrois["accs"]["rh-IFS"] low, high = moss.ci(moss.bootstrap(diff), 95) t, p = stats.ttest_1samp(diff, 0) args = (diff.mean(), low, high, t, p) print "Test for laterality: mean difference = %.3f; 95%% CI: %.3f, %.3f, t = %.2f; p = %.3g " % args dksort_ifs_lateralization_test() def dksort_ifs_rostral_test(): diff = ifs_subrois["accs"]["aIFS"] - ifs_subrois["accs"]["pIFS"] low, high = moss.ci(moss.bootstrap(diff), 95) t, p = stats.ttest_1samp(diff, 0) print "Test for rosto-caudal organization: mean difference = %.3f; " % diff.mean(), print "95%% CI: %.3f, %.3f, t = %.2f; p = %.3g " % (low, high, t, p) dksort_ifs_rostral_test() def dksort_ifs_mean_test(): low, high = moss.ci(moss.bootstrap(ifs_mean["accs"]), 95) chance = ifs_mean["null"].mean() delta = ifs_mean["accs"] - chance t, p = moss.randomize_onesample(delta) signif = ifs_mean["accs"] > ifs_mean["thresh"] print "IFS_Mean ROI:" print " mean accuracy = %.3f; 95%% CI: %.3f, %.3f;" % (ifs_mean["accs"].mean(), low, high), print "chance = %.3f; t = %.2f; p = %.3g " % (chance, t, p) print " %d/15 significant IFS_Mean models" % signif.sum() dksort_ifs_mean_test() def dksort_figure_4(): f = plt.figure(figsize=(4.48, 5.5)) ax_ts = f.add_axes([.12, .53, .86, .45]) dksort_timecourse_figure(ax_ts, pfc_dimension, (.32, .48, 5)) ax_sig = f.add_axes([.12, .08, .34, .34]) dksort_significance_figure(ax_sig, pfc_dimension["signif"]) ax_pfc_peak = f.add_axes([.59, .08, .22, .34]) dksort_point_figure(ax_pfc_peak, pfc_dimension["peak"], .33, (.32, .5, 7), True) ax_sub_peak = f.add_axes([.86, .08, .12, .34]) dksort_point_figure(ax_sub_peak, ifs_subrois["accs"].filter(regex="IFS$"), .33, (.32, .5, 7), False) f.text(.01, .97, "A", size=12) f.text(.01, .42, "B", size=12) f.text(.48, .42, "C", size=12) f.text(.825, .42, "D", size=12) sns.despine() save_figure(f, "figure_4") dksort_figure_4() pfc_cue = dksort_decode(["IFS"], "dimension_cue") def dksort_cue_test(): accs = pfc_cue["accs"].mean(axis=0) peak_tp = np.argmax(accs) max_accs = pfc_cue["accs"].values[:, peak_tp] low, high = moss.ci(moss.bootstrap(max_accs), 95) peak_time = pfc_cue["accs"].columns[peak_tp] args = max_accs.mean(), peak_time, low, high print "Max accuracy: %.3f at %ds; 95%% CI: %.3f, %.3f" % args dksort_cue_test() pfc_cue["ttest"] pfc_decision = dksort_decode(pfc_rois, "decision") pfc_decision["ttest"] pfc_decision["signif"].groupby(level="correction").apply(lambda x: (x > 95).sum()) other_decision = dksort_peak_decode("decision", other_rois, other_masks, pfc_decision["peak"]["IFS"]) def dksort_figure_5(): f = plt.figure(figsize=(4.48, 5.5)) ax_ts = f.add_axes([.12, .53, .86, .45]) dksort_timecourse_figure(ax_ts, pfc_decision, (.43, .67, 7)) ax_sig = f.add_axes([.12, .08, .34, .34]) dksort_significance_figure(ax_sig, pfc_decision["signif"]) ax_pfc_peak = f.add_axes([.59, .08, .22, .34]) dksort_point_figure(ax_pfc_peak, pfc_decision["peak"], .5, (.43, .67, 7), True) ax_sub_peak = f.add_axes([.86, .08, .12, .34]) dksort_point_figure(ax_sub_peak, other_decision["accs"].filter(regex="IFS$"), .5, (.43, .67, 7), False) f.text(.01, .97, "A", size=12) f.text(.01, .42, "B", size=12) f.text(.48, .42, "C", size=12) f.text(.825, .42, "D", size=12) sns.despine() save_figure(f, "figure_5") dksort_figure_5() net_dimension = dksort_decode(net_rois, "dimension") net_dimension["signif"].groupby(level="correction").apply(lambda x: (x > 95).sum()) net_dimension["ttest"] def dksort_calculate_ar1(): # Set up the dataframe roi_index = moss.product_index([net_rois, subjects], ["roi", "subj"]) ar1s = pd.Series(index=roi_index, name="ar1", dtype=float) shifters = [(0, 0, 1), (0, 1, 0), (1, 0, 0)] for roi in net_rois: mask_name = "yeo17_" + roi.lower() ds = mvpa.extract_group("dimension", roi, mask_name, frames, peak, "rt") coefs = mvpa.model_coefs(ds, model, flat=False) for subj, coef_s in zip(subjects, coefs): ar1_vals = [] # Average over the three underlying models for coef_c in coef_s.transpose(3, 0, 1, 2): mask = ~np.isnan(coef_c) orig = coef_c[mask] locs = np.argwhere(mask) # Average over shifts in the x, y, and z direction for shifter in shifters: x, y, z = (locs + shifter).T shifted = coef_c[x, y, z] notnan = ~np.isnan(shifted) r, p = stats.pearsonr(orig[notnan], shifted[notnan]) ar1_vals.append(r) ar1s.loc[(roi, subj)] = np.mean(ar1_vals) # Print descriptive statistics about the autocorrelation for roi in net_rois: m, sd = ar1s.loc[roi].describe()[["mean", "std"]] print "%s Spatial AR1: M = %.3g (SD = %.2g)" % (roi, m, sd) # Perform pairwise tests for each ROI combination for roi_a, roi_b in itertools.combinations(net_rois, 2): t, p = stats.ttest_rel(ar1s.loc[roi_a], ar1s.loc[roi_b]) dof = len(ar1s.loc[roi_a]) - 1 print "Test for %s vs. %s: t(%d) = %.3f; p = %.3g" % (roi_a, roi_b, dof, t, p) ar1_df = ar1s.reset_index() return ar1_df ar1_df = dksort_calculate_ar1() %%R -i ar1_df m.ar1 = lmer(ar1 ~ roi + (roi | subj), ar1_df, REML=FALSE) m.ar1.nest = lmer(ar1 ~ 1 + (roi | subj), ar1_df, REML=FALSE) print(anova(m.ar1)) lr_test(m.ar1, m.ar1.nest, "ROI effect") def dksort_upsample(): if dv is not None: dv["up_timepoints"] = up_timepoints roi_index = moss.product_index([net_rois, subjects], ["ROI", "subj"]) columns = pd.Series(up_timepoints, name="timepoints") up_accs = pd.DataFrame(index=roi_index, columns=columns, dtype=float) up_hrf_info = pd.DataFrame(index=roi_index, columns=["peak", "r2"], dtype=float) up_hrfs = {} bounds = dict(shape=(2, 8), coef=(0, None), loc=(-2.5, 2.5), scale=(0, 3)) for roi in net_rois: mask = "yeo17_" + roi.lower() ds = mvpa.extract_group("dimension", roi + "_500ms", mask, frames, upsample=4, confounds="rt", dv=dv4) roi_accs = mvpa.decode_group(ds, model, dv=dv) up_accs.loc[roi, :] = roi_accs hrfs = [moss.GammaHRF(loc=-1.5, bounds=bounds) for s in subjects] mapper = map if dv is None else dv.map_sync hrfs_ = mapper(lambda h, a: h.fit(up_timepoints, a), hrfs, roi_accs) up_hrfs[roi] = hrfs_ for subj, hrf in zip(subjects, hrfs_): up_hrf_info.loc[(roi, subj)] = hrf.peak_time_, hrf.fit_r2_ return up_hrfs, up_hrf_info, up_accs up_hrfs, up_hrf_info, up_accs = dksort_upsample() def dksort_peak_report(): for roi, info in up_hrf_info.groupby(level="ROI"): peak_time = info["peak"].mean() low, high = moss.ci(moss.bootstrap(info["peak"]), 95) print roi + "\n---" print " Mean peak: %.2fs; 95%% CI: (%.2f, %.2f)" % (peak_time, low, high) r2 = info["r2"].median() low, high = moss.ci(moss.bootstrap(info["r2"], func=np.median), 95) print " Median R^2: %.2f; 95%% CI: (%.2f, %.2f)\n" % (r2, low, high) dksort_peak_report() def up_peak_test(info, roi_a, roi_b): peaks = np.array(info["peak"].unstack(level="ROI")[[roi_a, roi_b]]) diff = np.diff(peaks, axis=1) mean_diff = np.mean(diff) diff_ci = moss.ci(moss.bootstrap(diff), 95) t, p = stats.ttest_rel(peaks[:, 0], peaks[:, 1]) print "%s - %s difference:" % (roi_a, roi_b) print " mean = %.2fs;" % mean_diff, print "95%% CI: %.2f, %.2f" % tuple(diff_ci) print " t(%d) = %.3f; p = %.3g" % (len(diff) - 1, t, p) print " %d subjects with positive difference" % (diff > 0).sum() up_peak_test(up_hrf_info, "IFS", "OTC") up_peak_test(up_hrf_info, "IPS", "OTC") up_peak_test(up_hrf_info, "IFS", "IPS") def dksort_figure_6(): f = plt.figure(figsize=(4.48, 5.5)) net_colors = [roi_colors[roi] for roi in net_rois] # Set up the axes ax_ts = f.add_axes([.12, .53, .855, .44]) ax_peak = f.add_axes([.12, .07, .16, .35]) ax_ar1 = f.add_axes([.37, .29, .61, .13]) ax_time = f.add_axes([.37, .07, .61, .13]) # Time-resolved decoding accuracy figure # Plot the original-resolution accuracy as point estimate and CI dksort_timecourse_figure(ax_ts, net_dimension, (.3, .66, 5), "ci_bars", False, {"linewidth": 2}, ms=5) # Plot the Gamma HRF modeled-accuracy as a smooth curve xx = np.linspace(-2, 10, 500) net_rois_r = list(reversed(net_rois)) fits = [[hrf.predict(xx) for hrf in up_hrfs[roi]] for roi in net_rois_r] fits = np.transpose(fits, (1, 2, 0)) sns.tsplot(fits, time=xx, condition=net_rois_r, err_style=None, linewidth=1.75, color=reversed(net_colors), ax=ax_ts) # Remove the old chance line that doesn't span the whole plot ax_ts.lines[0].remove() ax_ts.plot([-2, 10], [1 / 3, 1 / 3], ls="--", c="k") # Peak average decoding accuracy dksort_point_figure(ax_peak, net_dimension["peak"], .33, (.3, .66, 5), True) # Spatial autocorrelation of model parameters box_colors = [sns.set_hls_values(c, l=.5, s=.3) for c in net_colors] ar1_vals = [df.ar1.values for _, df in ar1_df.groupby("roi")] sns.boxplot(ar1_vals, vert=False, color=box_colors, linewidth=1, widths=.6, ax=ax_ar1) ax_ar1.set_yticklabels(net_rois) ax_ar1.set_xlim(.05, .35) ax_ar1.set_xticks([.1, .2, .3]) ax_ar1.xaxis.grid(True) ax_ar1.set_xlabel("Spatial autocorrelation of feature weights (r)") # Pairwise differences in peak decoding time time_vals = [up_hrf_info.loc[roi, "peak"] - up_hrf_info.loc["OTC", "peak"] for roi in ["IFS", "IPS"]] sns.boxplot(time_vals, vert=False, ax=ax_time, linewidth=1, widths=.5, color=box_colors[:2]) ax_time.set_xlim(-2.2, .7) ax_time.xaxis.grid(True) ax_time.set_yticklabels(net_rois[:2]) ax_time.axvline(0, ls=":", c="#444444") ax_time.set_xlabel("Time of peak decoding relative to OTC (s)") # Panel labels text_size = 12 f.text(.01, .97, "A", size=text_size) f.text(.01, .43, "B", size=text_size) f.text(.3, .43, "C", size=text_size) f.text(.3, .20, "D", size=text_size) sns.despine() save_figure(f, "figure_6") dksort_figure_6() def dksort_model_logits(): logits = dict() preds = dict() bins = np.arange(-14.5, 14.5, 1) for roi in net_rois: mask = "yeo17_" + roi.lower() ds = mvpa.extract_group("dimension", roi, mask, frames, peak, "rt", dv=dv4) logits_ = mvpa.decode_group(ds, model, logits=True, trialwise=True, dv=dv) logits[roi] = np.concatenate(logits_) preds_ = mvpa.decode_group(ds, model, trialwise=True, dv=dv) preds[roi] = np.concatenate(preds_) logit_df = behav_df[["subj"]] for roi in net_rois: logit_df[roi] = logits[roi] logit_df[roi + "_acc"] = preds[roi] logit_df[roi + "_bin"] = bins[np.digitize(logits[roi], bins)] - 1 return logit_df.reset_index(drop=True) logit_df = dksort_model_logits() %%R -i logit_df m.logits = lmer(OTC_acc ~ IFS + IPS + (IFS + IPS | subj), logit_df, family=binomial) m.logits.noifs = lmer(OTC_acc ~ IPS + (IFS + IPS | subj), logit_df, family=binomial) m.logits.noips = lmer(OTC_acc ~ IFS + (IFS + IPS | subj), logit_df, family=binomial) %%R print(m.logits, corr=FALSE) lr_test(m.logits, m.logits.noifs, "IFS effect") lr_test(m.logits, m.logits.noips, "IPS effect") %%R C = matrix(c(0, -1, 1), 1, 3) rownames(C) = "IPS-IFS" print(summary(glht(m.logits, C))) def dksort_collect_motion_info(): all_motion = [] motion_template = op.join(anal_dir, "dksort/%s/preproc/run_%d/realignment_params.csv") for subj in subjects: subj_motion = [] for run in range(1, 5): motion = pd.read_csv(motion_template % (subj, run))["displace_rel"] stim_onsets = behav_df.loc[(behav_df.subj == subj) & (behav_df.run == run), "stim_time"] stim_onsets = stim_onsets.values.astype(int) / 2 first_tr = stim_onsets + 2 second_tr = stim_onsets + 3 run_motion = np.mean([motion[first_tr], motion[second_tr]], axis=0) subj_motion.append(run_motion) subj_motion = np.concatenate(subj_motion) all_motion.append(subj_motion) logit_df["motion"] = stats.zscore(np.concatenate(all_motion)) dksort_collect_motion_info() %%R -i logit_df m.logits.motion = lmer(OTC_acc ~ IFS + IPS + motion + (IFS + IPS + motion | subj), logit_df, family=binomial) print(m.logits.motion) def dksort_figure_7(): # Fit a logistic regression for each subject and save predictions xx = np.linspace(-10, 6, 100) x_pred = sm.add_constant(xx, prepend=True) pred_rois = ["IFS", "IPS"] models = {roi: np.empty((subjects.size, xx.size)) for roi in pred_rois} xlim = [-6, 4] models = dict() for roi in pred_rois: x_fit = sm.add_constant(logit_df[roi].values, prepend=True) fit = sm.GLM(logit_df.OTC_acc.values, x_fit, family=sm.families.Binomial()).fit() models[roi] = fit.predict(x_pred) f = plt.figure(figsize=(3.34, 5.5)) axr = f.add_axes([.15, .38, .8, .58]) # Plot the data for each ROI for i, roi in enumerate(pred_rois): color = roi_colors[roi] histcolor = sns.set_hls_values(color, l=.5, s=.3) sns.tsplot(models[roi], time=xx, color=color, label=roi, linewidth=1.25, err_style=None, ax=axr) roi_pivot = pd.pivot_table(logit_df, "OTC_acc", "subj", roi + "_bin").values bins = np.sort(logit_df[roi + "_bin"].unique()) sns.tsplot(roi_pivot, time=bins, err_style="ci_bars", color=color, interpolate=False, estimator=stats.nanmean, err_kws={"linewidth": 2}, markersize=6, ax=axr) axh = f.add_axes([.15, .23 - i * .15, .8, .12]) bins = np.linspace(*xlim, num=xlim[1] - xlim[0] + 1) axh.hist(logit_df[roi], bins, rwidth=.95, color=histcolor, linewidth=0, label=roi) axh.set_yticks(np.linspace(0, 900, 3)) axh.set_ylabel("Trials") if i: axh.set_xlabel("Classifier evidence for target class") else: axh.set_xticklabels([]) #axh.legend(loc="upper left") # Set the regession axis properties axr.set_xlim(*xlim) axr.set_ylim(.25, .85) axr.set_xticklabels([]) axr.axhline(.33, c="k", ls="--") axr.set_ylabel("OTC classifier accuracy") # Label the traces and histograms axr.text(-3, .6, "IFS", color=roi_colors["IFS"], ha="center") axr.text(-1.5, .48, "IPS", color=roi_colors["IPS"], ha="center") f.axes[1].text(-4, 600, "IFS evidence", color=roi_colors["IFS"], ha="center") f.axes[2].text(-4, 600, "IPS evidence", color=roi_colors["IPS"], ha="center") # Label the panels textsize = 11 f.text(.02, .96, "A", size=textsize) f.text(.02, .35, "B", size=textsize) f.text(.02, .20, "C", size=textsize) sns.despine() save_figure(f, "figure_7") dksort_figure_7() def dksort_rule_shift(): shift_dfs = [] for roi in net_rois: mask = "yeo17_" + roi.lower() ds = mvpa.extract_group("dimension", roi, mask, frames, peak, "rt") accs = mvpa.decode_group(ds, LogisticRegression(), trialwise=True) roi_df = behav_df[["subj", "dim_shift", "dim_shift_lag"]] roi_df["roi"] = roi roi_df["acc"] = np.concatenate(accs) shift_dfs.append(roi_df) shift_df = pd.concat(shift_dfs, ignore_index=True) shift_df["IFS_logit"] = np.tile(logit_df.IFS.values, 3) shift_df["IPS_logit"] = np.tile(logit_df.IPS.values, 3) shift_df = shift_df[shift_df.dim_shift_lag < 6] shift_df["log_shift_lag"] = np.log(shift_df.dim_shift_lag + 1) return shift_df shift_df = dksort_rule_shift() %%R -i shift_df m.shift.int = lmer(acc ~ roi * log_shift_lag + (roi + log_shift_lag | subj), shift_df, family=binomial) m.shift.add = lmer(acc ~ roi + log_shift_lag + (roi + log_shift_lag | subj), shift_df, family=binomial) lr_test(m.shift.int, m.shift.add, "ROI X Lag interaction") %%R print(m.shift.add, corr=FALSE) m.shift.nolag = lmer(acc ~ roi + (roi + log_shift_lag | subj), shift_df, family=binomial) lr_test(m.shift.add, m.shift.nolag, "Lag main effect") %%R m.shift.top = lmer(acc ~ log_shift_lag + IFS_logit + IPS_logit + (log_shift_lag + IFS_logit + IPS_logit | subj), shift_df, family=binomial, subset=shift_df$roi == "OTC") print(m.shift.top, corr=FALSE) %%R m.shift.top.noifs = lmer(acc ~ IPS_logit + log_shift_lag + (log_shift_lag + IFS_logit + IPS_logit | subj), shift_df, family=binomial, subset=shift_df$roi == "OTC") m.shift.top.noips = lmer(acc ~ IFS_logit + log_shift_lag + (log_shift_lag + IFS_logit + IPS_logit | subj), shift_df, family=binomial, subset=shift_df$roi == "OTC") lr_test(m.shift.top, m.shift.top.noifs, "IFS main effect") lr_test(m.shift.top, m.shift.top.noips, "IPS main effect") %%R m.shift.top.nolag = lmer(acc ~ IFS_logit + IPS_logit + (log_shift_lag + IFS_logit + IPS_logit | subj), shift_df, family=binomial, subset=shift_df$roi == "OTC") lr_test(m.shift.top, m.shift.top.nolag, "Lag main effect") def dksort_figure_8(): f, axes = plt.subplots(1, 3, sharey=True, figsize=(4.48, 3.5)) text_pos = [.41, .48, .54] data = shift_df.groupby(["roi", "subj", "dim_shift_lag"]).acc.mean().unstack() for i, roi in enumerate(net_rois): ax = axes[i] color = roi_colors[roi] roi_data = data.loc[roi].values x = range(1, 7) xx = np.linspace(.7, 6.3, 100) sns.tsplot(roi_data, time=x, err_style="ci_bars", interpolate=False, color=color, ms=5, mec=color, err_kws=dict(linewidth=1.5), ax=ax) y = roi_data.mean(axis=0) fit = np.polyfit(np.log(x), y, 1) r2 = r2_score(y, np.polyval(fit, np.log(x))) yy = np.polyval(fit, np.log(xx)) a, b = 3.1, 3.9 block_break = (xx >=a) & (xx <= b) lw=1.25 ax.plot(xx[xx < a], yy[xx < a], lw=lw, color=color, label=roi) ax.plot(xx[block_break], yy[block_break], lw=lw, ls=":", color=color) ax.plot(xx[xx > b], yy[xx > b], lw=lw, color=color) ax.text(1.6, text_pos[i], "$R^2 =\/%.2f$" % r2, size=8) ax.text(2.8, yy[30] + .03, roi, size=9, color=color) ax.xaxis.grid(False) ax.set_xlim(.7, 6.3) ax.set_ylim(.3, .7) ax.set_yticks(np.linspace(.3, .7, 5)) ax.axhline(.33, ls="--", c="#444444") if i == 1: ax.set_xlabel("Trials since rule switch") if not i: ax.set_ylabel("Cross-validated decoding accuracy") f.subplots_adjust(wspace=.05, bottom=.12, top=.97, left=.1, right=.98) sns.despine() save_figure(f, "figure_8") dksort_figure_8() def make_dataset(offset): a = np.random.randn(50, 3) + [-offset, 0, offset * .5] b = np.random.randn(50, 3) + [offset, 0, -offset * .5] return a.T, b.T def plot_plane(ax, normal, color="#666666", grid=False): point = np.array([0, 0, 0]) d = -point.dot(normal) steps = np.linspace(-2.5, 2.5, 5) if grid else [-2.5, 2.5] x, y = np.meshgrid(steps, steps) a, b, c = normal z = (-a * x - b * y - d) / c ax.plot_surface(x, y, z, alpha=.25, shade=False, color=color, rstride=20, cstride=20) ax.plot_wireframe(x, y, z, color="#555555", linewidths=.15) return ax def dksort_figure_9(): o, g, p = rule_colors["dim"] f = plt.figure(figsize=(6.85, 2)) for i, offset in enumerate([.5, 1, 2], 1): ax = f.add_subplot(1, 3, i, projection="3d", axisbg="white") a, b = make_dataset(offset) ax.scatter3D(*a, color=g, s=8) ax.scatter3D(*b, color=p, s=8) plot_plane(ax, [-1, 0, .5], "#999989", True) for axis in ["x", "y", "z"]: getattr(ax, "set_%slim" % axis)(-4, 4) getattr(ax, "set_%sticks" % axis)([]) getattr(ax, "set_%sticklabels" % axis)([]) ax.set_title(["Low", "Moderate", "High"][i - 1] + " control demands", size=8) plt.tight_layout() save_figure(f, "figure_9") dksort_figure_9()