%matplotlib inline import matplotlib.pyplot as plt import numpy as np import pandas as pd import pymc as pm import imp pm3 = imp.load_module("pymc3", *imp.find_module("pymc", ["/Users/fonnescj/GitHub/pymc"])) pm.__version__ import qgrid qgrid.nbinstall() pm3.__version__ hospitalized = pd.read_csv('data/hospitalized.csv', index_col=0) hospitalized.child_birth_date = pd.to_datetime(hospitalized.child_birth_date) hospitalized.enrollment_date = pd.to_datetime(hospitalized.enrollment_date) hospitalized.admission_date = pd.to_datetime(hospitalized.admission_date) hospitalized.discharge_date = pd.to_datetime(hospitalized.discharge_date) hospitalized["prev_cond"] = hospitalized[[c for c in hospitalized.columns if c.endswith('hx') and not c.startswith('no_')]].sum(1) hospitalized['year_admission'] = hospitalized.admission_date.apply(lambda x: x.year) nationality_lookup = {1: 'Jordanian', 2: 'Egyptian', 3: 'Palestinian', 4: 'Iraqi', 5: 'Syrian', 6: 'Sudanese', 7: 'Russian', 8: 'Asian', 9: 'Other'} hospitalized['nationality'] = hospitalized.mother_nationality.replace(nationality_lookup) hospitalized['Jordanian'] = (hospitalized.nationality=='Jordanian').astype(int) hospitalized['Palestinian'] = (hospitalized.nationality=='Palestinian').astype(int) hospitalized['vitamin D < 20'] = (hospitalized.hospitalized_vitamin_d < 20).astype(int) hospitalized['vitamin D < 20'][hospitalized.hospitalized_vitamin_d.isnull()] = np.nan hospitalized['vitamin D < 11'] = (hospitalized.hospitalized_vitamin_d < 11).astype(int) hospitalized['vitamin D < 11'][hospitalized.hospitalized_vitamin_d.isnull()] = np.nan hospitalized['premature'] = (hospitalized.gest_age < 37).astype(int) hospitalized.shape hospitalized['respiratory_class'] = 0 hospitalized.loc[(hospitalized.respiratory_rate>30) & (hospitalized.respiratory_rate<=45), 'respiratory_class'] = 1 hospitalized.loc[(hospitalized.respiratory_rate>45) & (hospitalized.respiratory_rate<=60), 'respiratory_class'] = 2 hospitalized.loc[hospitalized.respiratory_rate>60, 'respiratory_class'] = 3 hospitalized.respiratory_class.value_counts() hospitalized['sats_class'] = 0 hospitalized.loc[(hospitalized.sats_number>=90) & (hospitalized.sats_number<95), 'sats_class'] = 1 hospitalized.loc[(hospitalized.sats_number>=85) & (hospitalized.sats_number<90), 'sats_class'] = 2 hospitalized.loc[hospitalized.sats_number<85, 'sats_class'] = 3 hospitalized.sats_class.value_counts() hospitalized.sats_range.unique() hospitalized['sats_score'] = hospitalized.sats_range.replace({'85-89': 2, '90-94': 1, '>95': 0, '<85': 3, '(90-94)%': 1, '>95%': 0, '>60': np.nan, '90-95': 1, '(85-89)%': 2, '<85%': 3, '85-59': 2, '85-98': 2, '95-98': 0, '92-95': 1, '93-95': 1}) pd.crosstab(hospitalized.sats_score, hospitalized.oxygen) hospitalized['admission_month'] = hospitalized.admission_date.apply(lambda x: x.month) hospitalized['admission_week'] = hospitalized.admission_date.apply(lambda x: x.week) hospitalized['admission_year'] = hospitalized.admission_date.apply(lambda x: x.year) hospitalized['male'] = hospitalized.sex_child hospitalized['premature'] = np.maximum(0, 37 - hospitalized.gest_age) hospitalized.loc[hospitalized.gest_age.isnull(), 'premature'] = np.nan hospitalized.premature.hist(bins=20, grid=False) # Create indicator for LRTI hospitalized['ltri'] = ((hospitalized.adm_bronchopneumo + hospitalized.adm_bronchiolitis + hospitalized.adm_pneumo + hospitalized.adm_wheezing + hospitalized.adm_asthma + (hospitalized.wheezing>0) + (hospitalized.flaring>1)) > 1).astype(int) pcr_lookup = {'pcr_result___1': 'RSV', 'pcr_result___2': 'HMPV', 'pcr_result___3': 'flu A', 'pcr_result___4': 'flu B', 'pcr_result___5': 'rhino', 'pcr_result___6': 'PIV1', 'pcr_result___7': 'PIV2', 'pcr_result___8': 'PIV3', 'pcr_result___13': 'H1N1', 'pcr_result___14': 'H3N2', 'pcr_result___15': 'Swine', 'pcr_result___16': 'Swine H1', 'pcr_result___17': 'flu C', 'pcr_result___18': 'Adeno'} virus_vars = ['Influenza', 'HMPV', 'Rhino'] hospitalized['RSV'] = hospitalized['pcr_result___1'] hospitalized['Influenza'] = (hospitalized['pcr_result___3'] | hospitalized['pcr_result___4']).astype(int) hospitalized['HMPV'] = hospitalized['pcr_result___2'] hospitalized['Rhino'] = hospitalized['pcr_result___5'] hospitalized['vitamin_d_norm'] = ((hospitalized.hospitalized_vitamin_d - 20) / hospitalized.hospitalized_vitamin_d.std()) hospitalized['age_centered'] = hospitalized.age_months - hospitalized.age_months.mean() hospitalized['enroll_month'] = [d.month for d in hospitalized.enrollment_date] hospitalized['enroll_spring'] = hospitalized.enroll_month.isin((3,4,5)) hospitalized['enroll_summer'] = hospitalized.enroll_month.isin((6,7,8)) hospitalized['enroll_autumn'] = hospitalized.enroll_month.isin((9,10,11)) hospitalized['age_X_vitamin_d'] = hospitalized.vitamin_d_norm * hospitalized.age_centered hospitalized['male_X_vitamin_d'] = hospitalized.vitamin_d_norm * hospitalized.male hospitalized.hosp_vitd.replace({'<1': 0}).astype(float).notnull().sum() analysis_subset = hospitalized[hospitalized.qns.isnull() & hospitalized.hosp_vitd.notnull()] analysis_subset.shape rsv_subset = analysis_subset[analysis_subset.pcr_result___1==1] rsv_subset.shape rsv_subset.heart_hx.mean() rsv_subset.daycare.mean() rsv_subset.cigarette_preg.mean() rsv_subset['viral_coinfection'] = (rsv_subset[['pcr_result___2', 'pcr_result___3', 'pcr_result___4', 'pcr_result___5', 'pcr_result___6', 'pcr_result___7', 'pcr_result___8', 'pcr_result___13', 'pcr_result___14', 'pcr_result___15', 'pcr_result___16', 'pcr_result___17', 'pcr_result___18']].sum(1)>0).astype(int) covnames =["vitamin_d_norm", "prev_cond", "cigarette_smokers", "male", "z_score", # "ltri", "premature", # "viral_coinfection", # "enroll_spring", # "enroll_summer", # "enroll_autumn", # "admission_week", "age_X_vitamin_d", "male_X_vitamin_d", "age_centered"] covariates = rsv_subset[covnames] covariates.isnull().sum() severity = rsv_subset[['oxygen', 'sats_score']] outcome = 'oxygen' complete = (covariates.isnull().sum(axis=1).astype(bool)==False) & (severity[outcome].isnull().astype(bool)==False) covariates_complete = covariates[complete] y_complete = severity[outcome][complete] variables = pd.concat([covariates, severity['oxygen']], axis=1) variables.isnull().sum() variables = variables.dropna() assert not variables.isnull().sum().any() qgrid.show_grid(variables, remote_js=True) rsv_vars = variables.columns[:-1].values rsv_vars['vent_or_icu'] = ((rsv_no_hx.vent + rsv_no_hx.icu) > 0).astype(int) glm = pm3.glm.glm with pm3.Model() as severity_model: formula = 'oxygen ~ ' + '+'.join(rsv_vars) glm(formula, variables, family=pm3.glm.families.Binomial()) severity_trace = pm3.sample(2000, pm3.NUTS()) pm3.forestplot(severity_trace, vars=rsv_vars) from pandas.tools.plotting import scatter_matrix trace_df = pm3.trace_to_dataframe(severity_trace) print(trace_df.describe().drop('count').T) print("\nP(Vitamin D < 0) = {}".format((trace_df['vitamin_d_norm'] < 0).mean())) trace_df.drop('Intercept', axis=1).apply(np.exp).describe().drop('count').T variables['O2'] = variables.oxygen + np.random.normal(size=len(variables.oxygen))*0.01 variables.plot(x='vitamin_d_norm', y='O2', kind='scatter', alpha=0.2, yticks=[0,1]) variables['vitd_bins'] = pd.cut(variables.vitamin_d_norm, bins=12) ax = variables.groupby('vitd_bins')['oxygen'].mean().plot(grid=False) variables.groupby('vitd_bins')['oxygen'].count().plot(secondary_y=True, style='k--') ax.set_ylabel('Proportion on O2') ax.right_ax.set_ylabel('Number in bin') (hospitalized.hospitalized_vitamin_d > 30).mean() variables.vitamin_d_norm.hist(bins=25) variables[rsv_vars[1:]].head() rsv_vars[1:] def changepoint_model(): X = variables[['prev_cond', 'cigarette_smokers', 'male', 'z_score', 'premature']].values change = pm.Uniform('change', variables.vitamin_d_norm.min(), 1.5, value=-0.5) # Low vitamin D effect alpha = pm.Normal('alpha', 0, 0.001, value=0) mu = pm.Normal('mu', 0, 0.001, value=0) # Covariates for confounders beta = pm.Normal('beta', 0, 0.001, value=[0]*X.shape[1]) @pm.deterministic def theta(b=beta, a=alpha, c=change, m=mu): #import pdb; pdb.set_trace() return pm.invlogit(m + a*(variables.vitamin_d_norm.values < c) + X.dot(b)) y = pm.Bernoulli('y', p=theta, value=variables.oxygen.values, observed=True) return(locals()) M_change = pm.MCMC(changepoint_model()) M_change.sample(20000, 10000) pm.Matplot.plot(M_change.change) M_change.change.summary() pm.Matplot.plot(M_change.alpha) M_change.alpha.summary() from pymc.gp import Covariance, Mean, GPSubmodel from pymc.gp.cov_funs import matern vitD_mesh = np.linspace(variables.vitamin_d_norm.min(), variables.vitamin_d_norm.max()) gp_covs =["prev_cond", "cigarette_smokers", "male", "z_score", "premature", "age_centered"] def gp_model(): X = variables[gp_covs].values # Mean risk mu = pm.Normal('mu', 0, 0.001, value=0) # Covariates for confounders beta = pm.Normal('beta', 0, 0.001, value=[0]*X.shape[1]) # GP hyperpriors amp = pm.Exponential('amp', 1, value=1) scale = pm.Uniform('scale' , 0, 10, value=1) diff_degree = pm.Uniform('diff_degree', 0, 10, value=1) @pm.deterministic def C(diff_degree=diff_degree, amp=amp, scale=scale): """ The Matern covariance function """ return Covariance(matern.euclidean, diff_degree=diff_degree, amp=amp, scale=scale) @pm.deterministic def M(): """ The mean function is the zero function """ return Mean(lambda x: np.zeros(x.shape)) alpha = GPSubmodel('alpha', M, C, mesh=vitD_mesh) @pm.deterministic def theta(b=beta, a=alpha.f(variables.vitamin_d_norm.values), m=mu): return pm.invlogit(m + a + X.dot(b)) y = pm.Bernoulli('y', p=theta, value=variables.oxygen.values, observed=True) return(locals()) M_gp = pm.MCMC(gp_model()) M_gp.sample(20000, 10000) plt.figure(figsize=(18, 5)) ax = plt.subplot(1, 2, 1) for i in range(0, 100): f = np.random.choice(M_gp.alpha.f.trace()[-1000:])(vitD_mesh) plt.plot(vitD_mesh, f, 'k-', linewidth=0.5, alpha=0.5) plt.xlabel('Normalized Vitamin D') plt.ylabel('Vitamin D effect on oxygen') import theano.tensor as TT def tinvlogit(x): return TT.exp(x) / (1 + TT.exp(x)) with pm3.Model() as severity_lasso_model: X = variables[rsv_vars].values alpha = pm3.Exponential('alpha', 1) beta = pm3.Laplace('beta', 0, alpha, shape=X.shape[1]) mu = pm3.Normal('mu', 0, 0.01) p = tinvlogit(mu + beta.dot(X.T)) y = pm3.Bernoulli('y', p, observed=variables.oxygen) with severity_lasso_model: trace_lasso = pm3.sample(2000, (pm3.NUTS(vars=[beta], scaling=pm3.find_MAP(vars=[beta])), pm3.Slice(vars=[alpha, mu]))) pm3.forestplot(trace_lasso[1000:], vars=['beta'], ylabels=list(variables[rsv_vars].columns)) _ = pm3.traceplot(trace_lasso[1000:], vars=['alpha']) with pm3.Model() as severity_lasso_model: b = pm3.Exponential('b', 1) # Define priors for intercept and regression coefficients. priors = {v:pm3.Laplace.dist(mu=0, b=b) for v in variables.columns[:-1]} priors['Intercept'] = pm3.Normal.dist(mu=0, sd=50) formula = 'oxygen ~ prev_cond + cigarette_smokers + male + z_score' formula += '+ premature + age_X_vitamin_d + male_X_vitamin_d + age_centered + vitamin_d_norm' glm(formula, variables, family=pm3.glm.families.Binomial(), priors=priors) trace_lasso = pm3.sample(2000, pm3.NUTS()) pm3.forestplot(trace_lasso, vars=variables.columns[:-1]) _ = pm3.traceplot(trace_lasso[1000:], vars=['b']) trace_df = pm3.trace_to_dataframe(trace_lasso) #scatter_matrix(trace_df, figsize=(8, 8)); print(trace_df.apply(np.exp).describe().drop('count').T) rsv_no_hx = rsv_subset#[rsv_subset.no_hx.astype(bool) & rsv_subset.oxygen.notnull()] rsv_no_hx.shape rsv_vars rsv_no_hx.icu.mean() rsv_no_hx.vent.mean() rsv_no_hx.death.mean() rsv_no_hx.oxygen.mean() rsv_no_hx['vent_or_icu'] = ((rsv_no_hx.vent + rsv_no_hx.icu) > 0).astype(int) rsv_no_hx['premature_ind'] = (rsv_no_hx.premature>0).astype(int) rsv_no_hx.adm_pneumo.mean() no_hx_vars = ['vitamin_d_norm', 'breastfed', 'male', 'age_X_vitamin_d', 'z_score', 'age_centered', 'premature_ind', 'adm_pneumo'] rsv_no_hx[no_hx_vars].head() rsv_no_hx.premature.value_counts() rsv_no_hx = rsv_no_hx.dropna(subset=['oxygen'] + no_hx_vars) rsv_no_hx[no_hx_vars].isnull().sum(0) rsv_no_hx.oxygen.isnull().sum() glm = pm3.glm.glm with pm3.Model() as oxygen_no_hx: formula = 'oxygen ~ ' + '+'.join(no_hx_vars) glm(formula, rsv_no_hx, family=pm3.glm.families.Binomial()) oxygen_trace = pm3.sample(2000, pm3.NUTS()) pm3.forestplot(oxygen_trace, vars=no_hx_vars) with pm3.Model() as severity_no_hx: formula = 'vent_or_icu ~ ' + '+'.join(no_hx_vars) glm(formula, rsv_no_hx, family=pm3.glm.families.Binomial()) severity_trace = pm3.sample(2000, pm3.NUTS()) pm3.forestplot(severity_trace, vars=no_hx_vars) pm3.forestplot(severity_trace, vars=no_hx_vars) pm3.forestplot(severity_trace, vars=no_hx_vars) analysis_subset.shape import seaborn as sns plt.figure(figsize=(10, 6)) sns.set(style="white", palette="deep") sns.distplot(analysis_subset.hospitalized_vitamin_d.dropna(), kde=False, color="steelblue", axlabel='Hospitalized VitaminD (ng/ml)') sns.despine(trim=True) analysis_subset.hospitalized_vitamin_d.hist(bins=np.sqrt(1000), normed=True, grid=False, color='grey', figsize=(10,6)) plt.xlabel('Vitamin D (ng/ml)'); plt.ylabel('Relative Frequency') analysis_subset.cigarette_smokers.hist() (analysis_subset.cigarette_smokers > 0).mean() (analysis_subset.cigarette_smokers > 0).sum() analysis_subset.nargila_smokers.hist() (analysis_subset.nargila_smokers > 0).mean() (analysis_subset.nargila_smokers > 0).sum() analysis_subset.cigarette_preg.mean() analysis_subset.nargila_preg.mean() analysis_subset.breastfed = analysis_subset.breastfed.replace({0: 'Not breastfed', 1: 'Breastfed'}) _ = analysis_subset.groupby(['breastfed']).boxplot(column='hospitalized_vitamin_d', grid=False) _ = analysis_subset.groupby('oxygen').boxplot(column='rsv_count', grid=False) months = ['Jan','Feb','Mar','Apr','May','Jun','Jul','Aug','Sep','Oct','Nov','Dec'] groupby_month = analysis_subset.groupby('admission_month') ax = groupby_month['oxygen'].mean().plot(grid=False, color='r', figsize=(14,8)) groupby_month['hospitalized_vitamin_d'].mean().plot(secondary_y=True, color='b') ax.set_ylabel('Proportion on oxygen', color='r') ax.right_ax.set_ylabel('Mean vitamin D', color='b') ax.set_xlabel('Admission month') ax.set_xticks(np.arange(12)+1) ax.set_xticklabels(months); vitD_O2 = analysis_subset[['oxygen', 'hospitalized_vitamin_d']].set_index(analysis_subset.admission_date) vitD_O2_by_week = vitD_O2.resample('M', how='mean').fillna(0) ax = vitD_O2_by_week['oxygen'].plot(grid=False, color='r') vitD_O2_by_week['hospitalized_vitamin_d'].plot(secondary_y=True, color='b') ax.set_ylabel('Proportion on oxygen', color='r') ax.right_ax.set_ylabel('Mean vitamin D', color='blue') from mpl_toolkits.axes_grid1 import host_subplot import mpl_toolkits.axisartist as AA import datetime plt.figure(figsize=(12,6)) host = host_subplot(111, axes_class=AA.Axes) plt.subplots_adjust(right=0.75) par1 = host.twinx() par2 = host.twinx() offset = 60 new_fixed_axis = par2.get_grid_helper().new_fixed_axis par2.axis["right"] = new_fixed_axis(loc="right", axes=par2, offset=(offset, 0)) par2.axis["right"].toggle(all=True) dates = [datetime.datetime.utcfromtimestamp(i.astype(int)*1e-9) for i in vitD_O2_by_week.index.values] vitD_O2 = analysis_subset[['oxygen', 'hospitalized_vitamin_d']].set_index(analysis_subset.admission_date) vitD_O2_by_week = vitD_O2.resample('M', how='mean').fillna(0) host.plot(dates, vitD_O2_by_week['oxygen'], 'k-', label='Oxygen') par1.plot(dates, vitD_O2_by_week['hospitalized_vitamin_d'], 'k--', label='Vitamin D') par2.plot(dates, vitD_O2['hospitalized_vitamin_d'].resample('M', how='count'), linestyle='dotted', color='grey', label='RSV') #host.get_xaxis().get_major_formatter().set_useOffset(False) host.set_xlabel("Month") host.set_ylabel("Proportion on oxygen") par1.set_ylabel("Mean vitamin D") par2.set_ylabel("RSV hospitalizations") host.legend(loc='lower right'); analysis_subset.oxygen.notnull().sum() se = lambda p, n: np.sqrt(p * (1. - p) / n) def odds_ratio(x, y, n_sim=10000, alpha=0.05): try: n_x, n_y = len(x.dropna()), len(y.dropna()) p_x, p_y = x.mean(), y.mean() se_x = se(p_x, n_x) se_y = se(p_y, n_y) p_x_sim = np.random.normal(p_x, se_x, n_sim) p_y_sim = np.random.normal(p_y, se_y, n_sim) ratio = ((p_x_sim / (1. - p_x_sim)) / (p_y_sim / (1. - p_y_sim))) interval = np.percentile(ratio, [100*(alpha/2.), 100*(1. - alpha/2.)]) return np.round(np.median(ratio), 2), np.round(interval, 2).tolist(), (n_y, n_x) except ValueError: return np.nan, np.nan, np.nan def calc_or(groupby, var): data = list(groupby[var]) return odds_ratio(data[1][1], data[0][1]) def make_table(groupby, table_vars, replace_dict={}): table = np.round(groupby[table_vars].mean(), 2).T ratios = [calc_or(groupby, v) for v in table.index] table['OR'] = [r[0] for r in ratios] table['Interval'] = [r[1] for r in ratios] table['N'] = [r[2] for r in ratios] table.rename(columns=replace_dict, inplace=True) table.columns.name = None return(table) table_vars = ['male', 'under_2_months', 'months_2_11', 'months_12_23', 'Jordanian', 'Palestinian', 'vitamin D < 20', 'vitamin D < 11', 'prev_cond', 'heart_hx', 'breastfed', 'premature_ind', 'adm_pneumo', 'adm_bronchopneumo', 'adm_sepsis', 'adm_bronchiolitis'] age_groups = pd.get_dummies(pd.cut(rsv_subset.age_months, [0,1,11,23])) age_groups.index = rsv_subset.index age_groups.columns = 'under_2_months', 'months_2_11', 'months_12_23' rsv_subset = rsv_subset.join(age_groups) nationality_lookup = {1: 'Jordanian', 2: 'Egyptian', 3: 'Palestinian', 4: 'Iraqi', 5: 'Syrian', 6: 'Sudanese', 7: 'Russian', 8: 'Asian', 9: 'Other'} rsv_subset['nationality'] = rsv_subset.mother_nationality.replace(nationality_lookup) rsv_subset['Jordanian'] = (rsv_subset.nationality=='Jordanian').astype(int) rsv_subset['Palestinian'] = (rsv_subset.nationality=='Palestinian').astype(int) rsv_subset['vitamin D < 20'] = (rsv_subset.hospitalized_vitamin_d < 20).astype(int) rsv_subset.loc[rsv_subset.hospitalized_vitamin_d.isnull(), 'vitamin D < 20'] = np.nan rsv_subset['vitamin D < 11'] = (rsv_subset.hospitalized_vitamin_d < 11).astype(int) rsv_subset.loc[rsv_subset.hospitalized_vitamin_d.isnull(), 'vitamin D < 11'] = np.nan rsv_subset['premature_ind'] = (rsv_subset.premature>0).astype(int) groupby_o2 = rsv_subset.groupby('oxygen') make_table(groupby_o2, table_vars=table_vars+virus_vars, replace_dict={0.0: 'No Oxygen', 1.0: 'Oxygen'}) groupby_vent = rsv_subset.groupby('vent') make_table(groupby_vent, table_vars=table_vars+virus_vars, replace_dict={0.0: 'No Ventilator', 1.0: 'Ventilator'}) groupby_death = rsv_subset.groupby('icu') make_table(groupby_death, table_vars=table_vars+virus_vars, replace_dict={False: 'Survived', True: 'Die'})