Multi-cell geolift analysis#

In other examples, we’ve seen how we can use Synthetic Control methods to estimate the causal impact of a treatment in one geographic area (geo). In this example, we’ll extend the analysis to multiple geographic areas (geos).

This may be a particularly common use case in marketing, where a company may want to understand the impact of a marketing campaign in multiple regions. But these methods are not restricted to marketing of course - the methods shown here are general. Another concrete use case may be in public health, where a public health intervention may be rolled out in multiple regions.

This notebook focusses on the situation where the treatment has already taken place, and now we want to understand the causal effects of the treatments that were executed. Much work likely preceded this analysis, such as asking yourself questions like “which geos should I run the treatment in?”, “what should the treatment be?” But these pre-treatment questions are not the focus of this notebook.

We can imagine two scenarios (there may be more), and show how we can tailor our analysis to each:

  1. The treatments were similar in kind and/or magnitude in each region. An example of this may be where a company ran the same marketing campaign in multiple test geos. In cases like this, we can imagine that the causal impact of the treatment may be similar in each region. So we will show an example of how to analyse geo lift data like this. We can think of this as a fully pooled analysis approach.

  2. The treatments were of different kinds and/or magnitudes in each region. An example of this may be where different marketing campaigns were run in different regions, and perhaps the budgets were different in each region. In cases like this, we can imagine that the causal impact of the treatment may be different in each region. So we will show an example of how to analyse geo lift data like this. We can think of this as as unpooled analysis approach.

Let’s start with some notebook setup:

import arviz as az
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import pandas as pd
import xarray as xr

import causalpy as cp
seed = 42

Load the dataset#

Developer notes

The synthetic dataset can be generated with the following code:

from causalpy.data.simulate_data import generate_multicell_geolift_data


df = generate_multicell_geolift_data()
df.to_csv("../../../causalpy/data/geolift_multi_cell.csv", index=True)
df = (
    cp.load_data("geolift_multi_cell")
    .assign(time=lambda x: pd.to_datetime(x["time"]))
    .set_index("time")
)

treatment_time = pd.to_datetime("2022-01-01")

# Define the treatment and control geos (the column names)
untreated = [
    "u1",
    "u2",
    "u3",
    "u4",
    "u5",
    "u6",
    "u7",
    "u8",
    "u9",
    "u10",
    "u11",
    "u12",
]

treated = ["t1", "t2", "t3", "t4"]

df.head()
u1 u2 u3 u4 u5 u6 u7 u8 u9 u10 u11 u12 t1 t2 t3 t4
time
2019-01-06 5.06 2.97 2.96 2.37 1.00 2.87 1.90 1.03 4.16 2.06 3.85 2.80 3.02 2.65 3.01 2.36
2019-01-13 5.14 3.06 2.89 2.40 0.92 3.16 1.85 0.83 4.12 1.93 3.83 2.89 2.91 2.44 3.15 2.14
2019-01-20 5.09 3.20 2.84 2.43 0.97 3.18 1.80 1.15 4.08 2.14 3.82 2.92 3.00 2.50 3.09 2.32
2019-01-27 5.21 3.18 2.90 2.14 0.75 3.14 1.97 1.09 4.10 2.11 3.87 2.81 3.02 2.50 3.12 2.20
2019-02-03 4.86 3.14 2.81 2.31 0.61 3.36 2.00 1.13 4.21 2.03 3.87 2.97 2.98 2.41 3.07 2.25

Always visualise the data before starting the analysis. Our rather uncreative naming scheme uses u to represent untreated geos, and t to represent treated geos. The number after the u or t represents the geo number.

ax = df.plot(colormap="tab20")
ax.axvline(treatment_time, color="k", linestyle="--")
ax.set(title="Observed data from all geos", ylabel="Sales volume (thousands)")
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5));

We can see that each geo has some seasonality component as well as some noise. The treatment (vertical dashed line) is the same in each geo. The question is: to what extent did we get uplift in our KPI in each treated geo?

Pooled analysis approach#

The first analysis approach is to aggregate the treated geos and analyze them as a group. In the code cell below we will aggregate with the median function, but we can also use the mean. While these are likely to be the most often used aggregation functions, the user is free to use any other function appropriate for a given dataset.

df["treated_agg"] = df[treated].median(axis=1)
df.head()
u1 u2 u3 u4 u5 u6 u7 u8 u9 u10 u11 u12 t1 t2 t3 t4 treated_agg
time
2019-01-06 5.06 2.97 2.96 2.37 1.00 2.87 1.90 1.03 4.16 2.06 3.85 2.80 3.02 2.65 3.01 2.36 2.83
2019-01-13 5.14 3.06 2.89 2.40 0.92 3.16 1.85 0.83 4.12 1.93 3.83 2.89 2.91 2.44 3.15 2.14 2.68
2019-01-20 5.09 3.20 2.84 2.43 0.97 3.18 1.80 1.15 4.08 2.14 3.82 2.92 3.00 2.50 3.09 2.32 2.75
2019-01-27 5.21 3.18 2.90 2.14 0.75 3.14 1.97 1.09 4.10 2.11 3.87 2.81 3.02 2.50 3.12 2.20 2.76
2019-02-03 4.86 3.14 2.81 2.31 0.61 3.36 2.00 1.13 4.21 2.03 3.87 2.97 2.98 2.41 3.07 2.25 2.69

Let’s visualise this aggregated geo and compare it to the individual treated geo’s.

ax = df[treated].plot(colormap="tab20")
df["treated_agg"].plot(color="k", lw=4, ax=ax, label="Aggregate geo")
ax.axvline(treatment_time, color="k", linestyle="--")
ax.set(title="Treated geos and the aggregate", ylabel="Sales volume (thousands)")
plt.legend(loc="center left", bbox_to_anchor=(1, 0.5));

Now we just proceed as we would with a regular single-geo analysis.

So first we’ll define the model formula - namely that we are modeling the treated_agg geo as a function of the untreated geos, and the 0 specifies that we are not using an intercept.

formula = f"treated_agg ~ 0 + {' + '.join(untreated)}"
print(formula)
treated_agg ~ 0 + u1 + u2 + u3 + u4 + u5 + u6 + u7 + u8 + u9 + u10 + u11 + u12

Then we’ll fit the model and print the summary.

pooled = cp.SyntheticControl(
    df,
    treatment_time,
    formula=formula,
    model=cp.pymc_models.WeightedSumFitter(
        sample_kwargs={"target_accept": 0.95, "random_seed": seed}
    ),
)
Hide code cell output
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [beta, sigma]


Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 20 seconds.
Sampling: [beta, sigma, y_hat]
Sampling: [y_hat]
Sampling: [y_hat]
Sampling: [y_hat]
Sampling: [y_hat]
pooled.summary()
==================================Pre-Post Fit==================================
Formula: treated_agg ~ 0 + u1 + u2 + u3 + u4 + u5 + u6 + u7 + u8 + u9 + u10 + u11 + u12
Model coefficients:
    u1     0.13, 94% HDI [0.071, 0.18]
    u2     0.099, 94% HDI [0.057, 0.14]
    u3     0.093, 94% HDI [0.015, 0.17]
    u4     0.14, 94% HDI [0.097, 0.19]
    u5     0.069, 94% HDI [0.017, 0.13]
    u6     0.049, 94% HDI [0.0039, 0.11]
    u7     0.12, 94% HDI [0.031, 0.21]
    u8     0.11, 94% HDI [0.049, 0.17]
    u9     0.038, 94% HDI [0.0019, 0.099]
    u10    0.039, 94% HDI [0.0022, 0.089]
    u11    0.071, 94% HDI [0.0074, 0.15]
    u12    0.039, 94% HDI [0.0021, 0.099]
    sigma  0.067, 94% HDI [0.06, 0.075]

We can see the model weightings visually like this:

az.plot_forest(pooled.idata, var_names=["~mu"], figsize=(8, 4), combined=True);

And of course we can see the causal impact plot using the plot method.

fig, ax = pooled.plot(plot_predictors=False)

# formatting
ax[2].tick_params(axis="x", labelrotation=-90)
ax[2].xaxis.set_major_formatter(mdates.DateFormatter("%Y"))
ax[2].xaxis.set_major_locator(mdates.YearLocator())
for i in [0, 1, 2]:
    ax[i].set(ylabel="Sales (thousands)")

Unpooled analysis approach#

The second analysis approach is to analyze each treated geo individually.

unpooled_results = []

for i, target_geo in enumerate(treated):
    print(f"Analyzing test geo: {target_geo} ({i+1} of {len(treated)})")
    formula = f"{target_geo} ~ 0 + {' + '.join(untreated)}"
    print(formula)

    result = cp.SyntheticControl(
        df,
        treatment_time,
        formula=formula,
        model=cp.pymc_models.WeightedSumFitter(
            sample_kwargs={"target_accept": 0.95, "random_seed": seed}
        ),
    )
    unpooled_results.append(result)
Hide code cell output
Analyzing test geo: t1 (1 of 4)
t1 ~ 0 + u1 + u2 + u3 + u4 + u5 + u6 + u7 + u8 + u9 + u10 + u11 + u12
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [beta, sigma]


Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 19 seconds.
Sampling: [beta, sigma, y_hat]
Sampling: [y_hat]
Sampling: [y_hat]
Sampling: [y_hat]
Sampling: [y_hat]
Analyzing test geo: t2 (2 of 4)
t2 ~ 0 + u1 + u2 + u3 + u4 + u5 + u6 + u7 + u8 + u9 + u10 + u11 + u12
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [beta, sigma]


Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 17 seconds.
Sampling: [beta, sigma, y_hat]
Sampling: [y_hat]
Sampling: [y_hat]
Sampling: [y_hat]
Sampling: [y_hat]
Analyzing test geo: t3 (3 of 4)
t3 ~ 0 + u1 + u2 + u3 + u4 + u5 + u6 + u7 + u8 + u9 + u10 + u11 + u12
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [beta, sigma]


Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 19 seconds.
Sampling: [beta, sigma, y_hat]
Sampling: [y_hat]
Sampling: [y_hat]
Sampling: [y_hat]
Sampling: [y_hat]
Analyzing test geo: t4 (4 of 4)
t4 ~ 0 + u1 + u2 + u3 + u4 + u5 + u6 + u7 + u8 + u9 + u10 + u11 + u12
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [beta, sigma]


Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 22 seconds.
Sampling: [beta, sigma, y_hat]
Sampling: [y_hat]
Sampling: [y_hat]
Sampling: [y_hat]
Sampling: [y_hat]

Now let’s plot the weightings of the untreated geos for each treated geo. Note that sigma is the model’s estimate of the standard deviation of the observation noise.

If we wanted to produce separate plots for each target geo, we could do so like this:

fig, axs = plt.subplots(len(treated), 1, figsize=(8, 4 * len(treated)), sharex=True)

for target_geo, ax, result in zip(treated, axs, unpooled_results):
    az.plot_forest(result.idata, var_names=["~mu"], combined=True, ax=ax)
    ax.set(title=f"target geo: {target_geo}")

But instead we will use a nice feature of ArviZ to plot all the weightings on the same plot, but with different colors for each treated geo.

az.plot_forest(
    [results.idata for results in unpooled_results],
    model_names=treated,
    var_names=["~mu"],
    combined=True,
    figsize=(8, 12),
);

And let’s also plot the idiomatic causal impact plot for each treated geo.

for treated_geo, result in zip(treated, unpooled_results):
    fig, ax = result.plot(plot_predictors=False)

    # formatting
    ax[2].tick_params(axis="x", labelrotation=-90)
    ax[2].xaxis.set_major_formatter(mdates.DateFormatter("%Y"))
    ax[2].xaxis.set_major_locator(mdates.YearLocator())
    for i in [0, 1, 2]:
        ax[i].set(ylabel="Sales (thousands)")
    plt.suptitle(f"Causal impact for {treated_geo}")

We’ve seen in this section that it is not just possible, but very easy to analyse geo lift result data when there are multiple treated geos. This approach essentially just iterates through each treated geo and analyses them individually.

This does of course mean that if we have a large number of treated geos, we will have a large number of results and plots to look at, but this is fine.

Comparing the two approaches#

Let’s compare the two approaches by plotting the posterior distribution of the total cumulative causal impact over the whole post=treatment period.

The top plot shows the estimate for the pooled model, and the bottom shows the estimates for each of the models applied to the 4 treated geos.

Hide code cell source
def get_final_cumulative_impact(result):
    return result.post_impact_cumulative.sel(
        {"obs_ind": result.post_impact_cumulative.obs_ind.max()}
    )


pooled_cumulative = get_final_cumulative_impact(pooled)

unpooled_cumulative = xr.concat(
    [get_final_cumulative_impact(result) for result in unpooled_results],
    dim="treated_region",
)

axes = az.plot_forest(
    [pooled_cumulative, unpooled_cumulative],
    model_names=["Pooled", "Unpooled"],
    combined=True,
)
axes[0].set(title="Estimated total cumulative impact", xlabel="Sales (thousands)");

For this particular (simulated) dataset, the pooled and unpooled approaches give similar results. This is expected because the treatment was the same in each geo, and the causal impact of the treatment was similar in each geo. We’d likely see more variation in the estimates of the unpooled models if the real causal impacts were more heterogeneous across the geos.

Summary#

We’ve shown two methods that we can use to analyse geo lift data with multiple treated geos. To do this, we used a simulated dataset with seasonality and observation noise.

The first method is to aggregate the treated geos and analyze them as a single aggregated geo. This is useful when we expect the effects of the intervention to be similar in each treated region - for example if we deployed the same kind and magnitude of intervention in each treated region. This method is also useful when we have a large number of treated geos and we want to reduce the number of models we need to fit and create a single ‘story’ for the causal effects of the treatment across all treated geos.

The second method is to analyze each treated geo individually. This is useful when we want to understand the impact of each geo separately. This may make most sense if the treatments were different in kind or magnitude. That is, when we do not expect the effects of the intervention to be similar in each treated region.

But what about more a more complex scenario? We could imagine a situation where one intervention (e.g. a store refurbishment programme) was deployed in some geos, and a different intervention (e.g. a marketing campaign) was deployed in different geos. In this case, we could use a combination of the two methods we’ve shown here. We could aggregate the treated geos where the same intervention was deployed, and analyze the treated geos where different interventions were deployed separately.