import cs109style cs109style.customize_mpl() cs109style.customize_css() # special IPython command to prepare the notebook for matplotlib %matplotlib inline from collections import defaultdict import pandas as pd import matplotlib.pyplot as plt import requests from pattern import web url = 'http://en.wikipedia.org/wiki/List_of_countries_by_past_and_future_population' website_html = requests.get(url).text #print website_html def get_population_html_tables(html): """Parse html and return html tables of wikipedia population data.""" dom = web.Element(html) # 0. step: look at html source! # 1. step: get all tables # tbls = [t for t in dom.by_tag('table')] # 2. step: get all wikitable sortable tables (the ones with data) tbls = [t for t in dom.by_tag('table') if t.attributes['class'] == "wikitable sortable"] return tbls tables = get_population_html_tables(website_html) print "table length: %d" %len(tables) for t in tables: print t.attributes def table_type(tbl): headers = [th.content for th in tbl.by_tag('th')] return headers[1] # group the tables by type tables_by_type = defaultdict(list) # defaultdicts have a default value that is inserted when a new key is accessed for tbl in tables: tables_by_type[table_type(tbl)].append(tbl) print tables_by_type def get_countries_population(tables): """Extract population data for countries from all tables and store it in dictionary.""" result = defaultdict(dict) # 1. step: try to extract data for a single table # 2. step: iterate over all tables, extract headings and actual data and combine data into single dict for tbl in tables: # extract column headers # each table looks a little different, therefore extract columns that store data (i.e., table header is a year) tbl_headers = [ th.content for th in tbl.by_tag('th')] column_idx_years = [(idx, int(header)) for idx, header in enumerate(tbl_headers) if header.isnumeric()] column_idx, column_years = zip(*column_idx_years) # extract data from table # get table rows - but skip the ones that have no td element tbl_rows = [ row for row in tbl.by_tag('tr') if row.by_tag('td') ] #print len(trs) #print trs[0] for row in tbl_rows: #datarow = [td.content for td in tr.by_tag('td')] #print datarow # get country name - 2nd td, a href, convert unicode to string countryname = (row.by_tag('td')[1].by_tag('a')[0].content).encode('ascii','ignore') #print type(countryname) #print countryname # get country data - create a dictionary {1955: 10000, 1960: 14000,...} # extract data from the columns in column_idx; strip commas from numers; scale number to millions countrydata = {column_years[i]:int(row.by_tag('td')[idx].content.replace(',', ''))/1000.0 for i,idx in enumerate(column_idx) } #print datarow # append to dictionary result[countryname].update(countrydata) return result result = get_countries_population(tables_by_type['Country or territory']) print result # create dataframe df = pd.DataFrame.from_dict(result, orient='index') # sort based on year df.sort(axis=1,inplace=True) print df subtable = df.iloc[0:2, 0:2] print "subtable" print subtable print "" column = df[1955] print "column" print column print "" row = df.ix[0] #row 0 print "row" print row print "" rows = df.ix[:2] #rows 0,1 print "rows" print rows print "" element = df.ix[0,1955] #element print "element" print element print "" # max along column print "max" print df[1950].max() print "" # axes print "axes" print df.axes print "" row = df.ix[0] print "row info" print row.name print row.index print "" countries = df.index print "countries" print countries print "" print "Austria" print df.ix['Austria'] plotCountries = ['Austria', 'Germany', 'United States', 'France'] for country in plotCountries: row = df.ix[country] plt.plot(row.index, row, label=row.name ) plt.ylim(ymin=0) # start y axis at 0 plt.xticks(rotation=70) plt.legend(loc='best') plt.xlabel("Year") plt.ylabel("# people (million)") plt.title("Population of countries") def plot_populous(df, year): # sort table depending on data value in year column df_by_year = df.sort(year, ascending=False) plt.figure() for i in range(5): row = df_by_year.ix[i] plt.plot(row.index, row, label=row.name ) plt.ylim(ymin=0) plt.xticks(rotation=70) plt.legend(loc='best') plt.xlabel("Year") plt.ylabel("# people (million)") plt.title("Most populous countries in %d" % year) plot_populous(df, 2010) plot_populous(df, 2050)