Created 7/9/2014 by KO
Implements propensity-score matching and eventually will implement balance diagnostics%matplotlib inline
import math
import numpy as np
import scipy
from scipy.stats import binom, hypergeom
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression
Goal: find the average treatment effect in the treatment group (ATT) on RE78.
Import the data: controls and treated from Lalonde/Dehejia papers. Here's what the site says about the data:
The variables from left to right are: treatment indicator (1 if treated, 0 if not treated), age, education, Black (1 if black, 0 otherwise), Hispanic (1 if Hispanic, 0 otherwise), married (1 if married, 0 otherwise), nodegree (1 if no degree, 0 otherwise), RE74 (earnings in 1974), RE75 (earnings in 1975), and RE78 (earnings in 1978).
names = ['Treated', 'Age', 'Education', 'Black', 'Hispanic', 'Married',
'Nodegree', 'RE74', 'RE75', 'RE78']
treated = pd.read_table('nswre74_treated.txt', sep = '\s+',
header = None, names = names)
control = pd.read_table('nswre74_control.txt', sep='\s+',
header = None, names = names)
data = pd.concat([treated, control])
data.head()
Treated | Age | Education | Black | Hispanic | Married | Nodegree | RE74 | RE75 | RE78 | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | 37 | 11 | 1 | 0 | 1 | 1 | 0 | 0 | 9930.0460 |
1 | 1 | 22 | 9 | 0 | 1 | 0 | 1 | 0 | 0 | 3595.8940 |
2 | 1 | 30 | 12 | 1 | 0 | 0 | 0 | 0 | 0 | 24909.4500 |
3 | 1 | 27 | 11 | 1 | 0 | 0 | 1 | 0 | 0 | 7506.1460 |
4 | 1 | 33 | 8 | 1 | 0 | 0 | 1 | 0 | 0 | 289.7899 |
Compute propensity scores to start. Then we need to separate the treated and controls again (preserve original indexing) in order to match them.
Note, this section might need some fine-tuning to make it match Dehejia and Wahba (see their appendix for how they computed propensity scores)
propensity = LogisticRegression()
propensity = propensity.fit(data[names[1:-1]], data.Treated)
pscore = propensity.predict_proba(data[names[1:-1]])[:,1] # The predicted propensities by the model
print pscore[:5]
data['Propensity'] = pscore
#pscore = pd.Series(data = pscore, index = data.index)
[ 0.42716293 0.25617646 0.54874013 0.37386481 0.40217168]
Implement one-to-one matching, caliper without replacement. Variants of the method are examined in the following paper. This is something to explore further.
Austin, P. C. (2014), A comparison of 12 algorithms for matching on the propensity score. Statist. Med., 33: 1057–1069. doi: 10.1002/sim.6004def Match(groups, propensity, caliper = 0.05):
'''
Inputs:
groups = Treatment assignments. Must be 2 groups
propensity = Propensity scores for each observation. Propensity and groups should be in the same order (matching indices)
caliper = Maximum difference in matched propensity scores. For now, this is a caliper on the raw
propensity; Austin reccommends using a caliper on the logit propensity.
Output:
A series containing the individuals in the control group matched to the treatment group.
Note that with caliper matching, not every treated individual may have a match.
'''
# Check inputs
if any(propensity <=0) or any(propensity >=1):
raise ValueError('Propensity scores must be between 0 and 1')
elif not(0<caliper<1):
raise ValueError('Caliper must be between 0 and 1')
elif len(groups)!= len(propensity):
raise ValueError('groups and propensity scores must be same dimension')
elif len(groups.unique()) != 2:
raise ValueError('wrong number of groups')
# Code groups as 0 and 1
groups = groups == groups.unique()[0]
N = len(groups)
N1 = groups.sum(); N2 = N-N1
g1, g2 = propensity[groups == 1], (propensity[groups == 0])
# Check if treatment groups got flipped - treatment (coded 1) should be the smaller
if N1 > N2:
N1, N2, g1, g2 = N2, N1, g2, g1
# Randomly permute the smaller group to get order for matching
morder = np.random.permutation(N1)
matches = pd.Series(np.empty(N1))
matches[:] = np.NAN
for m in morder:
dist = abs(g1[m] - g2)
if dist.min() <= caliper:
matches[m] = dist.argmin()
g2 = g2.drop(matches[m])
return (matches)
stuff = Match(data.Treated, data.Propensity)
g1, g2 = data.Propensity[data.Treated==1], data.Propensity[data.Treated==0]
# test ValueError
#badtreat = data.Treated + data.Hispanic
#Match(badtreat, pscore)
stuff[:5]
0 25 1 213 2 75 3 14 4 210 dtype: float64
Here's the result: if we put the propensity scores of the treatment and matched controls side-by-side, we see that they're matched pretty well.
zip(g1, g2[stuff])
[(0.42716292681386048, 0.42581880946572531), (0.25617646165856756, 0.25484822320097977), (0.54874012613984235, 0.54874012613984235), (0.37386481070884442, 0.37424224379754362), (0.40217167818759586, 0.4025488132226952), (0.37341040912703743, 0.37341040912703743), (0.53301175942990109, 0.53492670949658028), (0.38451503327005687, 0.38489660886065941), (0.50914521709464833, 0.5038713907925424), (0.61794972732004938, 0.58527783340947481), (0.36708091086829236, 0.36693596150764657), (0.52310609372279693, 0.52399253077902697), (0.37001375859203811, 0.37001375859203811), (0.41040938689881862, 0.40263858902000565), (0.37295623086366536, 0.37330571977466881), (0.36206521830641381, 0.36206521830641381), (0.53663014121627661, 0.53887945917796798), (0.37046647062835597, 0.37046647062835597), (0.57103382954029125, 0.5709591033853374), (0.53976237257756132, 0.53976237257756132), (0.36543122767419073, 0.36543122767419073), (0.59343787140217641, 0.58088215756113837), (0.4386969080721812, 0.44247414429559007), (0.36753212799132406, 0.36753212799132406), (0.35997778168476563, 0.35997778168476563), (0.40954977660006303, 0.40954977660006303), (0.36963807095849838, 0.36918577756921234), (0.26069791214312099, 0.26100872287425286), (0.35789562808324343, 0.35789562808324343), (0.36753212799132406, 0.36708091086829236), (0.35789562808324343, 0.35789562808324343), (0.45658813935562875, nan), (0.40082653085458858, 0.40082653085458858), (0.52624903693537084, nan), (0.5375136651601029, 0.5375136651601029), (0.56484900646250547, 0.56420807943645801), (0.40038494333600216, 0.39999799811496028), (0.56561608250077011, 0.56652020912997003), (0.46330566204377877, nan), (0.37257932252559373, 0.37257932252559373), (0.52850447140293089, 0.53075874267386913), (0.3969042231346866, 0.39612686819916004), (0.36790692445253587, 0.36753212799132406), (0.27913948507677838, 0.27881522993582841), (0.35915945500472685, 0.35868111759952492), (0.40082653085458858, 0.40168028755982904), (0.36790692445253587, 0.36790692445253587), (0.36288641657881326, 0.36288641657881326), (0.40038494333600216, 0.40079387221848178), (0.53301175942990109, 0.57005720877929322), (0.3913433578168215, 0.39168310001168555), (0.41393321094396696, 0.4127054899751052), (0.35500470600275558, 0.35500470600275558), (0.5375136651601029, 0.57053301873387885), (0.4117397433545737, 0.41137419523919089), (0.35789562808324343, 0.35789562808324343), (0.40567343869825451, 0.40603741546027916), (0.47360442597144675, 0.44944907059288414), (0.59691150969843831, 0.58464545042501748), (0.39304342081844873, 0.39301309024977354), (0.36790692445253587, 0.36835856598771649), (0.39350085125328266, 0.39350085125328266), (0.39612686819916004, 0.39775517139652095), (0.37386481070884442, 0.37386481070884442), (0.41434922364455296, 0.41302230654596012), (0.38619843594707792, 0.38619843594707792), (0.37386481070884442, 0.37424224379754362), (0.41608090759002542, 0.41309626451423542), (0.47723165395085809, 0.47774167074970658), (0.54112719501669537, 0.55768626228483698), (0.38405562703559115, 0.38451503327005687), (0.57901039501721874, 0.57317432336471086), (0.38016333753745757, 0.38041199464841013), (0.5375136651601029, 0.55545300341207382), (0.52810275907963666, 0.55097984688635149), (0.39953799040539978, 0.39938108342140327), (0.53075874267386913, 0.56180397045598252), (0.37129563329740667, 0.37046647062835597), (0.40178416325708044, 0.40255931536012274), (0.36333543761891568, 0.36333543761891568), (0.5375136651601029, 0.5375136651601029), (0.5217350446165081, nan), (0.59732989436222506, 0.58907212092192596), (0.53663014121627661, nan), (0.38919010277019828, 0.38919010277019828), (0.39744213629889608, 0.39736320517754298), (0.26207005517968529, 0.26138340190382492), (0.37174898822506325, 0.37174898822506325), (0.55984202089567947, nan), (0.38489660886065941, 0.38514373934059354), (0.41393321094396696, 0.41126954472315269), (0.35789562808324343, 0.35789562808324343), (0.36963807095849838, 0.36963807095849838), (0.52850447140293089, 0.52822407944383731), (0.41051397115899052, 0.41051397115899052), (0.5687410368098067, 0.56887204494312937), (0.36498110138737444, 0.36498110138737444), (0.57546114010719607, 0.57669759720271807), (0.37469705231202188, 0.37469705231202188), (0.27200031650600354, 0.27659164918684331), (0.37257932252559373, 0.37257932252559373), (0.37681958989925324, 0.37681958989925324), (0.36963807095849838, 0.36963807095849838), (0.35500470600275558, 0.35500470600275558), (0.35707939391258053, 0.35707939391258053), (0.41882782270999969, 0.4199216409379678), (0.53663014121627661, 0.53663014121627661), (0.36288641657881326, 0.36288641657881326), (0.38024100782850595, 0.38062096495388931), (0.39397088933484964, 0.3917332458962281), (0.32878224824302671, 0.32857587970056046), (0.24604313298890371, 0.24531002119584097), (0.52145634998553481, 0.52341300607717067), (0.35215316763290705, 0.35293545509098728), (0.55291558624596349, 0.55453051526410713), (0.53812498686079091, 0.53760415714737053), (0.53165840610880255, nan), (0.53178741742330415, 0.53526343055701198), (0.36732044941345232, 0.36726529273432967), (0.50016320180970986, 0.54233385721831306), (0.54568333247369738, 0.54649842890599332), (0.37088397590008554, 0.37046647062835597), (0.37807518272603041, 0.37811089138712795), (0.54505617352098923, 0.54778717953003553), (0.37187026316150906, 0.37123230389724349), (0.42907939387160898, 0.42621298956406267), (0.503576849031942, nan), (0.41499892319370452, 0.40530744104465677), (0.34867332496059367, 0.34855987148037076), (0.35124881591719082, 0.34951011622716577), (0.368354966402969, 0.36835856598771649), (0.39388148189718369, 0.3939643292973713), (0.42287750361949428, 0.42323937900753988), (0.27508239938161005, 0.28098466748341883), (0.41651421693564455, 0.42037816699807562), (0.4167397609523012, 0.41919527728474437), (0.46806147084466776, 0.45878647462448441), (0.36978946463533807, 0.36963807095849838), (0.49379330696466373, nan), (0.43357631858871326, 0.43257515170041527), (0.34692254178035581, 0.34623057969596582), (0.46757669724338247, 0.46341413004896059), (0.37685357018163268, 0.37681958989925324), (0.4455565269264718, 0.44617327640136217), (0.61886657105467635, 0.60079264899057727), (0.37718188114200613, 0.37705709921672609), (0.40934354192118261, 0.40954977660006303), (0.39573595296507968, 0.39432477073157096), (0.44255500877946502, 0.42053379991459011), (0.36811932082281856, 0.36790692445253587), (0.42116861754671925, 0.42103202468255674), (0.47650112195239025, 0.47887757800201047), (0.39476901876612058, 0.39217707695803344), (0.53190717776280616, 0.54227884676845073), (0.49379443849344062, 0.45360505623822056), (0.4872454286920998, 0.48602144945577064), (0.41089388876758753, 0.41048788604880759), (0.53019919491977685, 0.52624903693537084), (0.42206423034004498, 0.42342902381310754), (0.42156230780201276, 0.42182411349268001), (0.41195948127964338, 0.41102028893247478), (0.41575142347870969, 0.41004419660549557), (0.37949264038566599, 0.37894686448575243), (0.44597592218322879, 0.44343522976592487), (0.54700647957502357, nan), (0.46839575084837853, 0.42584389373973836), (0.30118504354309478, 0.30266228063350109), (0.39535751543292041, 0.39479876377674628), (0.46967422150989402, 0.47016706817874493), (0.40711739797786478, 0.40603741546027916), (0.43016221895202345, 0.40603741546027916), (0.59539063185170427, 0.57227359694234925), (0.25033235153545463, 0.2506659230345058), (0.41813865251776078, 0.40263858902000565), (0.2969316019282166, 0.29322046060925033), (0.46046038221301638, 0.45526853386815697), (0.52687531615177041, 0.54771545246673836), (0.50045740817131545, 0.51770231215857909), (0.63404475620095613, 0.58689845316771205), (0.376835082633315, 0.37604736111069387), (0.59996973878777604, 0.57989993766179626), (0.47324609795733941, 0.47791911265297649), (0.53695349937194692, 0.53707042977012709), (0.60108829050619672, nan), (0.6729751518671101, 0.63068079328671744)]