Overview
Questions
How do you create appealing plots in Python?
How do you compare distributions of data?
How do different plotting libraries work?
Objectives
Examples of
seaborn
for Python.Make box plots, violin plots, and histograms in
seaborn
.Some examples of plots using
bokeh
andggplot
.
Visualization is meant to convey information.
The power of a graph is its ability to enable one to take in the quantitative information, organize it, and see patterns and structure not readily revealed by other means of studying the data.
- Cleveland and McGill, 1984
matplotlib
Now, we’ll start learning how to create visualizations in Python. We’ll start by using a popular python package called matplotlib, and later on use a second package called seaborn that builds on matplotlib First we’ll import matplotlib, along with several other Python packages.
import math
import random
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
One of the nice features of Jupyter notebooks is that figures can be plotted inline, which means they appear below the code cell that creates them. This is not the default behavior however, and so we’ll use the below “Magic” statement to tell the Jupyter notebook to plot the figures inline (instead of opening a separate browser window to show them).
%matplotlib inline
To illustrate what matplotlib can do, we’ll need to use a dataset. We’ve decided to use the so-called Gapminder dataset, which was compiled by Jennifer Bryan. This dataset contains several key demographic and economic statistics for many countries, across many years. For more information, see the gapminder repository.
We’ll use the pandas Python package to load the .csv (comma separated values) file that contains the dataset. The pandas package provides DataFrame objects that organize datasets in tabular form (think Microsoft Excel spreedsheets). We’ve created an alias for the pandas package, called pd, which is the convention. To read in a .csv file we simply use pd.read_csv
. This .csv file happens to be tab-delimited, so we need to specify sep=\t
.
df = pd.read_csv('data/gapminder.tsv', sep='\t')
To look at the first few rows of the dataset we’ll use the .head()
method of the dataframe.
df.head()
population fertility ... child_mortality Region
0 34811059.0 2.73 ... 29.5 Middle East & North Africa
1 19842251.0 6.43 ... 192.0 Sub-Saharan Africa
2 40381860.0 2.24 ... 15.4 America
3 2975029.0 1.40 ... 20.0 Europe & Central Asia
4 21370348.0 1.96 ... 5.2 East Asia & Pacific
[5 rows x 10 columns]
It looks like we have information about life expectancy (lifeExp), population (pop) and per-capita GDP (gdpPercap), across multiple years per country.
To start off, let’s say we want to explore the data from the most recent year in the dataset.
latest_year = df['year'].max()
latest_year
2007
df_latest = df[df['year']==latest_year]
df_latest.shape
(142, 6)
Ok, looks like we have 142 values, or rows, across our 6 variables, or columns. Let’s get an idea of how per-capita GDP was distributed across all of the countries during 2007 by calculating some summary statistics. We’ll do that using the DataFrame’s .describe()
method.
df_latest['gdpPercap'].describe()
count 142.000000
mean 11680.071820
std 12859.937337
min 277.551859
25% 1624.842248
50% 6124.371109
75% 18008.835640
max 49357.190170
Name: gdpPercap, dtype: float64
Across 142 countries the mean GDP was ~$11680, and the standard deviation was ~$12860! There was a lot of deviation in GDP across countries, but these summary statistics don’t give us the whole picture. To get the whole picture, let’s draw a picture! Or plot a figure more accurately.
Histograms plot an discretized distribution of a one-dimensional dataset across all the values it has taken. They visualize how the many data points are in each of bins, each of which has a pre-defined range.
To create a histogram plot in matplotlib we’ll use pyplot, which is a collection of command style functions that make matplotlib work like MATLAB and save many lines of repeated code. By convention, pyplot is aliased to plt, which we’ve alread done in the above import cell.
Let’s use plt.hist()
to create a histogram of the per-capita GDP in 2007.
plt.hist(df_latest['gdpPercap']);
Protip: Use a semicolon (;) at the end of the last line in a Jupyter notebook cell to suppress the notebooks from printing the return value of the last line. This was done in the above cell. Try removing the ; to see how the output changes.
This histogram tells us that many of the countries had a low GDP, which was less than 5,000. There is also a second “bump” in the histrogram around 30,000. This type of distribution is known as bi-modal, since there are two modes, or common values.
To make this histogram more interpretable let’s add a title and labels for the and axes. We’ll pass strings to plt.title()
, plt.xlabel(), and plt.ylabel()
to do so.
plt.hist(df_latest['gdpPercap']);
plt.title('Distribution of Global Per-Capita GDP in 2007')
plt.xlabel('Per-Capita GDP (Millions of USD)')
plt.ylabel('# of Countries');
Each line in the histogram represents a bin. The height of the line represents the number of items (countries in this case) within the range of values spanned by the bin. In the last plots we used the default number of bins (10), now let’s use more bins by specifying the bin=20 parameter.
plt.hist(df_latest['gdpPercap'], bins=30);
plt.title('Distribution of Global Per-Capita GDP in 2007')
plt.xlabel('Per-Capita GDP (Millions of USD)')
plt.ylabel('# of Countries');
We can see this histogram doesn’t look as “smooth” as the last one. There’s no “right” way to display a histogram, but some bin counts definitely are more informative than others. For example, using only 3 bins we cannot see the bi-modal nature of the GDP distribution
plt.hist(df_latest['gdpPercap'], bins=3);
plt.title('Distribution of Global Per-Capita GDP in 2007')
plt.xlabel('Per-Capita GDP (Millions of USD)')
plt.ylabel('# of Countries');
The styling that you see in the plot above are the matplotlib defaults. For now, we’ll continue to use those defaults as the first portion of this workshop is geared toward getting you familiar with the API-how to actually create the plots you’re interested in. Then, we’ll cover how to customize plot styles in a later section.)
As you can see, we can call functions in the plt module multiple times within a single cell and those functions will all work on, and modify, the current figure associated with the current cell. This is because pyplot
(or plt
) keeps an internal variable for the current figure which is unique to each cell plt is used in.
NOTE Unless specified in the help function, the order of these function calls doesn’t matter. See that the cell below produces the same plot as the one above even though the calls to plt
functions are in a different order.
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import statsmodels.api as sm
# Generate some random data for demonstration
#data = np.random.normal(loc=0, scale=1, size=1000)
data =df_latest['gdpPercap']
# Create subplots
fig, axs = plt.subplots(1, 3, figsize=(18, 6)) # Adjusted figsize for better display
# Plot box plot
axs[0].boxplot(data, vert=False)
axs[0].set_title('Box Plot')
# Plot histogram with KDE
sns.histplot(data, kde=True, color='skyblue', ax=axs[1])
axs[1].set_title('Histogram with Density Plot')
# Plot Q-Q plot
sm.qqplot(data, line='45', ax=axs[2])
axs[2].set_title('Q-Q Plot')
plt.tight_layout()
plt.show()
The above code aims to provide a visual representation of the distribution and quantiles of the data using different types of plots for analysis and visualization. It is attempting to visualize a dataset using three different types of plots: a box plot, a histogram with kernel density estimation (KDE), and a Q-Q (quantile-quantile) plot.
these plots collectively provide insights into the distributional characteristics, presence of outliers, and conformity to theoretical distributions of the dataset. They are valuable tools for exploratory data analysis, helping researchers and analysts understand the underlying structure and patterns within the data.
Next, it might be interesting to get a sense of how many countries per continent we have data for. Let’s create a country DataFrame that includes only the unique combinations of country and continent by dropping all duplicate rows using drop_duplicates()
countries = df[['country', 'continent']]
countries = countries.drop_duplicates()
countries
country continent
0 Afghanistan Asia
12 Albania Europe
24 Algeria Africa
36 Angola Africa
48 Argentina Americas
... ... ...
1644 Vietnam Asia
1656 West Bank and Gaza Asia
1668 Yemen, Rep. Asia
1680 Zambia Africa
1692 Zimbabwe Africa
[142 rows x 2 columns]
To get the number of countries per continent we’ll use the .groupby() method to group by continent, then count the unique countries in each continent. We’ll use the as_index=False argument so that the contininent name gets it’s own column, and is not used as the index. This will create a new DataFrame that we’ll call country_counts.
country_counts = countries.groupby('continent', as_index=False)['country'].count()
country_counts
continent country
0 Africa 52
1 Americas 25
2 Asia 33
3 Europe 30
4 Oceania 2
import seaborn as sns
import matplotlib.pyplot as plt
# Assuming you have already calculated country_counts
country_counts = countries.groupby('continent', as_index=False)['country'].count()
# Sort the DataFrame by the count of countries
country_counts_sorted = country_counts.sort_values(by='country', ascending=False)
# Create horizontal bar plot with sorted order
plt.figure(figsize=(10, 6))
sns.barplot(data=country_counts_sorted, y='continent', x='country', color='skyblue')
plt.xlabel('Number of Countries')
plt.ylabel('Continent')
plt.title('Number of Countries in Each Continent')
plt.show()
A quick way to analyze the distribution of numerical columns in a DataFrame is to calculate summary statistics (skewness and kurtosis), creates histograms to visualize the distribution, uses box plots to identify outliers, and generates normal probability plots to assess normality. This information can be useful for understanding the characteristics of the data and determining appropriate statistical methods for further analysis.
from scipy.stats import probplot
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
# Set use_inf_as_na to True to handle infinite values
pd.set_option('mode.use_inf_as_na', True)
# Assuming df is your DataFrame
for col in df.select_dtypes(np.number).columns:
plt.figure(figsize=(14, 4))
print(f"Skewness of {col}:", df[col].skew())
print(f"Kurtosis of {col}:", df[col].kurtosis())
plt.subplot(131)
sns.histplot(df[col], kde=True)
plt.subplot(132)
sns.boxplot(data=df, x=col)
plt.subplot(133)
probplot(df[col], dist='norm', rvalue=True, plot=plt)
plt.suptitle(col)
plt.show()
A count plot, also known as a bar plot, is a type of plot that displays the count or frequency of each category in a categorical variable. In this case, the ‘continent’ column is a categorical variable, and the count plot shows the number of occurrences for each continent in the DataFrame
import matplotlib.pyplot as plt
import seaborn as sns
import gc
plt.figure(figsize=(10,6))
sns.countplot(data=df, x='continent')
plt.xticks(rotation=90)
plt.show()
plt.close('all')
gc.collect()
import gc
import seaborn as sns
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 10))
sns.barplot(data=df_africa, y='country', x=df_africa.index,
estimator=len, orient='h')
plt.xticks(rotation=90)
plt.ylabel('Country')
plt.xlabel('Count')
plt.title('Count of Data Entries per Country')
plt.show()
plt.close('all')
gc.collect()
A heatmap is a graphical representation of data where the individual values contained in a matrix are represented as colors. It is a powerful tool for visualizing and analyzing complex data, especially when dealing with large datasets. The heatmap provides a visual representation of the correlations between numerical columns in the DataFrame, helping to identify relationships and patterns in the data.
import matplotlib.pyplot as plt
import seaborn as sns
import gc
# Filter out only the numerical columns
numerical_df = df.select_dtypes(include=['float64', 'int64'])
# Drop NaN values along columns in the numerical DataFrame
numerical_df = numerical_df.dropna(axis=1)
plt.figure(figsize=(12, 8))
fig = sns.heatmap(numerical_df.corr(), annot=True, cmap='rainbow', vmin=-1.0, vmax=1.0)
plt.xticks(rotation=90)
plt.show()
plt.close('all')
gc.collect()
seaborn
Python has powerful built-in plotting capabilities and
for this exercise, we will focus on using the seaborn
package, which facilitates the creation of highly-informative plots of
structured data.
The seaborn
library is built on matplotlib
and features very nice color palettes. This library makes manipulating the features of a matplotlib
plot somewhat easier. For an introduction to this package, see the seaborn
documentation.
Create a new Jupyter notebook for this lesson and
begin by importing the Pandas package and seaborn. We also need to import the matplotlib plotting functions pyplot
in order to directly call any of the plotting functions for matplotlib
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
Here, we have given seaborn the alias sns
.
Now load the DataFrame:
surveys_complete = pd.read_csv('surveys_complete.csv', index_col=0)
For this exercise, we will use a different version of the ecological dataset from our previous lessons. This version has only the complete observations, without missing entries for any of the columns. There are also a few additional columns providing information about the taxonomy of the species collected and the type of plot/enclosure.
Let’s start with a basic scatterplot. We’ll simply plot the weight on the horizontal axis and the hind foot length on the vertical axis. This uses the seaborn function .lmplot()
. This function can take a Pandas DataFrame directly. It also will fit a regression line, by default. Since we may not want to visualize these data with a regression line, we will use the fit_reg=False
argument.
sns.lmplot(x="weight", y="hindfoot_length", data=surveys_complete, fit_reg=False)
Out of the box, seaborn will plot with a given set of aesthetics. We may want to change the label font size and make the background gray. For this, we can call the .set()
method. If this method is called without any arguments, sns.set()
, this will switch on the seaborn defaults, which changes the background color to gray and overlays white grid lines (i.e., a similar background to ggplot2 in R). We can add an argument to also change the font size for all of our downstream figures:
sns.set(font_scale=1.5)
This will increase the font size of the labels by 50% in all of our subsequent plots.
sns.lmplot(x="weight", y="hindfoot_length", data=surveys_complete, fit_reg=False)
The plot size is small, by default, and with seaborn, there isn’t a way to change the size of every plot with a single function. Further, this is done differently depending on the plot function we use because some functions return matplotlib objects and others are drawn using a grid object.
For .lmplot()
, we can set the plot size directly as arguments in the function call using the size
and aspect
arguments. These arguments control the height and width of the object, respectively.
sns.lmplot(x="weight", y="hindfoot_length", data=surveys_complete, fit_reg=False, height=8, aspect=1.5)
One issue with this plot is that because we have a large dataset, it is difficult to adequately visualize the points on our graph. This is called “overplotting”. One way to avoid this is to change the size of the marker. Here we use the argument scatter_kws
that takes a dictionary of keywords and values that are passed to the matplotlib scatter
function.
sns.lmplot(x="weight", y="hindfoot_length", data=surveys_complete, fit_reg=False, height=8, aspect=1.5, scatter_kws={"s": 5})
Alternatively, we can change the marker type using the markers
keyword argument. By default, lmplot()
sets markers='8'
. If we want to change the marker, we must use the correct matplotlib code. Let’s change the markers to X-es:
sns.lmplot(x="weight", y="hindfoot_length", data=surveys_complete, fit_reg=False, height=8, aspect=1.5, markers='x')
Another way to avoid overplotting is to use transparency so that regions of the plot with many points are darker. This is achieved using the scatter_kws
argument and setting the alpha value. (This is using alpha compositing.)
sns.lmplot(x="weight", y="hindfoot_length", data=surveys_complete, fit_reg=False, height=8, aspect=1.5, scatter_kws={'alpha':0.3})
We can also specify that the species_id
labels indicate categories that determine a point’s color:
sns.lmplot(x="weight", y="hindfoot_length", data=surveys_complete,
fit_reg=False, height=8, aspect=1.5, scatter_kws={'alpha':0.3,"s": 50},
hue='species_id', markers='D')
There are different ways to set figure properties using seaborn. For .lmplot()
, we can create a figure variable and use a member method of that variable to set the axis labels.
my_fig = sns.lmplot(x="weight", y="hindfoot_length", data=surveys_complete,
fit_reg=False, height=8, aspect=1.5,
scatter_kws={'alpha':0.3,"s": 50},
hue='species_id', markers='D')
my_fig.set_axis_labels('Weight (g)', 'Hindfoot Length (mm)')
Scatter plot for a single plot
How would you plot the hind foot length and weight for just a single plot? Remember how to access subsets of a DataFrame based on conditional criteria? Plot the scatter plot above for only plot number
12
and color bysex
. (Make the markers larger circles.)Solution
my_fig = sns.lmplot(x="weight", y="hindfoot_length", data=surveys_complete[surveys_complete.plot_id == 12], fit_reg=False, height=8, aspect=1.5, scatter_kws={'alpha':0.3,"s": 200}, hue='sex', markers='8') my_fig.set_axis_labels('Weight (g)', 'Hindfoot Length (mm)')
We often like to compare the distributions of values across different categorical variables. Box plots and violin plots are often used to do this in a simple way.
In seaborn, both the .boxplot()
and .violinplot()
functions return matplotlib Axes objects. Thus, these plot functions do not have arguments for height
and aspect
like the scatter plot function above.
In order to change the size of these plots, we must create a matplotlib figure and axes and set the dimensions of the figure. First, let’s decide on the figure size we will use for the rest of our seaborn plots.
We can put these dimensions into a tuple variable called plot_dims
.
plot_dims = (14, 9)
Now, every time we create a new plot that returns a matplotlib Axes object, we can call the .subplots()
function to change some of the aesthetics. Because of the iteritve way in which figures are built using matplotlib, we need to always execute our creation of the figure and the plot function in a single notebook cell.
fig, ax = plt.subplots(figsize=plot_dims)
sns.boxplot(x='species_id', y='hindfoot_length', data=surveys_complete)
ax.set(xlabel='Species ID', ylabel='Hindfoot Length (mm)')
Note that for the .boxplot()
function, we can simply set the x and y values by giving the column names from our DataFrame. Then we must use the data=surveys_complete
argument to indicate our DataFrame object.
Now let’s use a violin plot to visualize the weight data. For these data, we would also like the weight to be on the log10 scale. We will make this plot horizontal, so we’ll set the horizontal axis to be on the log scale.
An easy way to do this is to create a new graph variable from our .violinplot()
function and then call the .set_xscale()
that is a member method of the graph variable.
fig, ax = plt.subplots(figsize=plot_dims)
g = sns.violinplot(x='weight', y='species_id', data=surveys_complete,
linewidth=0.2, orient="h", cut=0)
g.set_xscale('log', base=10)
ax.set(ylabel='Species ID', xlabel='Weight (g)')
Violin plot for a single species
Use seaborn to make a violin plot comparing the relative distributions of weight measurements for different sexed animals from a single species, Onychomys leucogaster (OL), one of the coolest rodent species:
Solution
fig, ax = plt.subplots(figsize=plot_dims) sns.violinplot(x='sex', y='weight', data=surveys_complete[surveys_complete.species_id == 'OL'])
Often, a histogram is a better way to visualize a distribution. This is relatively simple using seaborn’s .displot()
function. This function does not take a Pandas DataFrame, but can take a Pandas Series (i.e., column in our DataFrame). This function also can take come other arguments like color
, which takes values that specify the color of histogram based on matplotlib’s colors. We will make our plot cyan using the 'c'
color code. Additionally, we will specify bins=50
so that our histogram is not plotted using too many or too few bins.
sns.displot(surveys_complete['weight'], color='c',bins=50,
height=8, aspect=1.5)
By default, the .displot()
function plots the density as a histogram, alternatively, you can plot a kernel density estimate by setting the kind
argument to kind="kde"
(the default is kind="hist"
) and removing the bins
argument.
sns.displot(surveys_complete['weight'], color='k', kind="kde",
height=8, aspect=1.5)
bokeh
Another library called bokeh
can create amazing, interactive graphics using D3.js (javascript). This package is also easy to install with conda
:
$ conda install bokeh
Return to your Jupyter notebook and import the necessary bokeh plotting tools:
from bokeh.plotting import figure
from bokeh.io import output_notebook, show
Here we are importing functions to create a figure and to output that figure to our Jupyter notebook.
We can now execute the output_notebook()
function that will ensure that our Javascript images are displayed in our html notebook.
output_notebook()
This figure also uses the histogram function from numpy. So we have to first import that package:
import numpy as np
We can reproduce the histogram of the weights of all our observations.
hist, edges = np.histogram(surveys_complete['weight'], density=True, bins=100)
my_fig = figure(title="Weight (g)",background_fill_color="#EBC8EB")
my_fig.quad(top=hist, bottom=0, left=edges[:-1], right=edges[1:],fill_color="#036564", line_color="#033649")
show(my_fig)
We can also save this file as an html file that we can share with others or embed in files on the web. To do this, we need to include some other bokeh functions:
from bokeh.resources import CDN
from bokeh.embed import file_html
And create an html object and save it to a file using standard Python file i/o:
html = file_html(my_fig, CDN, "Weight Histogram")
out = open('weight_hist_bokeh.html','w')
out.write(html)
out.close()
If you open the file you created in your web browser, you now have an interactive version of your figure. You can then use this to embed in a website.
from bokeh.plotting import figure, show
from bokeh.models import HexTile
import pandas as pd
# Create Bokeh plot
p = figure(width=800, height=400)
# Create HexTile glyph
p.hex_tile(q='weight', r='hindfoot_length',
source=surveys_complete, size=20, fill_color='blue',
line_color='white')
# Label axes
p.xaxis.axis_label = 'Weight'
p.yaxis.axis_label = 'Hindfoot Length'
show(p)
Take-Home Challenge: More Visualizations in Python
Continue to use
seaborn
andbokeh
to try out different visualizations.
Produce a histogram of the hindfoot length for only the species observed in the genus Peromyscus. Note that this particular dataset includes a column called
"genus"
. So you can subset out only the entries for Peromyscus individuals. Try making the histogram using bothseaborn
andbokeh
. Consider altering the bin numbers to help visualize the distribution.This dataset also has a column called
"plot_type"
. That says what kind of plot in which the animal was captured. The different plot types are specific experimental set-ups for this kind of study. Create a visualization that helps us better understand how the species ("species_id"
) might be influenced by the plot type. For example, are some species prevented from entering certain plot types?Solutions
The solutions will be posted in 4 days. Feel free to use the
#scripting_help
channel in Slack to discuss these exercises.
Key Points
There are many tools for plotting in Python.