%matplotlib inline
from IPython.display import HTML, YouTubeVideo
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import urllib.request
import os.path
import zipfile
from scipy.integrate import odeint
from scipy.optimize import curve_fit
from scipy.integrate import odeint
from scipy.stats import linregress
import seaborn as sns
sns.set_style('ticks')
sns.set_context('talk')
sns.set_palette("Set1", 8, .75)
Our focus todays will be on modeling population growth. We will start with human populations and use data from the World Bank on the population size of countries around the world from 1960 until 2013.
HTML('<blockquote class="twitter-tweet" lang="he"><p>In Data Science, 80% of time spent prepare data, 20% of time spent complain about need for prepare data.</p>— Big Data Borat (@BigDataBorat) <a href="https://twitter.com/BigDataBorat/status/306596352991830016">פברואר 27, 2013</a></blockquote><script async src="//platform.twitter.com/widgets.js" charset="utf-8"></script>')
In Data Science, 80% of time spent prepare data, 20% of time spent complain about need for prepare data.
— Big Data Borat (@BigDataBorat) פברואר 27, 2013
So let's prepare the data from scratch.
Start by retrieving the data as a zip file (if not already downloaded). Follow the link to the World Bank page on Population, total, click DOWNLOAD DATA
, choose CSV
, a zip file will be downloaded.
Even better, click the CSV
button with the right button and click Copy link address
in the context menu. This link can be used to get the file directly from Python using urllib
:
fname = 'world_growth.zip'
if not os.path.exists(fname):
urllib.request.urlretrieve('http://api.worldbank.org/v2/en/indicator/sp.pop.totl?downloadformat=csv', fname)
Open the file using the zipfile
module and read it using pandas, skipping the first two rows and dropping the last two columns that have NaN
values:
with zipfile.ZipFile('world_growth.zip') as z:
f = z.open('sp.pop.totl_Indicator_en_csv_v2.csv')
df = pd.read_csv(f, skiprows=2)
df.drop('2014', axis=1, inplace=True)
df.drop('Unnamed: 59', axis=1, inplace=True)
df.head()
Country Name | Country Code | Indicator Name | Indicator Code | 1960 | 1961 | 1962 | 1963 | 1964 | 1965 | ... | 2004 | 2005 | 2006 | 2007 | 2008 | 2009 | 2010 | 2011 | 2012 | 2013 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | Aruba | ABW | Population, total | SP.POP.TOTL | 54208 | 55435 | 56226 | 56697 | 57029 | 57360 | ... | 98742 | 100031 | 100830 | 101219 | 101344 | 101418 | 101597 | 101932 | 102384 | 102911 |
1 | Andorra | AND | Population, total | SP.POP.TOTL | 13414 | 14376 | 15376 | 16410 | 17470 | 18551 | ... | 79060 | 81223 | 81877 | 81292 | 79969 | 78659 | 77907 | 77865 | 78360 | 79218 |
2 | Afghanistan | AFG | Population, total | SP.POP.TOTL | 8774440 | 8953544 | 9141783 | 9339507 | 9547131 | 9765015 | ... | 24018682 | 24860855 | 25631282 | 26349243 | 27032197 | 27708187 | 28397812 | 29105480 | 29824536 | 30551674 |
3 | Angola | AGO | Population, total | SP.POP.TOTL | 4965988 | 5056688 | 5150076 | 5245015 | 5339893 | 5433841 | ... | 15976715 | 16544376 | 17122409 | 17712824 | 18314441 | 18926650 | 19549124 | 20180490 | 20820525 | 21471618 |
4 | Albania | ALB | Population, total | SP.POP.TOTL | 1608800 | 1659800 | 1711319 | 1762621 | 1814135 | 1864791 | ... | 3026939 | 3011487 | 2992547 | 2970017 | 2947314 | 2927519 | 2913021 | 2904780 | 2900489 | 2897366 |
5 rows × 58 columns
df.tail()
Country Name | Country Code | Indicator Name | Indicator Code | 1960 | 1961 | 1962 | 1963 | 1964 | 1965 | ... | 2004 | 2005 | 2006 | 2007 | 2008 | 2009 | 2010 | 2011 | 2012 | 2013 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
243 | Yemen, Rep. | YEM | Population, total | SP.POP.TOTL | 5099785 | 5184477 | 5276093 | 5372934 | 5472775 | 5573959 | ... | 19612696 | 20139661 | 20661714 | 21182162 | 21703571 | 22229625 | 22763008 | 23304206 | 23852409 | 24407381 |
244 | South Africa | ZAF | Population, total | SP.POP.TOTL | 17396000 | 17949962 | 18459442 | 18936138 | 19390554 | 19832000 | ... | 46727694 | 47349013 | 47991699 | 48656506 | 49344228 | 50055701 | 50791808 | 51553479 | 52341695 | 53157490 |
245 | Congo, Dem. Rep. | COD | Population, total | SP.POP.TOTL | 15248246 | 15637715 | 16041247 | 16461914 | 16903899 | 17369859 | ... | 52487293 | 54028003 | 55590838 | 57187942 | 58819038 | 60486276 | 62191161 | 63931512 | 65705093 | 67513677 |
246 | Zambia | ZMB | Population, total | SP.POP.TOTL | 3082627 | 3173662 | 3269151 | 3368961 | 3472843 | 3580708 | ... | 11174650 | 11470022 | 11781612 | 12109620 | 12456527 | 12825031 | 13216985 | 13633796 | 14075099 | 14538640 |
247 | Zimbabwe | ZWE | Population, total | SP.POP.TOTL | 3752390 | 3876638 | 4006261 | 4140802 | 4279559 | 4422129 | ... | 12693047 | 12710589 | 12724308 | 12740160 | 12784041 | 12888918 | 13076978 | 13358738 | 13724317 | 14149648 |
5 rows × 58 columns
We now want to tidy our data - transform the dataset to a DataFrame format:
Tidy datasets are easy to manipulate, model and visualize, and have a specific structure: each variable is a column, each observation is a row, and each type of observational unit is a table - Hadley Wickham (2014), Tidy Data, J. Stat. Soft.
This means that every row in the table has exactly one measurement population size with all the relevant variables: country name and year. Read more on tidy data.
We do this using pandas melt
function:
print(df.shape)
df = pd.melt(df, id_vars=('Country Name','Country Code','Indicator Name', 'Indicator Code'), var_name='Year', value_name='Population')
print(df.shape)
df.head()
(248, 58) (13392, 6)
Country Name | Country Code | Indicator Name | Indicator Code | Year | Population | |
---|---|---|---|---|---|---|
0 | Aruba | ABW | Population, total | SP.POP.TOTL | 1960 | 54208 |
1 | Andorra | AND | Population, total | SP.POP.TOTL | 1960 | 13414 |
2 | Afghanistan | AFG | Population, total | SP.POP.TOTL | 1960 | 8774440 |
3 | Angola | AGO | Population, total | SP.POP.TOTL | 1960 | 4965988 |
4 | Albania | ALB | Population, total | SP.POP.TOTL | 1960 | 1608800 |
df.tail()
Country Name | Country Code | Indicator Name | Indicator Code | Year | Population | |
---|---|---|---|---|---|---|
13387 | Yemen, Rep. | YEM | Population, total | SP.POP.TOTL | 2013 | 24407381 |
13388 | South Africa | ZAF | Population, total | SP.POP.TOTL | 2013 | 53157490 |
13389 | Congo, Dem. Rep. | COD | Population, total | SP.POP.TOTL | 2013 | 67513677 |
13390 | Zambia | ZMB | Population, total | SP.POP.TOTL | 2013 | 14538640 |
13391 | Zimbabwe | ZWE | Population, total | SP.POP.TOTL | 2013 | 14149648 |
We get rid of rows that have NA
values:
df.dropna(axis=0, inplace=True)
df.head()
Country Name | Country Code | Indicator Name | Indicator Code | Year | Population | |
---|---|---|---|---|---|---|
0 | Aruba | ABW | Population, total | SP.POP.TOTL | 1960 | 54208 |
1 | Andorra | AND | Population, total | SP.POP.TOTL | 1960 | 13414 |
2 | Afghanistan | AFG | Population, total | SP.POP.TOTL | 1960 | 8774440 |
3 | Angola | AGO | Population, total | SP.POP.TOTL | 1960 | 4965988 |
4 | Albania | ALB | Population, total | SP.POP.TOTL | 1960 | 1608800 |
Note that for some reason the Year
variable has a type O
(object?) instead of int
, so we change that:
print(df.Year.dtype, df.Population.dtype)
df.Year = df.Year.astype(int)
print(df.Year.dtype, df.Population.dtype)
object float64 int32 float64
Let's start with the three biggest population in the dataset - the entire world, China and India. We will plot their population size over time:
world = df[df['Country Name'] == 'World']
china = df[df['Country Name'] == 'China']
india = df[df['Country Name'] == 'India']
plt.plot(world.Year, world.Population, label='World')
plt.plot(china.Year, china.Population, label='China')
plt.plot(india.Year, india.Population, label='India')
plt.legend(loc='upper left')
plt.xlabel('Year')
plt.ylabel('Population');
Alternatively we can also plot using the DataFrame.plot
method:
ax = world.plot('Year', 'Population')
china.plot('Year', 'Population', ax=ax)
india.plot('Year', 'Population', ax=ax)
ax.legend(['World', 'China', 'India'], loc='upper left')
ax.set_ylabel('Population');
According to Malthus (but that was a long time ago, so don't count on it!), populations grow exponentialy (how to solve this):
$$ \frac{dN(t)}{dt} = r N(t) \Rightarrow \\ N(t) = N(0) e^{rt} \Rightarrow \\ \log{N(t)} = \log{N(0)} + rt $$where $N(t)$ is the population size at time $t$ and $r$ is the growth rate.
That means that that logarithm of the population should be a linear function of time.
print("Exponential growth")
YouTubeVideo("https://www.youtube.com/watch?v=c6pcRR5Uy6w")
Exponential growth
Let's check if the world population growth is according to the exponential model. First define a new column for the log of the population size:
df['LogPopulation'] = np.log(df.Population)
Let's concentrate on the G8 countires (well, not exactly, but good enough for the purpose of this presentation):
G8_countries = ('France', 'Germany', 'Italy', 'Japan', 'United Kingdom', 'United States', 'Canada', 'European Union') # 'Russian Federation' was suspended in 2014
G8 = df[df['Country Name'].isin(G8_countries)]
G8.head()
Country Name | Country Code | Indicator Name | Indicator Code | Year | Population | LogPopulation | |
---|---|---|---|---|---|---|---|
33 | Canada | CAN | Population, total | SP.POP.TOTL | 1960 | 17909009 | 16.700814 |
52 | Germany | DEU | Population, total | SP.POP.TOTL | 1960 | 72814900 | 18.103431 |
69 | European Union | EUU | Population, total | SP.POP.TOTL | 1960 | 409331746 | 19.830037 |
73 | France | FRA | Population, total | SP.POP.TOTL | 1960 | 46647521 | 17.658130 |
77 | United Kingdom | GBR | Population, total | SP.POP.TOTL | 1960 | 52400000 | 17.774417 |
Here we do a hack to sort the countires names based on the largest population in 2013 so that the next plot will have a legend that is ordered by country size.
We start by splitting the data frame according to country name and extracting the max size for each country so that we will have a pandas.Series
with the country as index and the max population size as the value. We then sort by the value (the population max size) and extract the index (the country). We will use this sorted list for ordering the colors in the next plot.
G8_countries = G8.groupby('Country Name').Population.max()
G8_countries.sort(ascending=False)
G8_countries = G8_countries.index
print(G8_countries.tolist())
['European Union', 'United States', 'Japan', 'Germany', 'France', 'United Kingdom', 'Italy', 'Canada']
We now plot the log of the population sizes. Note that only the first 3 arguments are neccessary and the rest are are just used to make a more aesthetic and professional looking plot.
sns.lmplot(x='Year', y='LogPopulation', data=G8, hue='Country Name',
size=7, aspect=1.5, line_kws={'alpha':0.5}, hue_order=G8_countries);
This is interesting. While some countries fit very well to the Malthusian model (United States, France, United Kingdom), some countries don't fit as well: Japan seems to have stopped growing during the 1980's, and, European Union, Italy and Canada also seem to decelerate their growth. But this can be due to these countires being "developed". What about some "developing" countries?
Let's focus on the four largest countries, only one of the is "developed", but as we saw above it fits the Malthusian model.
We will use SciPy's linregress
function to perform linear regression. This will allow us to extrapolate - to predict what will be the population of these countries if the rate at which they grow remain constant under the Malthusian model.
Note that the linear regression of $y$ as a function of $x$ is $y = a x + b$, where $a$ is the slope of the regression and $b$ is the interception. But then we have to exponentiate both sides of the equation as we want actual population size and not log population - so we get $e^y = e^{ax+b}$.
Let's extrapolate to the end of the century:
years = np.linspace(df.Year.min(), 2100)
for country_name in ('India','China','United States', 'Brazil'):
country_df = df[df['Country Name'] == country_name]
slope, intercept, r_value, p_value, std_err = linregress(country_df.Year, country_df.LogPopulation)
print("Coefficient of correlation for %s: %.4f" % ( country_name, r_value), end=", ")
print("P-value for %s: %.2g" % ( country_name, p_value))
population = np.exp(intercept + slope * years)
plt.plot(years, population, label=country_name) # prediction
plt.scatter(country_df.Year, country_df.Population, marker='.') #data
plt.xlabel('Year')
plt.ylabel('Population')
plt.yscale('log')
plt.ylim(5e7,1e10)
plt.legend(loc='upper left')
plt.axvline(x=2032, color='gray', linestyle='--')
plt.axvline(x=2056, color='gray', linestyle='--');
Coefficient of correlation for India: 0.9968, P-value for India: 1.1e-58 Coefficient of correlation for China: 0.9764, P-value for China: 2.9e-36 Coefficient of correlation for United States: 0.9993, P-value for United States: 3.2e-76 Coefficient of correlation for Brazil: 0.9889, P-value for Brazil: 9.1e-45
From our prediction, India will pass China around 2032 and Brazil will pass United States around 2056.
From the results of linregress
we can estimate confidence intervals on the intercept and slope -- look for it online. We can also calculate the coefficient of determination ($R^2$).
The Malthusian/exponential growth model can't always be correct: many times growth decelerates and effectively stops when reaching a certain size K (carrying capacity, maximum yield, maximum population density etc.). This is true for fish body size (Schnute 1981), microbial population size in a constant volume (Zwietering 1990), and natural animal populations. In these cases it is common to use the logistic growth model in which the size of the population inhibits growth, leading to a maximum population size after which growth stops:
$$ \frac{dN}{dt} = r N \Big( 1 - \frac{N}{K} \Big) $$print("Logistic growth")
YouTubeVideo("https://www.youtube.com/watch?v=rXlyYFXyfIM")
Logistic growth
Let's start by plotting the solution to this ordinary differential equation (ODE). We need to define the ODE as a function of N, t and whatever other parameters are needed. It is important that N comes before t:
def logistic_ode(N, t, r, K):
return r * N * (1 - N / K)
Integrating ODEs is really easy with Python. After defining the ODE function we can define a new function that given $t$, $r$, $K$ and $N_0$ (initial population size) calculates $N(t)$ by integration. For this we use SciPy's odeint
(ODE integration).
def integrate_logistic(t, N0, r, K):
N = odeint(logistic_ode, N0, t, args=(r, K))
return N
t = np.linspace(0,10)
N = integrate_logistic(t, N0=1, r=1, K=10)
plt.plot(t, N)
plt.xlabel('Time')
plt.ylabel('Population')
sns.despine();
Incredibly, there is a solution to the logistic equation that doesn't require numerical integration:
$$ \frac{dN}{dt} = r N(t)(1 - N/K) \Rightarrow \\ N(t) = \frac{K N_0 e^{rt}}{K + N_0(e^{rt}-1)} $$def logistic(t, N0, r, K):
return K * N0 * np.exp(r*t) / ( K + N0 * (np.exp(r*t) - 1))
t = np.linspace(0,10)
N_integration = integrate_logistic(t, 1, 1, 10)
N_solution = logistic(t, 1, 1, 10)
plt.plot(t, N_integration, label='ODE')
plt.plot(t, N_solution, 'o', label='Solution')
plt.legend(loc='upper left');
So we have a standard growth model, but how do we find the right model (i.e. parameters: N0
, r
, K
) for our data? To do this we need to fit the logistic model to our data.
We will use a SciPy function called curve_fit
that accepts a model (a function of t
and some other arguments/parameters that returns N
), data (t
and N
), initial guesses of the model parameters, and some other optional arguments. It evaluates the model with the initial guesses and calculates the fit of the result -- in standrad fitting this is done by the sum of the square of the differences between the model and the data. It then makes another guess of the model parameters, evaluates the model result and calculates the distance, and again and again. It stops when it reaches a minimal distance, and that's why this operation is called least squares. This operation isn't perfect, and when there are even just 3 or 4 model parameters there might be more than one minimum or the function might not find a minimum in a reasonable time.
Before we try to work with real data we want to try it with simulated data and see that we know what we are doing. To simulate data we generate data from the logistic function and add some normally distributed noise:
simulated = logistic(t, N0=1.5, r=0.75, K=8.5) + np.random.normal(loc=0, scale=0.7, size=t.shape)
plt.scatter(t, simulated)
sns.despine()
Now we can call curve_fit
and give it the simulated data and get back an estimate of the model parameters. As you can see, although there is a lot of noise, we get a pretty good estimate of the parameters.
params, cov = curve_fit(f=logistic, xdata=t, ydata=simulated, p0=(1,1,10))
N0,r,K = params
plt.scatter(t, simulated)
plt.plot(t, logistic(t, N0, r, K), '-')
sns.despine()
print('N0=%.3f, r=%.3f, K=%.3f' % (N0,r,K))
N0=1.452, r=0.742, K=8.576
We can also get confidence intervals for the parameter values from the cov
return value. But if you need to work with confidence interals you might prefer to work with the package lmfit which is easier to use if you need more than the basics.
if not os.path.exists('Yoav_311214c_only_OD.csv'):
urllib.request.urlretrieve('https://raw.githubusercontent.com/Py4Life/TAU2015/master/Yoav_311214c_only_OD.csv', 'Yoav_311214c_only_OD.csv')
df = pd.read_csv('Yoav_311214c_only_OD.csv')
df.head()
Cycle Nr. | Time [s] | Temp. [°C] | A1 | A2 | A3 | A4 | A5 | A6 | A7 | ... | H4 | H5 | H6 | H7 | H8 | H9 | H10 | H11 | H12 | Time | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | 0.0 | 37.2 | 0.1128 | 0.1118 | 0.1186 | 0.1213 | 0.1131 | 0.1111 | 0.1178 | ... | 0.1240 | 0.1352 | 0.1266 | 0.1147 | 0.1157 | 0.1084 | 0.1177 | 0.1217 | 0.1110 | 0.000000 |
1 | 2 | 1136.9 | 37.4 | 0.1164 | 0.1151 | 0.1189 | 0.1209 | 0.1173 | 0.1144 | 0.1221 | ... | 0.1229 | 0.1386 | 0.1289 | 0.1154 | 0.1211 | 0.1103 | 0.1216 | 0.1208 | 0.1130 | 0.315806 |
2 | 3 | 2273.7 | 37.3 | 0.1221 | 0.1260 | 0.1242 | 0.1251 | 0.1247 | 0.1222 | 0.1312 | ... | 0.1350 | 0.1431 | 0.1373 | 0.1219 | 0.1258 | 0.1175 | 0.1276 | 0.1408 | 0.1080 | 0.631583 |
3 | 4 | 3410.5 | 37.3 | 0.1315 | 0.1405 | 0.1335 | 0.1338 | 0.1342 | 0.1321 | 0.1469 | ... | 0.1586 | 0.1493 | 0.1455 | 0.1362 | 0.1349 | 0.1263 | 0.1374 | 0.1777 | 0.1104 | 0.947361 |
4 | 5 | 4547.4 | 36.9 | 0.1451 | 0.1464 | 0.1467 | 0.1470 | 0.1486 | 0.1465 | 0.1584 | ... | 0.1500 | 0.1622 | 0.1593 | 0.1481 | 0.1477 | 0.1385 | 0.1554 | 0.1537 | 0.1089 | 1.263167 |
5 rows × 100 columns
We don't need the Time [s]
column (we got Time
in hours) and we don't need Temp. [°C]
, so we drop these columns.
Then we melt the data. That means that we transform our data from a matrix format (one row per timepoint, one column per well + columns for other variables such as time) into a data frame format (one row for each measurement, with one column for each measurement variable).
df = df.drop(axis=1, labels=['Time [s]', 'Temp. [°C]'])
df = pd.melt(df, id_vars=('Time','Cycle Nr.'), var_name='Well', value_name='OD')
df.head()
Time | Cycle Nr. | Well | OD | |
---|---|---|---|---|
0 | 0.000000 | 1 | A1 | 0.1128 |
1 | 0.315806 | 2 | A1 | 0.1164 |
2 | 0.631583 | 3 | A1 | 0.1221 |
3 | 0.947361 | 4 | A1 | 0.1315 |
4 | 1.263167 | 5 | A1 | 0.1451 |
Now let's add a column for the plate row (A - H) and the plate columns (1 - 12). We will use a list comprehension for that:
def well2row(well):
return well[0]
def well2col(well):
return int(well[1:])
df['Row'] = [well2row(x) for x in df.Well]
df['Col'] = [well2col(x) for x in df.Well]
df.head()
Time | Cycle Nr. | Well | OD | Row | Col | |
---|---|---|---|---|---|---|
0 | 0.000000 | 1 | A1 | 0.1128 | A | 1 |
1 | 0.315806 | 2 | A1 | 0.1164 | A | 1 |
2 | 0.631583 | 3 | A1 | 0.1221 | A | 1 |
3 | 0.947361 | 4 | A1 | 0.1315 | A | 1 |
4 | 1.263167 | 5 | A1 | 0.1451 | A | 1 |
Finally the data is prepared and we are ready to explore it. Let's start by plotting the OD over time curve of every well - just to make sure we know what we have in our hands.
We use seaborn.FacetGrid
which allows us to easily-ish make faceted plots. We need to specific the DataFrame columns on which we want to facet (these are the Col
and Row
we just created) and we can specify a bunch of styling options too. FacetGrid
will break the data into smaller data sets according to our instructions. We can also specify which column controls the hue
(color) of the plots.
Then we can map
a function to each of the facets. We give the function and then the columns that should be used as the function arguments.
g = sns.FacetGrid(df, col='Col', row='Row', hue='Well', sharex=True, sharey=True, size=0.75,
aspect=12./8, despine=True, margin_titles=True)
g.map(plt.plot, 'Time', 'OD')
plt.locator_params(nbins=4) # 4 ticks is enough
g.set_axis_labels('','') # remove facets axis labels
g.fig.text(0.5, 0, 'Time', size='x-large') # xlabel
g.fig.text(-0.01, 0.5, 'OD', size='x-large', rotation='vertical') # ylabel
g.fig.tight_layout()
You might notice that the curves in G12 and H12 are constant - these are the blank/control wells. Let's get rid of them:
df = df[(df.Well != 'G12') & (df.Well != 'H12')]
In order to fit a logistic curve to the data we need to calculate the average curve. This is easy (once you learn how to do it) in pandas. The idea is: we split the data by one or more columns. Rows with the same values in these columns will be grouped together. We then apply an aggregation to these groups, meaning we make a single row out of each group. We then combine the rows back together into a single data frame.
This strategy is known as split-apply_combine and it is a very powerful strategy to work with data frames. One of it's powers is that it easily extends to parallel computing because the apply part can be done in parallel over the different groups.
Using pandas we will split using the groupby
method; we will aggregate using mean
(for general aggregations use aggregate
); we will then resetthe indices (which were lost during the aggregation) with reset_index
.
We want the mean OD value for each time point so we group by Time
and aggregate with mean
.
grouped = df.groupby('Time')
aggregated = grouped.OD.mean().reset_index()
aggregated.head()
Time | OD | |
---|---|---|
0 | 0.000000 | 0.115260 |
1 | 0.315806 | 0.118860 |
2 | 0.631583 | 0.125469 |
3 | 0.947361 | 0.135600 |
4 | 1.263167 | 0.149310 |
Let's plot:
t = aggregated.Time
N = aggregated.OD
plt.scatter(t, N)
plt.xlabel('Time')
plt.ylabel('Mean OD')
plt.axvline(x=10, color='gray')
sns.despine()
You might notice that after about 10 hours the population size starts to decrease. This is an artifact of the experiments, due to the wells starting to dry out. Let's drop this data:
N = N[t<10]
t = t[t<10]
plt.scatter(t, N)
plt.xlabel('Time')
plt.ylabel('Mean OD')
sns.despine()
Now, finally, to the fitting (80% is data preparation!):
popt, cov = curve_fit(logistic, t, N, (1., N.max(), N.min()))
N0, r, K = popt
print('N0=%.3f, r=%.6f, K=%.3f' % (N0, r, K))
plt.scatter(t, N)
plt.plot(t, logistic(t, N0, r, K))
plt.xlabel('Time')
plt.ylabel('OD')
sns.despine();
N0=0.081, r=0.601459, K=0.601
You can see that the fit is so-so. It's not that bad but it's not perfect, especially at the begining. This is because the logistic model probably isn't the best choice (although it's a common, simple choice).
The Gompertz model is a different growth model. Let's try it:
def gompertz(t, N0, r, K):
return K * np.exp( np.log(N0/K) * np.exp(-r*t) )
popt, cov = curve_fit(gompertz, t, N, (r, N.max(), N.min()))
N0gomp, rgomp, Kgomp = popt
print('N0=%.3f, r=%.6f, K=%.3f' % (N0gomp, rgomp,Kgomp))
plt.scatter(t, N, label='data')
plt.plot(t, logistic(t, N0, r, K), '-', label='logistic')
plt.plot(t, gompertz(t, N0gomp, rgomp, Kgomp), '--', label='gompertz')
plt.legend(loc='upper left');
N0=0.065, r=0.371695, K=0.641
Not much better...
Let's try the Richards model, a generalized logsitic model:
def richards(t, N0, r, K, nu):
Q = -1 + (K/N0)**nu
return K / ( 1 + Q * np.exp(-r * nu * t) )**(1/nu)
The difference between the logistic and the Richards model is in the deceleration of growth: $1-(K/N)^{\nu}$. When $\nu=1$ Richards model is just the logistic model; when $\nu>1$ growth decelerates later in the growth; when $\nu<1$ it decelerates earlier.
plt.plot(t, logistic(t, N0, r, K), '--', lw=5, label='logistic')
plt.plot(t, richards(t, N0, r, K, 0.5), label='richards, nu=0')
plt.plot(t, richards(t, N0, r, K, 1), label='richards, nu=1')
plt.plot(t, richards(t, N0, r, K, 2), label='richards, nu=2')
plt.legend(loc='upper left')
sns.despine()
Let's try to fit the data with Richards model:
popt, cov = curve_fit(richards, t, N, (r, N.max(), N.min(), 1))
N0rich, rrich, Krich, nurich = popt
print('N0=%.3f, r=%.6f, K=%.3f, nu=%.3f' % (N0rich, rrich, Krich, nurich))
plt.scatter(t, N, label='data')
plt.plot(t, logistic(t, N0, r, K), '--', label='logistic')
plt.plot(t, richards(t, N0rich, rrich, Krich, nurich), '--', label='richards')
plt.legend(loc='upper left')
sns.despine()
N0=0.098, r=0.369301, K=0.580, nu=2.682
Richards model indeed provides a better fit to the data. We can measure this fit using AIC (the lower the better):
n = len(t)
print("Logistic AIC:", n * np.log(((logistic(t, N0, r, K)-N)**2).sum()/n) + 2*3)
print("Richards AIC:", n * np.log(((richards(t, N0rich, rrich, Krich, nurich)-N)**2).sum()/n) + 2*4)
Logistic AIC: -281.123079061 Richards AIC: -322.617475464
This notebook is part of the Python Programming for Life Sciences Graduate Students course given in Tel-Aviv University, Spring 2015.
The notebook was written using Python 3.4.1 and IPython 3.1.0 (download from PyZo).
The code is available at https://github.com/Py4Life/TAU2015/blob/master/lecture10.ipynb.
The notebook can be viewed online at http://nbviewer.ipython.org/github/Py4Life/TAU2015/blob/master/lecture10.ipynb.
This work is licensed under a Creative Commons Attribution-ShareAlike 3.0 Unported License.