plotnine.geoms.geom_segment¶
- class plotnine.geoms.geom_segment(mapping=None, data=None, **kwargs)[source]¶
Line segments
Usage
geom_segment(mapping=None, data=None, stat='identity', position='identity', na_rm=False, inherit_aes=True, show_legend=None, raster=False, lineend='butt', arrow=None, **kwargs)
Only the
data
andmapping
can be positional, the rest must be keyword arguments.**kwargs
can be aesthetics (or parameters) used by thestat
.- Parameters
- mapping
aes
, optional Aesthetic mappings created with
aes()
. If specified andinherit.aes=True
, it is combined with the default mapping for the plot. You must supply mapping if there is no plot mapping.Aesthetic
Default value
x
xend
y
yend
alpha
1
color
'black'
group
linetype
'solid'
size
0.5
The bold aesthetics are required.
- data
dataframe
, optional The data to be displayed in this layer. If
None
, the data from from theggplot()
call is used. If specified, it overrides the data from theggplot()
call.- stat
str
or stat, optional (default:stat_identity
) The statistical transformation to use on the data for this layer. If it is a string, it must be the registered and known to Plotnine.
- position
str
or position, optional (default:position_identity
) Position adjustment. If it is a string, it must be registered and known to Plotnine.
- na_rmbool, optional (default:
False
) If
False
, removes missing values with a warning. IfTrue
silently removes missing values.- inherit_aesbool, optional (default:
True
) If
False
, overrides the default aesthetics.- show_legendbool or
dict
, optional (default:None
) Whether this layer should be included in the legends.
None
the default, includes any aesthetics that are mapped. If abool
,False
never includes andTrue
always includes. Adict
can be used to exclude specific aesthetis of the layer from showing in the legend. e.gshow_legend={'color': False}
, any other aesthetic are included by default.- rasterbool, optional (default:
False
) If
True
, draw onto this layer a raster (bitmap) object even ifthe final image is in vector format.- lineend
str
(default:butt
) Line end style, of of butt, round or projecting. This option is applied for solid linetypes.
- arrow
plotnine.geoms.geom_path.arrow
(default:None
) Arrow specification. Default is no arrow.
- mapping
See also
plotnine.geoms.geom_path.arrow
for adding arrowhead(s) to segments.
Examples¶
[1]:
# NOTE: This notebook only works in Python 3
# and uses the plydata package for
# data manipulation.
import pandas as pd
import pandas.api.types as pdtypes
import numpy as np
from plotnine import *
from plydata import *
Ranges of Similar Variables¶
Comparing the point to point difference of many similar variables
Read the data.
Source: Pew Research Global Attitudes Spring 2015
[2]:
data = pd.read_csv('data/survey-social-media.csv')
data = (
data
>> rename(
country='COUNTRY',
gender='Q145',
age='Q146',
use_internet='Q70',
use_social_media='Q74'
)
>> select('PSRAID', 'gender', 'use_internet', drop=True)
)
data >> sample_n(10, random_state=123)
[2]:
country | age | use_social_media | |
---|---|---|---|
11376 | Vietnam | 48 | Yes |
12937 | United States | 51 | Yes |
19440 | Pakistan | 44 | |
30665 | Malaysia | 46 | |
37003 | Israel | 63 | |
19271 | Pakistan | 18 | Yes |
30445 | Malaysia | 40 | Yes |
16782 | Poland | 72 | |
39999 | China | 61 | |
8041 | Kenya | 35 | Yes |
Some helper functions
[3]:
def format_sequence(s, fmt='{}'):
"""
Format items sequence
Useful for creating labels from numeric data
Parameters
----------
s : sequence
List of values
Returns
-------
out : list
List of string values
"""
return [fmt.format(x) for x in s]
Create age groups for users of social media
[4]:
def toint(s):
res = []
for x in s:
try:
x = int(x)
except ValueError:
x = -1
res.append(x)
return res
def percent(gdf):
return 100 * np.sum(gdf['use_social_media'] == 'Yes') / len(gdf)
no_age = ['Refused', 'Don\'t know']
yes_no = ['Yes', 'No']
ages = ['18-34', '35-49', '50+']
def first(s):
return s.iloc[0]
rdata = (
data
>> define(age_='toint(age)')
>> define(age_group=case_when([
('age_ == -1', '""'),
('age_ < 35', '"18-34"'),
('age_ < 50', '"35-49"'),
('age_ >= 50', '"50+"'),
]))
>> select('age_', drop=True)
>> group_by('country')
>> define(country_count='n()')
>> query('age_group in @ages')
>> query('use_social_media in @yes_no')
>> group_by('country', 'age_group')
>> do(
# social media use percentage
sm_use_percent=lambda gdf: 100 * sum(gdf['use_social_media'] == 'Yes') / len(gdf),
# social media question response rate
smq_response_rate=lambda gdf: len(gdf) / gdf['country_count'].iloc[0]
)
)
rdata >> ungroup() >> head(9)
[4]:
country | age_group | sm_use_percent | smq_response_rate | |
---|---|---|---|---|
0 | Ethiopia | 18-34 | 77.777778 | 0.081000 |
1 | Ethiopia | 35-49 | 61.538462 | 0.013000 |
2 | Ethiopia | 50+ | 66.666667 | 0.003000 |
3 | South Korea | 35-49 | 57.246377 | 0.274627 |
4 | South Korea | 18-34 | 75.147929 | 0.336318 |
5 | South Korea | 50+ | 33.532934 | 0.332338 |
6 | Spain | 50+ | 47.422680 | 0.388000 |
7 | Spain | 18-34 | 86.842105 | 0.228000 |
8 | Spain | 35-49 | 67.777778 | 0.270000 |
Top 14 countries by response rate to the social media question.
[5]:
n = 14
top = (
rdata
>> group_by('country')
>> summarize(r='sum(smq_response_rate)')
>> arrange('-r')
>> head(n)
)
top_countries = top['country']
point_data = (
rdata
>> query('country in @top_countries')
# Format the floating point data that will be plotted into strings
>> define(sm_use_percent_str=if_else(
'country == "France"',
'format_sequence(sm_use_percent, "{:.0f}%")',
'format_sequence(sm_use_percent, "{:.0f}")')
)
>> ungroup()
)
point_data >> head(6)
[5]:
country | age_group | sm_use_percent | smq_response_rate | sm_use_percent_str | |
---|---|---|---|---|---|
0 | South Korea | 35-49 | 57.246377 | 0.274627 | 57 |
1 | South Korea | 18-34 | 75.147929 | 0.336318 | 75 |
2 | South Korea | 50+ | 33.532934 | 0.332338 | 34 |
3 | Spain | 50+ | 47.422680 | 0.388000 | 47 |
4 | Spain | 18-34 | 86.842105 | 0.228000 | 87 |
5 | Spain | 35-49 | 67.777778 | 0.270000 | 68 |
[6]:
segment_data = (
point_data
>> group_by('country')
>> summarize(min='min(sm_use_percent)', max='max(sm_use_percent)')
>> define(gap='max-min')
>> arrange('-gap')
# Format the floating point data that will be plotted into strings
>> define(
min_str='format_sequence(min, "{:.0f}")',
max_str='format_sequence(max, "{:.0f}")',
gap_str='format_sequence(gap, "{:.0f}")',
)
)
segment_data >> head()
[6]:
country | min | max | gap | min_str | max_str | gap_str | |
---|---|---|---|---|---|---|---|
0 | France | 30.564784 | 84.824903 | 54.260119 | 31 | 85 | 54 |
1 | Germany | 29.384966 | 79.646018 | 50.261052 | 29 | 80 | 50 |
2 | Japan | 30.794702 | 78.735632 | 47.940930 | 31 | 79 | 48 |
3 | Australia | 48.380952 | 90.862944 | 42.481992 | 48 | 91 | 42 |
4 | South Korea | 33.532934 | 75.147929 | 41.614995 | 34 | 75 | 42 |
Format the floating point data that will be plotted into strings
Set the order of the countries along the y-axis by setting the country
variable to an ordered categorical.
[7]:
# Create a verb to order the countries. You can apply this verb to any
# dataframe with a country column
categories = list(segment_data['country'])[::-1]
ordered_country_dtype = pdtypes.CategoricalDtype(categories=categories, ordered=True)
order_country = define(country='country.astype(ordered_country_dtype)')
segment_data >>= order_country
point_data >>= order_country
First plot
[8]:
# The right column (youngest-oldest gap) location
xgap = 112
(ggplot()
# Range strip
+ geom_segment(
segment_data,
aes(x='min', xend='max', y='country', yend='country'),
size=6,
color='#a7a9ac'
)
# Age group markers
+ geom_point(
point_data,
aes('sm_use_percent', 'country', color='age_group', fill='age_group'),
size=5,
stroke=0.7,
)
# Age group percentages
+ geom_text(
point_data >> query('age_group=="50+"'),
aes(x='sm_use_percent-2', y='country', label='sm_use_percent_str', color='age_group'),
size=8,
ha='right',
)
+ geom_text(
point_data >> query('age_group=="35-49"'),
aes(x='sm_use_percent+2', y='country', label='sm_use_percent_str'),
size=8,
ha='left',
va='center',
color='white'
)
+ geom_text(
point_data >> query('age_group=="18-34"'),
aes(x='sm_use_percent+2', y='country', label='sm_use_percent_str', color='age_group'),
size=8,
ha='left',
)
# gap difference
+ geom_text(
segment_data,
aes(x=xgap, y='country', label='gap_str'),
size=9,
fontweight='bold',
format_string='+{}'
)
)

[8]:
<ggplot: (97654321012345679)>
Tweak it
[9]:
# The right column (youngest-oldest gap) location
xgap = 112
# Gallery Plot
(ggplot()
# Background Strips # new
+ geom_segment(
segment_data,
aes(y='country', yend='country'),
x=0, xend=100,
size=8.5,
color='#edece3'
)
# vertical grid lines along the strips # new
+ annotate(
'segment',
x=list(range(10, 100, 10)) * n,
xend=list(range(10, 100, 10)) * n,
y=np.tile(np.arange(1, n+1), 9)-.25,
yend=np.tile(np.arange(1, n+1), 9) + .25,
color='#CCCCCC'
)
# Range strip
+ geom_segment(
segment_data,
aes(x='min', xend='max', y='country', yend='country'),
size=6,
color='#a7a9ac'
)
# Age group markers
+ geom_point(
point_data,
aes('sm_use_percent', 'country', color='age_group', fill='age_group'),
size=5,
stroke=0.7,
)
# Age group percentages
+ geom_text(
point_data >> query('age_group=="50+"'),
aes(x='sm_use_percent-2', y='country', label='sm_use_percent_str', color='age_group'),
size=8,
ha='right',
)
+ geom_text(
point_data >> query('age_group=="35-49"'),
aes(x='sm_use_percent+2', y='country', label='sm_use_percent_str'),
size=8,
ha='left',
va='center',
color='white'
)
+ geom_text(
point_data >> query('age_group=="18-34"'),
aes(x='sm_use_percent+2', y='country', label='sm_use_percent_str', color='age_group'),
size=8,
ha='left',
)
# countries right-hand-size (instead of y-axis) # new
+ geom_text(
segment_data,
aes(y='country', label='country'),
x=-1,
size=8,
ha='right',
fontweight='bold',
color='#222222'
)
# gap difference
+ geom_vline(xintercept=xgap, color='#edece3', size=32) # new
+ geom_text(
segment_data,
aes(x=xgap, y='country', label='gap_str'),
size=9,
fontweight='bold',
format_string='+{}'
)
# Annotations # new
+ annotate('text', x=31, y=n+1.1, label='50+', size=9, color='#ea9f2f', va='top')
+ annotate('text', x=56, y=n+1.1, label='35-49', size=9, color='#6d6e71', va='top')
+ annotate('text', x=85, y=n+1.1, label='18-34', size=9, color='#939c49', va='top')
+ annotate('text', x=xgap, y=n+.5, label='Youngest-\nOldest Gap', size=9, color='#444444', va='bottom', ha='center')
+ annotate('point', x=[31, 56, 85], y=n+.3, alpha=0.85, stroke=0)
+ annotate('segment', x=[31, 56, 85], xend=[31, 56, 85], y=n+.3, yend=n+.8, alpha=0.85)
+ annotate('hline', yintercept=[x+0.5 for x in range(2, n, 2)], alpha=.5, linetype='dotted', size=0.7)
# Better spacing and color # new
+ scale_x_continuous(limits=(-18, xgap+2))
+ scale_y_discrete(expand=(0, 0, 0.1, 0))
+ scale_fill_manual(values=['#c3ca8c', '#d1d3d4', '#f2c480'])
+ scale_color_manual(values=['#939c49', '#6d6e71', '#ea9f2f'])
+ guides(color=None, fill=None)
+ theme_void()
+ theme(figure_size=(8, 7.5))
)

[9]:
<ggplot: (97654321012345679)>
Instead of looking at this plot as having a country variable on the y-axis
and a percentage variable on the x-axis
, we can view it as having vertically stacked up many indepedent variables, the values of which have a similar scale.
Protip: Save a pdf file.
Change in Rank¶
Comparing a group of ranked items at two different times
Read the data.
Source: World Bank - Infanct Mortality Rate (per 1,000 live births)b
[10]:
data = pd.read_csv(
'data/API_SP.DYN.IMRT.IN_DS2_en_csv_v2/API_SP.DYN.IMRT.IN_DS2_en_csv_v2.csv',
skiprows=[0, 1, 2])
# Columns as valid python variables
year_columns = {'y{}'.format(c): c for c in data.columns if c[:2] in {'19', '20'}}
data = (
data
>> rename({'country': 'Country Name', 'code': 'Country Code'})
>> rename(year_columns)
>> select('Indicator Name', 'Indicator Code', 'Unnamed: 61', drop=True)
)
data >> head()
[10]:
country | code | y1960 | y1961 | y1962 | y1963 | y1964 | y1965 | y1966 | y1967 | ... | y2007 | y2008 | y2009 | y2010 | y2011 | y2012 | y2013 | y2014 | y2015 | y2016 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | Aruba | ABW | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
1 | Afghanistan | AFG | NaN | 240.5 | 236.3 | 232.3 | 228.5 | 224.6 | 220.7 | 217.0 | ... | 80.4 | 78.6 | 76.8 | 75.1 | 73.4 | 71.7 | 69.9 | 68.1 | 66.3 | NaN |
2 | Angola | AGO | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | ... | 117.1 | 114.7 | 112.2 | 109.6 | 106.8 | 104.1 | 101.4 | 98.8 | 96.0 | NaN |
3 | Albania | ALB | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | ... | 16.7 | 16.0 | 15.4 | 14.8 | 14.3 | 13.8 | 13.3 | 12.9 | 12.5 | NaN |
4 | Andorra | AND | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | ... | 2.8 | 2.7 | 2.6 | 2.5 | 2.4 | 2.3 | 2.2 | 2.1 | 2.1 | NaN |
5 rows × 59 columns
The data includes regional aggregates. To tell apart the regional aggregates we need the metadata. Every row in the data table has a corresponding row in the metadata table. Where the row has regional aggregate data, the Region
column in the metadata table is NaN
.
[11]:
metadata = pd.read_csv(
'data/API_SP.DYN.IMRT.IN_DS2_en_csv_v2/Metadata_Country_API_SP.DYN.IMRT.IN_DS2_en_csv_v2.csv'
)
metadata = (
metadata
>> rename({'code': 'Country Code',
'region': 'Region',
'income_group': 'IncomeGroup'
})
>> select('code', 'region', 'income_group')
)
cat_order = ['High income', 'Upper middle income', 'Lower middle income', 'Low income']
metadata['income_group'] = pd.Categorical(metadata['income_group'], categories=cat_order, ordered=True)
metadata >> head(10)
[11]:
code | region | income_group | |
---|---|---|---|
0 | ABW | Latin America & Caribbean | High income |
1 | AFG | South Asia | Low income |
2 | AGO | Sub-Saharan Africa | Lower middle income |
3 | ALB | Europe & Central Asia | Upper middle income |
4 | AND | Europe & Central Asia | High income |
5 | ARB | NaN | NaN |
6 | ARE | Middle East & North Africa | High income |
7 | ARG | Latin America & Caribbean | Upper middle income |
8 | ARM | Europe & Central Asia | Lower middle income |
9 | ASM | East Asia & Pacific | Upper middle income |
ARB
(row 5) is a regional code.
Drop the regional aggregates from the metadata
[12]:
# Drop the aggregates from
country_metadata = metadata >> call('.dropna', subset=['region'])
country_metadata >> head(10)
[12]:
code | region | income_group | |
---|---|---|---|
0 | ABW | Latin America & Caribbean | High income |
1 | AFG | South Asia | Low income |
2 | AGO | Sub-Saharan Africa | Lower middle income |
3 | ALB | Europe & Central Asia | Upper middle income |
4 | AND | Europe & Central Asia | High income |
6 | ARE | Middle East & North Africa | High income |
7 | ARG | Latin America & Caribbean | Upper middle income |
8 | ARM | Europe & Central Asia | Lower middle income |
9 | ASM | East Asia & Pacific | Upper middle income |
10 | ATG | Latin America & Caribbean | High income |
Remove the regional aggregates, to create a table with only country data
[13]:
country_data = inner_join(data, country_metadata, on='code')
country_data >> head()
[13]:
country | code | y1960 | y1961 | y1962 | y1963 | y1964 | y1965 | y1966 | y1967 | ... | y2009 | y2010 | y2011 | y2012 | y2013 | y2014 | y2015 | y2016 | region | income_group | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | Aruba | ABW | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | Latin America & Caribbean | High income |
1 | Afghanistan | AFG | NaN | 240.5 | 236.3 | 232.3 | 228.5 | 224.6 | 220.7 | 217.0 | ... | 76.8 | 75.1 | 73.4 | 71.7 | 69.9 | 68.1 | 66.3 | NaN | South Asia | Low income |
2 | Angola | AGO | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | ... | 112.2 | 109.6 | 106.8 | 104.1 | 101.4 | 98.8 | 96.0 | NaN | Sub-Saharan Africa | Lower middle income |
3 | Albania | ALB | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | ... | 15.4 | 14.8 | 14.3 | 13.8 | 13.3 | 12.9 | 12.5 | NaN | Europe & Central Asia | Upper middle income |
4 | Andorra | AND | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | ... | 2.6 | 2.5 | 2.4 | 2.3 | 2.2 | 2.1 | 2.1 | NaN | Europe & Central Asia | High income |
5 rows × 61 columns
We are interested in the changes in rank between 1960 and 2015. To plot a reasonable sized graph, we randomly sample 35 countries.
[14]:
sampled_data = (
country_data
>> call('.dropna', subset=['y1960', 'y2015'])
>> sample_n(35, random_state=123)
>> define(
y1960_rank='y1960.rank(method="min").astype(int)',
y2015_rank='y2015.rank(method="min").astype(int)'
)
)
sampled_data >> head()
[14]:
country | code | y1960 | y1961 | y1962 | y1963 | y1964 | y1965 | y1966 | y1967 | ... | y2011 | y2012 | y2013 | y2014 | y2015 | y2016 | region | income_group | y1960_rank | y2015_rank | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
25 | Bolivia | BOL | 173.4 | 170.5 | 167.7 | 165.0 | 162.2 | 159.4 | 156.5 | 153.6 | ... | 35.3 | 34.0 | 32.8 | 31.7 | 30.6 | NaN | Latin America & Caribbean | Lower middle income | 33 | 26 |
182 | Sweden | SWE | 16.3 | 16.0 | 15.6 | 15.0 | 14.4 | 13.7 | 13.0 | 12.6 | ... | 2.4 | 2.4 | 2.4 | 2.4 | 2.4 | NaN | Europe & Central Asia | High income | 1 | 1 |
106 | Kuwait | KWT | 101.6 | 95.2 | 89.0 | 83.3 | 77.8 | 72.6 | 68.0 | 63.7 | ... | 8.9 | 8.5 | 8.1 | 7.7 | 7.3 | NaN | Middle East & North Africa | High income | 15 | 10 |
63 | Fiji | FJI | 54.0 | 52.1 | 50.3 | 48.8 | 47.5 | 46.3 | 45.3 | 44.5 | ... | 20.1 | 20.0 | 19.7 | 19.4 | 19.1 | NaN | East Asia & Pacific | Upper middle income | 8 | 21 |
160 | Paraguay | PRY | 62.6 | 62.0 | 61.4 | 60.9 | 60.5 | 60.0 | 59.6 | 59.3 | ... | 19.8 | 19.2 | 18.6 | 18.1 | 17.5 | NaN | Latin America & Caribbean | Upper middle income | 11 | 17 |
5 rows × 63 columns
First graph
[15]:
(ggplot(sampled_data)
+ geom_text(aes(1, 'y1960_rank', label='country'), ha='right', size=9)
+ geom_text(aes(2, 'y2015_rank', label='country'), ha='left', size=9)
+ geom_point(aes(1, 'y1960_rank', color='income_group'), size=2.5)
+ geom_point(aes(2, 'y2015_rank', color='income_group'), size=2.5)
+ geom_segment(aes(x=1, y='y1960_rank', xend=2, yend='y2015_rank', color='income_group'))
)

[15]:
<ggplot: (97654321012345679)>
It has the form we want, but we need to tweak it.
[16]:
# Text colors
black1 = '#252525'
black2 = '#222222'
# Gallery Plot
(ggplot(sampled_data)
# Slight modifications for the original lines,
# 1. Nudge the text to either sides of the points
# 2. Alter the color and alpha values
+ geom_text(aes(1, 'y1960_rank', label='country'), nudge_x=-0.05, ha='right', size=9, color=black1)
+ geom_text(aes(2, 'y2015_rank', label='country'), nudge_x=0.05, ha='left', size=9, color=black1)
+ geom_point(aes(1, 'y1960_rank', color='income_group'), size=2.5, alpha=.7)
+ geom_point(aes(2, 'y2015_rank', color='income_group'), size=2.5, alpha=.7)
+ geom_segment(aes(x=1, y='y1960_rank', xend=2, yend='y2015_rank', color='income_group'), alpha=.7)
# Text Annotations
+ annotate('text', x=1, y=0, label='Rank in 1960', fontweight='bold', ha='right', size=10, color=black2)
+ annotate('text', x=2, y=0, label='Rank in 2015', fontweight='bold', ha='left', size=10, color=black2)
+ annotate('text', x=1.5, y=0, label='Lines show change in rank', size=9, color=black1)
+ annotate('label', x=1.5, y=3, label='Lower infant\ndeath rates', size=9, color=black1,
label_size=0, fontstyle='italic')
+ annotate('label', x=1.5, y=33, label='Higher infant\ndeath rates', size=9, color=black1,
label_size=0, fontstyle='italic')
# Prevent country names from being chopped off
+ lims(x=(0.35, 2.65))
+ labs(color='Income Group')
# Countries with lower rates on top
+ scale_y_reverse()
# Change colors
+ scale_color_brewer(type='qual', palette=2)
# Removes all decorations
+ theme_void()
# Changing the figure size prevents the country names from squishing up
+ theme(figure_size=(8, 11))
)

[16]:
<ggplot: (97654321012345679)>