Bayesian SIR

In this post I review how to build a compartmental model using the Stan probabilistic computing language. This is based largely by the case study, Bayesian workflow for disease transmission modeling in Stan which has been expanded to include a second compartment for exposed individuals as well as utilise case incidence data rather than prevalence.

Michael DeWitt https://michaeldewittjr.com
09-05-2020

library(cmdstanr)
library(magrittr)
library(ggplot2)
library(data.table)
library(deSolve)

Compartment models are commonly used in epidemiology to model epidemics. Compartmental model are composed of differential equations and captured some “knowns” regarding disease transmission. Because these models seek to simulate/ model the epidemic process directly, they are are somewhat more resistant to some biases (e.g. missing data). Strong-ish assumptions must be made regarding disease transmission and varying level of detail can be included in order to make the models closer to reality.

This post is largely an extension of Bayesian workflow for disease transmission modeling in Stan.

Data Generating Process

First we need to define our data generating process. Here we will start with a four compartment model with no births or deaths. This will represent an immunizing infection with a latent phase.

library(DiagrammeR)

a_graph <-
create_graph(directed = TRUE) %>%
node_aes = node_aes(fill = "orange")) %>%
node_aes = node_aes(fill = "orange")) %>%
node_aes = node_aes(fill = "orange"))%>%
node_aes = node_aes(fill = "orange")) %>%
add_edge(from = 1, to = 2) %>%
add_edge(from = 2, to = 3) %>%
add_edge(from = 3, to = 4)
render_graph(a_graph, layout = "nicely")

Simulate an Epidemic with Knowns

Now we can build this hypothetical epidemic using the “deSolve” package from R. Ideally, we will be able to recover our model parameters using our Bayesian model. This gives us confidence in fitting real data that we observe. This is a key step in the Bayesian workflow where we generate fake data, fit the fake data, and then examine the fit to ensure that we recover the real parameter before we fit our observed data.

seir_model = function (current_timepoint, state_values, parameters)
{
# create state variables (local variables)
S = state_values         # susceptibles
E = state_values         # exposed
I = state_values         # infectious
R = state_values         # recovered
N = state_values  + state_values  + state_values  +state_values 

with (
as.list (parameters),     # variable names within parameters can be used
{
# compute derivatives
dS = (-beta * S * I)/N
dE = (beta * S * I)/N - (delta * E)
dI = (delta * E) - (gamma * I)
dR = (gamma * I)

# combine results
results = c (dS, dE, dI, dR)
list (results)
}
)
}

beta_value <- 1/3
gamma_value <- 1/10
delta_value <- 1/4

parameter_list <- c (beta = beta_value, gamma = gamma_value, delta = delta_value)
times <- 1:120
initial_values <- c(S= 1000-2, E = 2, I =0, R = 0)
output = lsoda (initial_values, times, seir_model, parameter_list)

matplot(output[,2:5], main = "Observed Epidemic", type = "l", adj =0) We can extract the daily incidence using the following equation. This represents what would typically be reported by authorities. To make it more realistic, it would be good to convolve the cases with a delay distribution to indicate the lag we observe in case reporting. A deconvolution step could then be written into Stan in order to account for this delay distribution.

cases <-  ceiling(output[,2] - shift(output[,2],-1))
cases <- cases[-length(cases)]

Now let’s calculate our basic reproduction number or $$R_0$$

beta_value/gamma_value

 3.333333

Build Model in Stan

Now we can build the model in Stan as shown below. Ideally, I would write the ODE solver using the new syntax, but I’ll leave that to next time. We can see that the differential equations have been built into the “sir” function.

mod <- cmdstan_model("sir.stan")

mod$print() // Based on https://mc-stan.org/users/documentation/case-studies/boarding_school_case_study.html functions { real[] sir(real t, real[] y, real[] theta, real[] x_r, int[] x_i) { real S = y; real E = y; real I = y; real R = y; real N = x_i; real beta = theta; real delta = theta; real gamma = theta; real dS_dt = -beta * I * S / N; real dE_dt = beta * I * S / N - delta * E; real dI_dt = delta*E - gamma * I; real dR_dt = gamma * I; return {dS_dt, dE_dt, dI_dt, dR_dt}; } } data { int<lower=1> n_days; real y0; real t0; real ts[n_days]; int N; int cases[n_days-1]; int n_pred; real ts_pred[n_pred]; real delta_mu; } transformed data { real x_r; int x_i = { N }; } parameters { real<lower=0> theta; real<lower=0> phi_inv; } transformed parameters{ real y[n_days, 3]; real<lower=0> phi = 1. / phi_inv; real<lower=0> incidence[n_days-1]; { y = integrate_ode_rk45(sir, y0, t0, ts, theta, x_r, x_i); } //for (i in 2:(n_days-1)) { for (i in 1:(n_days-1)) incidence[i] = y[i, 1] - y[i + 1, 1]; } model { //priors theta~ normal(2, 1); //beta theta~ normal(delta_mu, .1); //delta theta~ normal(0.1, 0.7); //gamma phi_inv ~ exponential(2); cases ~ neg_binomial_2(incidence, phi); } generated quantities { real R0 = theta / theta; real recovery_time = 1 / theta; real pred_cases[n_days-1]; real pred_cases_out[n_pred-1]; real pred_incidence[n_pred-1]; // future prediction parameters real y_pred[n_pred, 3]; real y_init_pred = y[n_days, ]; // New initial conditions real t0_pred = ts[n_days]; // New time zero is the last observed time pred_cases = neg_binomial_2_rng(incidence, phi); y_pred = integrate_ode_rk45(sir, y_init_pred, t0_pred, ts_pred, theta, x_r, x_i); for (i in 1:(n_pred-1)) pred_incidence[i] = y_pred[i, 1] - y_pred[i + 1, 1]; pred_cases_out = neg_binomial_2_rng(pred_incidence, phi); } Build Dataset Now we can prep our dataset for Stan. # total count N <- 1000; # times n_days <- length(cases) +1 t <- seq(0, n_days, by = 1) t0 = 0 t <- t[-1] #initial conditions i0 <- 1 e0 <- 0 s0 <- N - i0 r0 <- 0 y0 = c(S = s0, E = e0, I = i0, R = r0) Run the Model Then we can run it using CmdStanR. # number of MCMC steps niter <- 2000 n_pred <- 21 data_sir<- list(n_days = n_days, y0 = y0, t0 = t0, ts = t, N = N, cases = cases, n_pred = n_pred,delta_mu=.2, ts_pred = seq(n_days+1, n_days+n_pred, by = 1) ) fit <- mod$sample(data = data_sir,
chains = 2,
max_treedepth = 12,
parallel_chains = 2,
iter_sampling = niter/2,
iter_warmup = niter/2)

Running MCMC with 2 parallel chains...

Chain 2 Iteration:    1 / 2000 [  0%]  (Warmup)
Chain 1 Iteration:    1 / 2000 [  0%]  (Warmup)
Chain 2 Iteration:  100 / 2000 [  5%]  (Warmup)
Chain 1 Iteration:  100 / 2000 [  5%]  (Warmup)
Chain 2 Iteration:  200 / 2000 [ 10%]  (Warmup)
Chain 1 Iteration:  200 / 2000 [ 10%]  (Warmup)
Chain 2 Iteration:  300 / 2000 [ 15%]  (Warmup)
Chain 1 Iteration:  300 / 2000 [ 15%]  (Warmup)
Chain 2 Iteration:  400 / 2000 [ 20%]  (Warmup)
Chain 1 Iteration:  400 / 2000 [ 20%]  (Warmup)
Chain 2 Iteration:  500 / 2000 [ 25%]  (Warmup)
Chain 2 Iteration:  600 / 2000 [ 30%]  (Warmup)
Chain 1 Iteration:  500 / 2000 [ 25%]  (Warmup)
Chain 2 Iteration:  700 / 2000 [ 35%]  (Warmup)
Chain 1 Iteration:  600 / 2000 [ 30%]  (Warmup)
Chain 2 Iteration:  800 / 2000 [ 40%]  (Warmup)
Chain 1 Iteration:  700 / 2000 [ 35%]  (Warmup)
Chain 2 Iteration:  900 / 2000 [ 45%]  (Warmup)
Chain 1 Iteration:  800 / 2000 [ 40%]  (Warmup)
Chain 2 Iteration: 1000 / 2000 [ 50%]  (Warmup)
Chain 2 Iteration: 1001 / 2000 [ 50%]  (Sampling)
Chain 1 Iteration:  900 / 2000 [ 45%]  (Warmup)
Chain 1 Iteration: 1000 / 2000 [ 50%]  (Warmup)
Chain 1 Iteration: 1001 / 2000 [ 50%]  (Sampling)
Chain 2 Iteration: 1100 / 2000 [ 55%]  (Sampling)
Chain 1 Iteration: 1100 / 2000 [ 55%]  (Sampling)
Chain 2 Iteration: 1200 / 2000 [ 60%]  (Sampling)
Chain 1 Iteration: 1200 / 2000 [ 60%]  (Sampling)
Chain 2 Iteration: 1300 / 2000 [ 65%]  (Sampling)
Chain 1 Iteration: 1300 / 2000 [ 65%]  (Sampling)
Chain 2 Iteration: 1400 / 2000 [ 70%]  (Sampling)
Chain 1 Iteration: 1400 / 2000 [ 70%]  (Sampling)
Chain 2 Iteration: 1500 / 2000 [ 75%]  (Sampling)
Chain 1 Iteration: 1500 / 2000 [ 75%]  (Sampling)
Chain 2 Iteration: 1600 / 2000 [ 80%]  (Sampling)
Chain 1 Iteration: 1600 / 2000 [ 80%]  (Sampling)
Chain 2 Iteration: 1700 / 2000 [ 85%]  (Sampling)
Chain 1 Iteration: 1700 / 2000 [ 85%]  (Sampling)
Chain 1 Iteration: 1800 / 2000 [ 90%]  (Sampling)
Chain 2 Iteration: 1800 / 2000 [ 90%]  (Sampling)
Chain 1 Iteration: 1900 / 2000 [ 95%]  (Sampling)
Chain 2 Iteration: 1900 / 2000 [ 95%]  (Sampling)
Chain 1 Iteration: 2000 / 2000 [100%]  (Sampling)
Chain 1 finished in 201.5 seconds.
Chain 2 Iteration: 2000 / 2000 [100%]  (Sampling)
Chain 2 finished in 204.3 seconds.

Both chains finished successfully.
Mean chain execution time: 202.9 seconds.
Total execution time: 204.4 seconds.

fit_sir<- rstan::read_stan_csv(fit$output_files()) rstan::extract(fit_sir, "y_pred")->y_pred str(y_pred) List of 1$ y_pred: num [1:2000, 1:21, 1:3] 39.4 40.1 40.9 32.5 38.9 ...
..- attr(*, "dimnames")=List of 3
.. ..$iterations: NULL .. ..$           : NULL
.. ..$: NULL med_estimates <- colMeans(y_pred[]) Review Model Outputs Let’s see if we recovered our actual parameters (yes, it appears so)! pars=c('theta', "R0", "recovery_time") fit$summary(variables = pars)

# A tibble: 5 x 10
variable   mean median     sd     mad     q5    q95  rhat ess_bulk
<chr>     <dbl>  <dbl>  <dbl>   <dbl>  <dbl>  <dbl> <dbl>    <dbl>
1 theta 1.99   1.92   0.461  0.432   1.35   2.86    1.00     585.
2 theta 0.0514 0.0504 0.0101 0.00956 0.0370 0.0698  1.00     745.
3 theta 0.499  0.480  0.122  0.112   0.341  0.720   1.00     656.
4 R0       4.03   3.98   0.519  0.504   3.25   4.93    1.00     847.
5 recover… 2.12   2.08   0.486  0.500   1.39   2.93    1.00     656.
# … with 1 more variable: ess_tail <dbl>

Visualise the Epidemic Curve

Finally, we can visualise our model outputs and see if we capture our actual cases. Additionally, I am using ggdist to capture the full range of possibilities as discussed in in my previous blog post

library(rstan)
library(ggdist)
library(dplyr)

extract(fit_sir, pars = "pred_cases")[] %>%
as.data.frame() %>%
mutate(.draw = 1:n()) %>%
tidyr::gather(key,value, -.draw) %>%
group_by(step) %>%
curve_interval(value, .width = c(.5, .8, .95)) %>%
ggplot(aes(x = step, y = value)) +
geom_hline(yintercept = 1, color = "gray75", linetype = "dashed") +
geom_lineribbon(aes(ymin = .lower, ymax = .upper)) +
scale_fill_brewer() +
labs(
title = "Simulated SIR Curve for Infections",
y = "Cases"
)+
geom_point(data = tibble(cases = cases, t = 1:length(cases)),
aes(t, cases), inherit.aes = FALSE, colour = "orange")+
theme_minimal() Looks like this model adequately captured our fake data! Part 2 will fit some real data.

Corrections

If you see mistakes or want to suggest changes, please create an issue on the source repository.

Reuse

Text and figures are licensed under Creative Commons Attribution CC BY 4.0. Source code is available at https://github.com/medewitt/medewitt.github.io, unless otherwise noted. The figures that have been reused from other sources don't fall under this license and can be recognized by a note in their caption: "Figure from ...".