The effects of Brexit#

The aim of this notebook is to estimate the causal impact of Brexit upon the UK’s GDP. This will be done using the synthetic control approach. As such, it is similar to the policy brief “What can we know about the cost of Brexit so far?” [Springford, 2022] from the Center for European Reform. That approach did not use Bayesian estimation methods however.

I did not use the GDP data from the above report however as it had been scaled in some way that was hard for me to understand how it related to the absolute GDP figures. Instead, GDP data was obtained courtesy of Prof. Dooruj Rambaccussing. Raw data is in units of billions of USD.

Warning

This is an experimental and in-progress notebook! While the results are reasonable, there is still some perfecting to be done on the inference side of things. There are high correlations between countries, and the prior for the Dirichlet distribution for country weightings could do with some attention. That said, the results here represent a ‘reasonable’ first approach at this dataset.

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import causalpy as cp
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'
seed = 42

Load data#

df = (
    cp.load_data("brexit")
    .assign(Time=lambda x: pd.to_datetime(x["Time"]))
    .set_index("Time")
    .loc[lambda x: x.index >= "2009-01-01"]
    # manual exclusion of some countries
    .drop(["Japan", "Italy", "US", "Spain", "Portugal"], axis=1)
)

# specify date of the Brexit vote announcement
treatment_time = pd.to_datetime("2016 June 24")
# get useful country lists
target_country = "UK"
all_countries = df.columns
other_countries = all_countries.difference({target_country})
all_countries = list(all_countries)
other_countries = list(other_countries)

Data visualization#

az.style.use("arviz-white")
# Plot the time series normalised so that intervention point (Q3 2016) is equal to 100
gdp_at_intervention = df.loc[pd.to_datetime("2016 July 01"), :]
df_normalised = (df / gdp_at_intervention) * 100.0

# plot
fig, ax = plt.subplots()
for col in other_countries:
    ax.plot(df_normalised.index, df_normalised[col], color="grey", alpha=0.2)

ax.plot(df_normalised.index, df_normalised[target_country], color="red", lw=3)
# ax = df_normalised.plot(legend=False)

# formatting
ax.set(title="Normalised GDP")
ax.axvline(x=treatment_time, color="r", ls=":");
/Users/benjamv/mambaforge/envs/CausalPy/lib/python3.13/site-packages/matplotlib_inline/config.py:68: DeprecationWarning: InlineBackend._figure_format_changed is deprecated in traitlets 4.1: use @observe and @unobserve instead.
  def _figure_format_changed(self, name, old, new):
../_images/d913d116d1a335951d4f191164078e8959df7c2ea9d2464594c8655dde083b96.png
# Examine how correlated the pre-intervention time series are

pre_intervention_data = df.loc[df.index < treatment_time, :]

corr = pre_intervention_data.corr()

f, ax = plt.subplots(figsize=(8, 6))
ax = sns.heatmap(
    corr,
    mask=np.triu(np.ones_like(corr, dtype=bool)),
    cmap=sns.diverging_palette(230, 20, as_cmap=True),
    vmin=-0.2,
    vmax=1.0,
    center=0,
    square=True,
    linewidths=0.5,
    cbar_kws={"shrink": 0.8},
)
ax.set(title="Correlations for pre-intervention GDP data");
../_images/8aafe77496f5293e5f17c3c035cd37376b46ae4f01f0a71e257f27ceaf61bdf7.png

Run the analysis#

Note: The analysis is (and should be) run on the raw GDP data. We do not use the normalised data shown above which was just for ease of visualization.

Note

The random_seed keyword argument for the PyMC sampler is not necessary. We use it here so that the results are reproducible.

sample_kwargs = {"tune": 4000, "target_accept": 0.99, "random_seed": seed}

result = cp.SyntheticControl(
    df,
    treatment_time,
    control_units=other_countries,
    treated_units=[target_country],
    model=cp.pymc_models.WeightedSumFitter(sample_kwargs=sample_kwargs),
)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [beta, sigma]
/Users/benjamv/mambaforge/envs/CausalPy/lib/python3.13/multiprocessing/popen_fork.py:67: DeprecationWarning: This process (pid=30956) is multi-threaded, use of fork() may lead to deadlocks in the child.
  self.pid = os.fork()
/Users/benjamv/mambaforge/envs/CausalPy/lib/python3.13/multiprocessing/popen_fork.py:67: DeprecationWarning: This process (pid=30956) is multi-threaded, use of fork() may lead to deadlocks in the child.
  self.pid = os.fork()
/Users/benjamv/mambaforge/envs/CausalPy/lib/python3.13/multiprocessing/popen_fork.py:67: DeprecationWarning: This process (pid=30956) is multi-threaded, use of fork() may lead to deadlocks in the child.
  self.pid = os.fork()
/Users/benjamv/mambaforge/envs/CausalPy/lib/python3.13/multiprocessing/popen_fork.py:67: DeprecationWarning: This process (pid=30956) is multi-threaded, use of fork() may lead to deadlocks in the child.
  self.pid = os.fork()

Sampling 4 chains for 4_000 tune and 1_000 draw iterations (16_000 + 4_000 draws total) took 212 seconds.
There were 2 divergences after tuning. Increase `target_accept` or reparameterize.
Chain 0 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain 1 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain 2 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain 3 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
Sampling: [beta, sigma, y_hat]
Sampling: [y_hat]
Sampling: [y_hat]
Sampling: [y_hat]
Sampling: [y_hat]

We currently get some divergences, but these are mostly dealt with by increasing tune and target_accept sampling parameters. Nevertheless, the sampling of this dataset/model combination feels a little brittle.

Check the MCMC chain mixing via the Rhat statistic.

az.summary(result.idata, var_names=["~mu"])
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
beta[Australia] 0.118 0.073 0.001 0.244 0.002 0.001 786.0 950.0 1.00
beta[Austria] 0.043 0.040 0.000 0.115 0.001 0.001 585.0 573.0 1.01
beta[Belgium] 0.050 0.045 0.000 0.133 0.001 0.001 817.0 940.0 1.00
beta[Canada] 0.039 0.022 0.000 0.075 0.001 0.000 417.0 439.0 1.01
beta[Denmark] 0.088 0.064 0.000 0.202 0.002 0.001 586.0 633.0 1.00
beta[Finland] 0.041 0.037 0.000 0.109 0.001 0.001 570.0 661.0 1.01
beta[France] 0.030 0.028 0.000 0.081 0.001 0.001 695.0 575.0 1.00
beta[Germany] 0.025 0.023 0.000 0.067 0.001 0.001 595.0 567.0 1.00
beta[Iceland] 0.154 0.039 0.081 0.224 0.001 0.001 915.0 894.0 1.00
beta[Luxemburg] 0.054 0.047 0.000 0.140 0.001 0.001 768.0 520.0 1.00
beta[Netherlands] 0.049 0.045 0.000 0.132 0.001 0.001 866.0 923.0 1.00
beta[New_Zealand] 0.063 0.054 0.000 0.163 0.002 0.001 581.0 512.0 1.00
beta[Norway] 0.081 0.045 0.000 0.155 0.002 0.001 331.0 250.0 1.01
beta[Sweden] 0.099 0.031 0.042 0.158 0.001 0.001 618.0 686.0 1.01
beta[Switzerland] 0.064 0.055 0.000 0.167 0.001 0.001 3955.0 2386.0 1.00
sigma 0.031 0.005 0.023 0.039 0.000 0.000 1063.0 1600.0 1.01

You can inspect the traces in more detail with:

az.plot_trace(result.idata, var_names="~mu", compact=False);
az.style.use("arviz-darkgrid")

fig, ax = result.plot(plot_predictors=False)

for i in [0, 1, 2]:
    ax[i].set(ylabel="Billion USD")
../_images/2557e2e961ef8a908d385ce4ff1f95d6d20e23aa09c9707ab12b842ddd4a8034.png
result.summary()
================================SyntheticControl================================
Control units: ['Australia', 'Austria', 'Belgium', 'Canada', 'Denmark', 'Finland', 'France', 'Germany', 'Iceland', 'Luxemburg', 'Netherlands', 'New_Zealand', 'Norway', 'Sweden', 'Switzerland']
Treated unit: UK
Model coefficients:
    Australia    0.12, 94% HDI [0.0096, 0.27]
    Austria      0.043, 94% HDI [0.0015, 0.14]
    Belgium      0.05, 94% HDI [0.0021, 0.16]
    Canada       0.039, 94% HDI [0.0033, 0.083]
    Denmark      0.088, 94% HDI [0.0043, 0.23]
    Finland      0.041, 94% HDI [0.0013, 0.14]
    France       0.03, 94% HDI [0.00095, 0.097]
    Germany      0.025, 94% HDI [0.0008, 0.083]
    Iceland      0.15, 94% HDI [0.082, 0.23]
    Luxemburg    0.054, 94% HDI [0.0025, 0.17]
    Netherlands  0.049, 94% HDI [0.0024, 0.16]
    New_Zealand  0.063, 94% HDI [0.0021, 0.19]
    Norway       0.081, 94% HDI [0.0047, 0.17]
    Sweden       0.099, 94% HDI [0.039, 0.16]
    Switzerland  0.064, 94% HDI [0.0023, 0.2]
    sigma        0.031, 94% HDI [0.023, 0.041]
ax = az.plot_forest(result.idata, var_names="beta", figsize=(6, 5))
ax[0].set(title="Estimated weighting coefficients");
../_images/167440d81148ea7a39b3a0888ca7ec858b0eafa7a74ce75fd2c0d1f0390c5c08.png

References#

[1]

John Springford. What can we know about the cost of brexit so far? 2022. URL: https://www.cer.eu/publications/archive/policy-brief/2022/cost-brexit-so-far.