This is a notebook from the live coding session by Dr Hugo Bowne-Anderson on April 10, 2020 via DataCamp.
Imports and data¶
Let's import the necessary packages from the SciPy stack and get the data.
# Import packages
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
# Set style & figures inline
sns.set()
%matplotlib inline
# Data urls
base_url = 'https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/'
confirmed_cases_data_url = base_url + 'time_series_covid19_confirmed_global.csv'
death_cases_data_url = base_url + 'time_series_covid19_deaths_global.csv'
recovery_cases_data_url = base_url+ 'time_series_covid19_recovered_global.csv'
# Import datasets as pandas dataframes
raw_data_confirmed = pd.read_csv(confirmed_cases_data_url)
raw_data_deaths = pd.read_csv(death_cases_data_url)
raw_data_recovered = pd.read_csv(recovery_cases_data_url)
Confirmed cases of COVID-19¶
We'll first check out the confirmed cases data by looking at the head of the dataframe:
raw_data_confirmed.head(n=10)
Discuss: What do you see here?
We can also see a lot about the data by using the .info()
and .describe()
dataframe methods:
raw_data_confirmed.info()
raw_data_confirmed.describe()
Number of confirmed cases by country¶
Look at the head (or tail) of our dataframe again and notice that each row is the data for a particular province or state of a given country:
raw_data_confirmed.head()
We want the numbers for each country, though. So the way to think about this is, for each country, we want to take all the rows (regions/provinces) that correspond to that country and add up the numbers for each. To put this in data-analytic-speak, we want to group by the country column and sum up all the values for the other columns.
This is a common pattern in data analysis that we humans have been using for centuries. Interestingly, it was only formalized in 2011 by Hadley Wickham in his seminal paper The Split-Apply-Combine Strategy for Data Analysis. The pattern we're discussing is now called Split-Apply-Combine and, in the case at hand, we
- Split the data into new datasets for each country,
- Apply the function of "sum" for each new dataset (that is, we add/sum up the values for each column) to sum over territories/provinces/states for each country, and
- Combine these datasets into a new dataframe.
The pandas
API has the groupby
method, which allows us to do this.
Side note: For more on split-apply-combine and pandas
check out my post here.
# Group by region (also drop 'Lat', 'Long' as it doesn't make sense to sum them here)
confirmed_country = raw_data_confirmed.groupby(by=['Country/Region']).sum().drop(['Lat', 'Long'], axis=1)
confirmed_country.head()
So each row of our new dataframe confirmed_country
is a time series of the number of confirmed cases for each country. Cool!
Now a dataframe has an associated object called an Index, which is essentially a set of unique indentifiers for each row. Let's check out the index of confirmed_country
:
confirmed_country.index
It's indexed by Country/Region
. That's all good but if we index by date instead, it will allow us to produce some visualizations almost immediately. This is a nice aspect of the pandas
API: you can make basic visualizations with it and, if your index consists of DateTimes, it knows that you're plotting time series and plays nicely with them.
To make the index the set of dates, notice that the column names are the dates. To turn column names into the index, we essentially want to make the columns the rows (and the rows the columns). This corresponds to taking the transpose of the dataframe:
confirmed_country = confirmed_country.transpose()
confirmed_country.head()
Let's have a look at our index to see whether it actually consists of DateTimes:
confirmed_country.index
Note that dtype='object'
which means that these are strings, not DateTimes. We can use pandas
to turn it into a DateTimeIndex:
# Set index as DateTimeIndex
datetime_index = pd.DatetimeIndex(confirmed_country.index)
confirmed_country.set_index(datetime_index)
# Check out index
confirmed_country.index
Now we have a DateTimeIndex and Countries for columns, we can use the dataframe plotting method to visualize the time series of confirmed number of cases by country. As there are so many coutries, we'll plot a subset of them:
Plotting confirmed cases by country¶
# Plot time series of several countries of interest
poi = ['China', 'US', 'Italy', 'France', 'Spain', 'Australia']
confirmed_country[poi].plot(figsize=(20, 10), linewidth=5, colormap='brg', fontsize=20)
Let's label our axes and give the figure a title. We'll also thin the line and add points for the data so that the sampling is evident in our plots:
# Plot time series of several countries of interest
confirmed_country[poi].plot(figsize=(20,10), linewidth=2, marker='.', colormap='brg', fontsize=20)
plt.xlabel('Date', fontsize=20);
plt.ylabel('Reported Confirmed cases count', fontsize=20);
plt.title('Reported Confirmed Cases Time Series', fontsize=20);
Let's do this again but make the y-axis logarithmic:
# Plot time series of several countries of interest
confirmed_country[poi].plot(figsize=(20,10), linewidth=2, marker='.', fontsize=20, logy=True)
plt.xlabel('Date', fontsize=20);
plt.ylabel('Reported Confirmed cases count', fontsize=20);
plt.title('Reported Confirmed Cases Time Series', fontsize=20);
Discuss: Why do we plot with a log y-axis? How do we interpret the log plot? Key points:
- If a variable takes on values over several orders of magnitude (e.g. in the 10s, 100s, and 1000s), we use a log axis so that the data is not all crammed into a small region of the visualization.
- If a curve is approximately linear on a log axis, then its approximately exponential growth and the gradient/slope of the line tells us about the exponent.
ESSENTIAL POINT: A logarithm scale is good for visualization BUT remember, in the thoughtful words of Justin Bois, "on the ground, in the hospitals, we live with the linear scale. The flattening of the US curve, for example is more evident on the log scale, but the growth is still rapid on a linear scale, which is what we feel."
Summary: We've
- looked at the JHU data repository and imported the data,
- looked at the dataset containing the number of reported confirmed cases for each region,
- wrangled the data to look at the number of reported confirmed cases by country,
- plotted the number of reported confirmed cases by country (both log and semi-log),
- discussed why log plots are important for visualization and that we need to remember that we, as humans, families, communities, and society, experience COVID-19 linearly.
Number of reported deaths¶
As we did above for raw_data_confirmed
, let's check out the head and the info of the raw_data_deaths
dataframe:
raw_data_deaths.head()
raw_data_deaths.info()
It seems to be structured similarly to raw_data_confirmed
. I have checked it out in detail and can confirm that it is! This is good data design as it means that users like can explore, munge, and visualize it in a fashion analogous to the above. Can you remember what we did? We
- Split-Apply-Combined it (and dropped 'Lat'/'Long'),
- Transposed it,
- Made the index a DateTimeIndex, and
- Visualized it (linear and semi-log).
Let's now do the first three steps here for raw_data_deaths
and see how we go:
Number of reported deaths by country¶
# Split-Apply-Combine
deaths_country = raw_data_deaths.groupby(by=['Country/Region']).sum().drop(['Lat', 'Long'], axis=1)
# Transpose
deaths_country = deaths_country.transpose()
# Set index as DateTimeIndex
datetime_index = pd.DatetimeIndex(deaths_country.index)
deaths_country.set_index(datetime_index)
# Check out head
deaths_country.head()
# Check out the index
deaths_country.index
Plotting number of reported deaths by country¶
Let's now visualize the number of reported deaths:
# Plot time series of several countries of interest
deaths_country[poi].plot(figsize=(20, 10), linewidth=2, marker='.', colormap='brg', fontsize=20)
plt.xlabel('Date', fontsize=20);
plt.ylabel('Number of Reported Deaths', fontsize=20);
plt.title('Reported Deaths Time Series', fontsize=20);
Now on a semi-log plot:
# Plot time series of several countries of interest
deaths_country[poi].plot(figsize=(20,10), linewidth=2, marker='.', fontsize=20, colormap='brg', logy=True)
plt.xlabel('Date', fontsize=20);
plt.ylabel('Number of Reported Deaths', fontsize=20);
plt.title('Reported Deaths Time Series', fontsize=20);
Aligning growth curves to start with day of number of known deaths ≥ 25¶
To compare what's happening in different countries, we can align each country's growth curves to all start on the day when the number of known deaths ≥ 25, such as reported in the first figure here. To achieve this, first off, let's set set all values less than 25 to NaN so that the associated data points don't get plotted at all when we visualize the data:
# Loop over columns & set values < 25 to None
for col in deaths_country.columns:
deaths_country.loc[(deaths_country[col] < 25), col] = None
# Check out tail
deaths_country.tail()
Now let's plot as above to make sure we see what we think we should see:
# Plot time series of several countries of interest
poi = ['China', 'US', 'Italy', 'France', 'Australia']
deaths_country[poi].plot(figsize=(20,10), linewidth=2, marker='.', colormap='brg', fontsize=20)
plt.xlabel('Date', fontsize=20);
plt.ylabel('Number of Reported Deaths', fontsize=20);
plt.title('Reported Deaths Time Series', fontsize=20);
The countries that have seen less than 25 total deaths will have columns of all NaNs now so let's drop these and then see how many columns we have left:
# Drop columns that are all NaNs (i.e. countries that haven't yet reached 25 deaths)
deaths_country.dropna(axis=1, how='all', inplace=True)
deaths_country.info()
As we're going to align the countries from the day they first had at least 25 deaths, we won't need the DateTimeIndex. In fact, we won't need the date at all. So we can
- Reset the Index, which will give us an ordinal index (which turns the date into a regular column) and
- Drop the date column (which will be called 'index) after the reset.
# drop index, sort date column
deaths_country_drop = deaths_country.reset_index().drop(['index'], axis=1)
deaths_country_drop.head()
Now it's time to shift each column so that the first entry is the first NaN value that it contains! To do this, we can use the shift()
method on each column. How much do we shift each column, though? The magnitude of the shift is given by how many NaNs there are at the start of the column, which we can retrieve using the first_valid_index()
method on the column but we want to shift up, which is negative in direction (by convention and perhaps intuition). SO let's do it.
# shift
for col in deaths_country_drop.columns:
deaths_country_drop[col] = deaths_country_drop[col].shift(-deaths_country_drop[col].first_valid_index())
# check out head
deaths_country_drop.head()
Side note: instead of looping over columns, we could have applied a lambda function to the columns of the dataframe, as follows:
# shift using lambda function
#deaths_country = deaths_country.apply(lambda x: x.shift(-x.first_valid_index()))
Now we get to plot our time series, first with linear axes, then semi-log:
# Plot time series
ax = deaths_country_drop.plot(figsize=(20,10), linewidth=2, marker='.', fontsize=20)
ax.legend(ncol=3, loc='upper right')
plt.xlabel('Days', fontsize=20);
plt.ylabel('Number of Reported Deaths', fontsize=20);
plt.title('Total reported coronavirus deaths for places with at least 25 deaths', fontsize=20);
# Plot semi log time series
ax = deaths_country_drop.plot(figsize=(20,10), linewidth=2, marker='.', fontsize=20, logy=True)
ax.legend(ncol=3, loc='upper right')
plt.xlabel('Days', fontsize=20);
plt.ylabel('Deaths Patients count', fontsize=20);
plt.title('Total reported coronavirus deaths for places with at least 25 deaths', fontsize=20);
Note: although we have managed to plot what we wanted, the above plots are challenging to retrieve any meaningful information from. There are too many growth curves so that it's very crowded and too many colours look the same so it's difficult to tell which country is which from the legend. Below, we'll plot less curves and further down in the notebook we'll use the python package Altair to introduce interactivity into the plot in order to deal with this challenge.
# Plot semi log time series
ax = deaths_country_drop.plot(figsize=(20,10), linewidth=2, marker='.', fontsize=20, logy=True)
ax.legend(ncol=3, loc='upper right')
plt.xlabel('Days', fontsize=20);
plt.ylabel('Deaths Patients count', fontsize=20);
plt.title('Total reported coronavirus deaths for places with at least 25 deaths', fontsize=20);
Summary: We've
- looked at the dataset containing the number of reported deaths for each region,
- wrangled the data to look at the number of reported deaths by country,
- plotted the number of reported deaths by country (both log and semi-log),
- aligned growth curves to start with day of number of known deaths ≥ 25.
Plotting number of recovered people¶
The third dataset in the Hopkins repository is the number of recovered. We want to do similar data wrangling as in the two cases above so we could copy and paste our code again but, if you're writing the same code three times, it's likely time to write a function.
# Function for grouping countries by region
def group_by_country(raw_data):
"""Returns data for countries indexed by date"""
# Group by
data = raw_data.groupby(by='Country/Region').sum().drop(['Lat', 'Long'], axis=1)
# Transpose
data = data.transpose()
# Set index as DateTimeIndex
datetime_index = pd.DatetimeIndex(data.index)
data.set_index(datetime_index, inplace=True)
return data
# Function to align growth curves
def align_curves(data, min_val):
"""Align growth curves to start on the day when the number of known deaths = min_val"""
# Loop over columns & set values < min_val to None
for col in data.columns:
data.loc[(data[col] < min_val), col] = None
# Drop columns with all NaNs
data.dropna(axis=1, how='all', inplace=True)
# Reset index, drop date
data = data.reset_index().drop(['index'], axis=1)
# Shift each column to begin with first valid index
for col in data.columns:
data[col] = data[col].shift(-data[col].first_valid_index())
return data
# Function to plot time series
def plot_time_series(df, plot_title, x_label, y_label, logy=False):
"""Plot time series and make looks a bit nice"""
ax = df.plot(figsize=(20,10), linewidth=2, marker='.', fontsize=20, logy=logy)
ax.legend(ncol=3, loc='lower right')
plt.xlabel(x_label, fontsize=20);
plt.ylabel(y_label, fontsize=20);
plt.title(plot_title, fontsize=20);
For a sanity check, let's see these functions at work on the 'number of deaths' data:
deaths_country_drop = group_by_country(raw_data_deaths)
deaths_country_drop = align_curves(deaths_country_drop, min_val=25)
plot_time_series(deaths_country_drop, 'Number of Reported Deaths', 'Days', 'Reported Deaths by Country', logy=True)
Now let's check use our functions to group, wrangle, and plot the recovered patients data:
# group by country and check out tail
recovered_country = group_by_country(raw_data_recovered)
recovered_country.tail()
# align curves and check out head
recovered_country_drop = align_curves(recovered_country, min_val=25)
recovered_country_drop.head()
Plot time series:
plot_time_series(recovered_country_drop, 'Recovered Patients Time Series', 'Days', 'Recovered Patients count')
plot_time_series(recovered_country_drop, 'Recovered Patients Time Series', 'Days', 'Recovered Patients count', True)
Note: once again, the above plots are challenging to retrieve any meaningful information from. There are too many growth curves so that it's very crowded and too many colours look the same so it's difficult to tell which country is which from the legend. Let's plot less curves and in the next section we'll use the python package Altair to introduce interactivity into such a plot in order to deal with this challenge.
plot_time_series(recovered_country_drop[poi], 'Recovered Patients Time Series', 'Days', 'Recovered Patients count', True)
Summary: We've
- looked at the dataset containing the number of reported recoveries for each region,
- written function for grouping, wrangling, and plotting the data,
- grouped, wrangled, and plotted the data for the number of reported recoveries.
Interactive plots with altair¶
We're now going to build some interactive data visualizations. I was recently inspired by this one in the NYTimes, a chart of confirmed number of deaths by country for places with at least 25 deaths, similar to ours above, but with informative hover tools. This one is also interesting.
We're going to use a tool called Altair. I like Altair for several reasons, including precisely what they state on their website:
With Altair, you can spend more time understanding your data and its meaning. Altair’s API is simple, friendly and consistent and built on top of the powerful Vega-Lite visualization grammar. This elegant simplicity produces beautiful and effective visualizations with a minimal amount of code.
Before jumping into Altair, let's reshape our deaths_country
dataset. Notice that it's currently in wide data format, with a column for each country and a row for each "day" (where day 1 is the first day with over 25 confirmed deaths). This worked with the pandas
plotting API for reasons discussed above.
# Look at head
deaths_country_drop.head()
For Altair, we'll want to convert the data into long data format. What this will do essentially have a row for each country/day pair so our columns will be 'Day', 'Country', and number of 'Deaths'. We do this using the dataframe method .melt()
as follows:
# create long data for deaths
deaths_long = deaths_country_drop.reset_index().melt(id_vars='index', value_name='Deaths').rename(columns={ 'index': 'Day' })
deaths_long.head()
We'll see the power of having long data when using Altair. Such transformations have been performed for a long time, however it wasn't until 2014 that Hadley Wickham formalized the language in his paper Tidy Data. Note that Wickham prefers to avoid the terms long and wide because, in his words, 'they are imprecise'. I generally agree but for our purposes here of giving the flavour, they suffice.
Now having transformed our data, let's import Altair and get a sense of its API.
import altair as alt
# altair plot
alt.Chart(deaths_long).mark_line().encode(
x='Day',
y='Deaths',
color='Country/Region'
)
It is nice to be able to build such an informative and elegant chart in four lines of code (which is also elegant). And, looking at the simplicity of the code we just wrote, we can see why it was great to have long data: a column for each variable allowed us to explicitly and easily tell Altair what we wanted on each axis and what we wanted for the colour.
As the Altair documentation (which is great, by the way!) states,
The key idea is that you are declaring links between data columns and visual encoding channels, such as the x-axis, y-axis, color, etc. The rest of the plot details are handled automatically. Building on this declarative plotting idea, a surprising range of simple to sophisticated plots and visualizations can be created using a relatively concise grammar.
We can now customize the code to thicken the line width, to alter the opacity, and to make the chart larger:
# altair plot
alt.Chart(deaths_long).mark_line(strokeWidth=4, opacity=0.7).encode(
x='Day',
y='Deaths',
color='Country/Region'
).properties(
width=800,
height=650
)
We can also add a log y-axis. To do this, The long-form, we express the types using the long-form alt.X('Day',...)
, which is, in the words of the Altair documentation
useful when doing more fine-tuned adjustments to the encoding, such as binning, axis and scale properties, or more.
We'll also now add a hover tooltip so that, when we hover our cursor over any point on any of the lines, it will tell us the 'Country', the 'Day', and the number of 'Deaths'.
# altair plot
alt.Chart(deaths_long).mark_line(strokeWidth=4, opacity=0.7).encode(
x=alt.X('Day'),
y=alt.Y('Deaths', scale=alt.Scale(type='log')),
color='Country/Region',
tooltip=['Country/Region', 'Day', 'Deaths']
).properties(
width=800,
height=650
)
It's great that we could add that useful hover tooltip with one line of code tooltip=['Country/Region', 'Day','Deaths']
, particularly as it adds such information rich interaction to the chart.
One useful aspect of the NYTimes chart was that, when you hovered over a particular curve, it made it stand out against the other. We're going to do something similar here: in the resulting chart, when you click on a curve, the others turn grey.
Note: When first attempting to build this chart, I discovered here that "multiple conditional values in one encoding are not allowed by the Vega-Lite spec," which is what Altair uses. For this reason, we build the chart, then an overlay, and then combine them.
# Selection tool
selection = alt.selection_single(fields=['Country/Region'])
# Color change when clicked
color = alt.condition(selection,
alt.Color('Country/Region:N'),
alt.value('lightgray'))
# Base altair plot
base = alt.Chart(deaths_long).mark_line(strokeWidth=4, opacity=0.7).encode(
x=alt.X('Day'),
y=alt.Y('Deaths', scale=alt.Scale(type='log')),
color='Country/Region',
tooltip=['Country/Region', 'Day','Deaths']
).properties(
width=800,
height=650
)
# Chart
chart = base.encode(
color=alt.condition(selection, 'Country/Region:N', alt.value('lightgray'))
).add_selection(
selection
)
# Overlay
overlay = base.encode(
color='Country/Region',
opacity=alt.value(0.5),
tooltip=['Country/Region:N', 'Name:N']
).transform_filter(
selection
)
# Sum em up!
chart + overlay
It's not super easy to line up the legend with the curves on the chart so let's put the labels on the chart itself. Thanks to Jake Vanderplas for this suggestion, and for the code.
# drop NaNs
deaths_long = deaths_long.dropna()
# Selection tool
selection = alt.selection_single(fields=['Country/Region'])
# Color change when clicked
color = alt.condition(selection,
alt.Color('Country/Region:N'),
alt.value('lightgray'))
# Base altair plot
base = alt.Chart(deaths_long).mark_line(strokeWidth=4, opacity=0.7).encode(
x=alt.X('Day'),
y=alt.Y('Deaths', scale=alt.Scale(type='log')),
color=alt.Color('Country/Region', legend=None),
).properties(
width=800,
height=650
)
# Chart
chart = base.encode(
color=alt.condition(selection, 'Country/Region:N', alt.value('lightgray'))
).add_selection(
selection
)
# Overlay
overlay = base.encode(
color='Country/Region',
opacity=alt.value(0.5),
tooltip=['Country/Region:N', 'Name:N']
).transform_filter(
selection
)
# Text labels
text = base.mark_text(
align='left',
dx=5,
size=10
).encode(
x=alt.X('Day', aggregate='max', axis=alt.Axis(title='Day')),
y=alt.Y('Deaths', aggregate={'argmax': 'Day'}, axis=alt.Axis(title='Reported Deaths')),
text='Country/Region',
).transform_filter(
selection
)
# Sum em up!
chart + overlay + text
Summary: We've
- melted the data into long format,
- used Altair to make interactive plots of increasing richness,
- admired the elegance & simplicity of the Altair API and the visualizations produced.
That's all for the time being. I'd be interested to see how you all can make these charts more information rich and comprehensible. I encourage you to raise ideas in issues on the issue tracker in this github repository and then to make pull requests. A couple of ideas are
- Adding lines to the above chart that show curves for deaths doubling each X days, as in the first chart here,
- Figuring out a way to make the chart less crowded with names by perhaps only showing 10 of them.