From Wikipedia:
A small multiple (sometimes called trellis chart, lattice chart, grid chart, or panel chart) is a series or grid of small similar graphics or charts, allowing them to be easily compared. The term was popularized by Edward Tufte.
In this notebook, we take data from the GapMinder dataset of gross domestic product (GDP) of all the world's countries (the same dataset that Hans Rosling used to make his famous bubble charts) to plot a GDP per capita small multiple with one subplot per country for all of the world's country.
That's right, Plotly can do that!
First, import the Plotly modules and graph object required for this task:
import plotly.plotly as py # signing in with your credentials file
import plotly.tools as tls
from plotly.graph_objs import Figure, Data, Layout
from plotly.graph_objs import Scatter
from plotly.graph_objs import Marker, Font
from plotly.graph_objs import XAxis, YAxis, Annotation, Annotations
Next, import the data into pandas dataframe and import numpy:
import numpy as np
import pandas as pd
import urllib2
# The datasets' url. Thanks Jennifer Bryan!
url_csv = 'http://www.stat.ubc.ca/~jenny/notOcto/STAT545A/examples/\
gapminder/data/gapminderDataFiveYear.txt'
file_csv = urllib2.urlopen(url_csv) # import csv file into this session
df = pd.read_csv(file_csv, sep='\t') # load csv file into a dataframe
df.head() # print the first 5 lines of the dataframe
country | year | pop | continent | lifeExp | gdpPercap | |
---|---|---|---|---|---|---|
0 | Afghanistan | 1952 | 8425333 | Asia | 28.801 | 779.445314 |
1 | Afghanistan | 1957 | 9240934 | Asia | 30.332 | 820.853030 |
2 | Afghanistan | 1962 | 10267083 | Asia | 31.997 | 853.100710 |
3 | Afghanistan | 1967 | 11537966 | Asia | 34.020 | 836.197138 |
4 | Afghanistan | 1972 | 13079460 | Asia | 36.088 | 739.981106 |
5 rows × 6 columns
To get a feel for the dataset, find the start and end year as well as the number of countries in the dataset:
# Start year, end year and number of year is dataset
df['year'].min(), df['year'].max(), len(df['year'].unique())
(1952, 2007, 12)
countries = df['country'].unique() # list of countries
N_country = len(countries) # number of countries
N_rowcol = int(np.ceil(np.sqrt(N_country))) # size of the square subplot grid
N_country, N_rowcol # print to screen
(142, 12)
There are 142 countries in the dataset and therefore, choosing to have a square small multiple for aesthetics, we need a 12 by 12 subplot grid to include every country. Consequently, there will be 2 empty subplots.
Next, generate a 12 by 12 subplot grid using get_subplots()
:
# Generate Figure object with 144 axes (12 rows x 12 columns),
fig = tls.get_subplots(
rows= N_rowcol, # number of rows
columns= N_rowcol, # number of columns
horizontal_spacing= 0.02, # horiz. spacing (norm. coord)
vertical_spacing= 0.02, # vert. spacing (norm. coord)
print_grid=True) # print axis grid ids to screen
This is the format of your plot grid! [133] [134] [135] [136] [137] [138] [139] [140] [141] [142] [143] [144] [121] [122] [123] [124] [125] [126] [127] [128] [129] [130] [131] [132] [109] [110] [111] [112] [113] [114] [115] [116] [117] [118] [119] [120] [97] [98] [99] [100] [101] [102] [103] [104] [105] [106] [107] [108] [85] [86] [87] [88] [89] [90] [91] [92] [93] [94] [95] [96] [73] [74] [75] [76] [77] [78] [79] [80] [81] [82] [83] [84] [61] [62] [63] [64] [65] [66] [67] [68] [69] [70] [71] [72] [49] [50] [51] [52] [53] [54] [55] [56] [57] [58] [59] [60] [37] [38] [39] [40] [41] [42] [43] [44] [45] [46] [47] [48] [25] [26] [27] [28] [29] [30] [31] [32] [33] [34] [35] [36] [13] [14] [15] [16] [17] [18] [19] [20] [21] [22] [23] [24] [1] [2] [3] [4] [5] [6] [7] [8] [9] [10] [11] [12]
The get_subplots()
function generates subplots axes row-wise starting from the bottom left corner of the plot. For example, the xaxis1
and yaxis1
are set at the bottom left corner of the subplot grid and the last xaxis
/yaxis
pair is the top right corner.
For presentation purposes, we will plot all the countries in the dataset in alphabetical order starting the top left corner of the subplot grid. To do so, we choose to modify the configuration of axis ids in the subplot grid. Consider,
# Function to make list of subplot indices
def get_splts(N_rowcol, N_country):
N_splt = N_rowcol**2 # number of subplots
N_empty = N_splt-N_country # number of empty subplots
tmp1d = np.arange(1,N_splt+1) # => [1,2,..,N_splt]
tmp2d = np.resize(tmp1d, (N_rowcol,N_rowcol)) # => [[1,2,..,N_rowcol],..[..,N_splt]]
tmp2d_flip = tmp2d[::-1,:] # => [[..,N_spl],..[1,2,..,N_rowcol]]
splts_left = tmp2d_flip[:,0] # indices of the left-hand side subplots
splts_bottom = tmp2d_flip[-1,:] # indices of the bottom subplots
tmp1d_in_order = tmp2d_flip.flatten().tolist() # => [..,N_spl,..,1,2,..N_rowcol]
splts_empty = range(N_rowcol-N_empty+1,N_rowcol+1) # indices of empty subplots
for splt in splts_empty:
tmp1d_in_order.remove(splt) # remove indices of empty subplots
splts = tmp1d_in_order # and get the complete list of subplots
return splts, splts_empty, splts_left, splts_bottom
# Get lists of subplot indices
splts, splts_empty, splts_left, splts_bottom = get_splts(N_rowcol, N_country)
splts # print list
[133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
The first item in splts
corresponds to the id of the top left axes, the last item is third id from the bottom right corner, leaving two empty sets of axes, and so on.
Consequently, given the list of countries in alphabetical order (thanks pandas), each country's data can be easily matched to the appropriate axes using sbplts
.
Next, make a few definitions to help us fill in the data object:
# Function to make Scatter graph object
def make_Scatter_gdp(splt, x, y, color, country, text):
return Scatter(
x= x, # x coordinates
y= y, # y coordinates
name= country, # label name (on hover)
text = text, # hover text
mode='lines+markers', # show marker pts and line beween them
fill= 'tozeroy', # fill area down to y=0
marker= Marker(color= color), # marker, line and fill color
xaxis= 'x{}'.format(splt), # bind coordinate to given x-axis
yaxis= 'y{}'.format(splt)) # bind coordinate to given y-axis
# Colors corresponding to contients
colors = dict(
Asia='#1f77b4',
Europe='#ff7f0e',
Africa='#2ca02c',
Americas='#d62728',
Oceania='#9467bd')
# Function to make hover text list (for each data point)
def make_text(X):
return 'Continent: %s\
<br>Year: %s\
<br>GDP per capita: %s $\
<br>Life Expectancy: %s years\
<br>Population: %s million'\
% (X['continent'], X['year'], X['gdpPercap'], X['lifeExp'], X['pop']/1e6)
To try to limit the amount of redundant information on the plot, only subplots on the left-hand side of the subplot grid will have labelled y axes and only subplots on the bottom row of the subplot grid will have labelled x axes.
Additionally, for better comparisons between countries, all subplots will have the same axes range. So, let's get an idea of the range in GDP per capita in the dataset:
df['gdpPercap'].min(), df['gdpPercap'].max()
(241.16587650000002, 113523.1329)
Then, define a few axes style dictionaries that will be used to update each axis' style:
# For all x axes
axis_style_x = dict(
range = [1950,2010]) # Set x-axis range
# For all y axes
axis_style_y = dict(
type='log', # N.B. log y-axis
range = [np.log10(90),np.log10(5e5)]) # N.B. set y-axis range w.r.t. log scale
# For all axes
axis_style_all = dict(
ticks='outside', # no ticks
showline=True, # show axis bounding line
showgrid=False, # remove grid
zeroline=False, # no thick line at x=y=0
showticklabels=False) # remove tick labels
# For y axes on the left hand side of the subplot grid
axis_style_left = dict(
showticklabels=True, # N.B. add back tick labels (overwrite axis_style_all)
title='GDP per cap.') # title of the y axes
# For x axes on the bottom of the subplot grid
axis_style_bottom = dict(
showticklabels=True, # N.B. add back tick labels (overwrite axis_style_all)
title='year') # title of the x axes
Make an annotation-generating function to label the subplots with each country's name:
# Function to make annotation labelling each classifier (at top of each column)
def make_splt_anno(splt_in, country):
if len(country)>14:
country = country[0:14]+'.' # truncate country's name if too long
return Annotation(
x= 1955, # x position
y= np.log10(2.5e5), # y position
text= country, # text
align='center', # align text in the center
font= Font(size=14), # font size
showarrow=False, # no arrow
xref= 'x{}'.format(splt), # position in relation to the x
yref= 'y{}'.format(splt)) # and y axes
Add a few keys to the layout object:
width = 2000 # plot's width
height = 1800 # and height in pixels
title = "GDP per Capita from 1952 to 2007 in USD of the year 2000 [GapMinder]"
fig['layout'].update(
title= title, # plot's title
font= Font(
family='Georgia, serif', # global font,
color='#635F5D'), # same as in 3.1
titlefont= Font(size=30), # increase title font size
showlegend=False, # remove legend
autosize=False, # turn off autosize
width= width, # plot's width
height= height) # plot's height
fig['layout']['annotations'] = Annotations([]) # init. 'annotations' key
Now, loop through every country in the dataframe and fill in the data object and update each axis' style:
i = 0 # init. subplot counter
# Group dataframe by country in alphabetical order and loop
for country, X in df.groupby('country'):
splt = splts[i] # N.B. get axes id
x = X['year'].values # get years
y = X['gdpPercap'].values # get GDP values
color = colors[X['continent'].values[0]] # get fill color
text = X.apply(make_text,axis=1).tolist() # get hover text
# Append data object
fig['data'] += [make_Scatter_gdp(splt, x, y, color, country, text)]
# Make shortcut to xaxis of splt id, update its style
xaxis_splt = fig['layout']['xaxis{}'.format(splt)]
xaxis_splt.update(axis_style_x)
xaxis_splt.update(axis_style_all)
if splt in splts_bottom:
xaxis_splt.update(axis_style_bottom)
# Make shortcut to yaxis of splt id, update its style
yaxis_splt = fig['layout']['yaxis{}'.format(splt)]
yaxis_splt.update(axis_style_y)
yaxis_splt.update(axis_style_all)
if splt in splts_left:
yaxis_splt.update(axis_style_left)
# Append annotations object, label each subplot
fig['layout']['annotations'] += [make_splt_anno(splt, country)]
i += 1 # increment counter
Before sending the figure object to Plotly, we must also update the axis style of the empty subplots:
for splt in splts_empty: # loop through list of empty subplot ids
# Make shortcut to xaxis of splt id, update its style
xaxis_splt = fig['layout']['xaxis{}'.format(splt)]
xaxis_splt.update(axis_style_x)
xaxis_splt.update(axis_style_all)
xaxis_splt.update(axis_style_bottom) # empty subplots are on the bottom row
# Make shortcut to yaxis of splt id, update its style
yaxis_splt = fig['layout']['yaxis{}'.format(splt)]
yaxis_splt.update(axis_style_y)
yaxis_splt.update(axis_style_all)
We are now ready to send the figure object to Plotly and see the results:
py.plot(fig, filename='small-multiple_gdp-time', auto_open=False)
u'https://plot.ly/~etpinard/311'
To see the figuire, copy the above URL in your browser or:
tls.embed('etpinard','311',
width=width, height=height)
To convert the Plotly figure to a .png
, run:
py.image.save_as(fig, 'small-multiple_gdp-time')
from IPython.display import Image
Image('small-multiple_gdp-time.png')
About Plotly
Big thanks to
# CSS styling within IPython notebook
from IPython.core.display import HTML
import urllib2
def css_styling():
url = 'https://raw.githubusercontent.com/plotly/python-user-guide/master/custom.css'
styles = urllib2.urlopen(url).read()
return HTML(styles)
css_styling()