%pylab inline
import matplotlib.pyplot as plt
Populating the interactive namespace from numpy and matplotlib
import numpy as np
from scipy import optimize
from astropy.modeling.fitting import Fitter
# TODO: This is currently needed to avoid an error in the contructor
from astropy.modeling.fitting import constraintsdef
constraintsdef['CashFitter'] = ['fixed']
def cash(D, M):
"""Cash Poisson likelihood statistic.
Parameters
----------
D : array-like
"data", i.e. observed counts per bin
M : array-like
"model", i.e. expected counts per bin
Returns
-------
cash : array
Cash statistic value per bin
"""
D = np.asanyarray(D, dtype=np.float64)
M = np.asanyarray(M, dtype=np.float64)
stat = 2 * (M - D * log(M))
stat = np.where(M > 0, stat, 0)
return stat
class CashFitter(Fitter):
"""Cash Poisson likelihood fitter.
Calls the `scipy.optimize.minimize` optimization function.
"""
def __init__(self, model):
Fitter.__init__(self, model)
def errorfunc(self, fitpars, *args):
"""The Cash Poisson likelihood fit statistic.
TODO: give formula and reference.
Parameters
----------
fitpars : TODO
TODO
*args : (y, x)
Tuple with y counts at coordinate x
Returns
-------
stat : float
Cash Poisson likelihood fit statistic
"""
self.fitpars = fitpars
D = args[0]
M = self.model(*args[1:])
stat = cash(D, M)
return stat.sum()
def __call__(self, x, y):
"""Execute the likelihood minimization.
TODO: document
Parameters
----------
x : array-like
x-coordinate
y : array-like
Observed number of counts at ``x``
"""
result = optimize.minimize(self.errorfunc,
x0=self.model.parameters[:],
args=(y, x))
self.fitpars = result.x
from astropy.modeling.models import Gaussian1DModel
model = Gaussian1DModel(amplitude=10, mean=2, stddev=3)
print('True parameters: ', model.parameters)
x = np.arange(-10, 20, 0.1)
y = np.random.poisson(model(x))
model = Gaussian1DModel(amplitude=7, mean=1, stddev=2)
print('Parameters before fit: ', model.parameters)
fitter = CashFitter(model)
fitter(x, y)
print('Parameters after fit: ', model.parameters)
y_model = model(x)
plt.plot(x, y, 'o')
plt.plot(x, y_model, 'r-');
('True parameters: ', [10.0, 2.0, 3.0]) ('Parameters before fit: ', [7.0, 1.0, 2.0]) ('Parameters after fit: ', [10.574787489518142, 1.8945598462935331, 2.9540072831120914])
-c:27: RuntimeWarning: divide by zero encountered in log -c:27: RuntimeWarning: invalid value encountered in multiply