Fitting the birats example

Author

Max Rohde

Published

November 3, 2022

Load and format data

Data is from: https://github.com/stan-dev/example-models/blob/master/bugs_examples/vol2/birats/birats.data.R

Stan code is from: https://github.com/stan-dev/example-models/blob/master/bugs_examples/vol2/birats/birats.stan

# Load in data from the `birats.data.R` file
N <- 30
T <- 5

y <-
structure(c(151, 145, 147, 155, 135, 159, 141, 159, 177, 134, 
160, 143, 154, 171, 163, 160, 142, 156, 157, 152, 154, 139, 146, 
157, 132, 160, 169, 157, 137, 153, 199, 199, 214, 200, 188, 210, 
189, 201, 236, 182, 208, 188, 200, 221, 216, 207, 187, 203, 212, 
203, 205, 190, 191, 211, 185, 207, 216, 205, 180, 200, 246, 249, 
263, 237, 230, 252, 231, 248, 285, 220, 261, 220, 244, 270, 242, 
248, 234, 243, 259, 246, 253, 225, 229, 250, 237, 257, 261, 248, 
219, 244, 283, 293, 312, 272, 280, 298, 275, 297, 350, 260, 313, 
273, 289, 326, 281, 288, 280, 283, 307, 286, 298, 267, 272, 285, 
286, 303, 295, 289, 258, 286, 320, 354, 328, 297, 323, 331, 305, 
338, 376, 296, 352, 314, 325, 358, 312, 324, 316, 317, 336, 321, 
334, 302, 302, 323, 331, 345, 333, 316, 291, 324), .Dim = c(30, 
5))

x <- c(8.0, 15.0, 22.0, 29.0, 36.0)

xbar <- 22

Omega <- structure(c(0.005, 0, 0, 5), .Dim = c(2, 2)) 
# Create the data list
data_list <- list(N = N,
                  T = T,
                  y = y,
                  x = x,
                  xbar = xbar,
                  Omega = Omega)

Compile and fit model

# Compile the model
mod <- cmdstan_model("birats.stan")
mod$print()
// http://www.mrc-bsu.cam.ac.uk/bugs/winbugs/Vol2.pdf
// Page 23: Birats
//# 

data {
  int<lower=0> N;
  int<lower=0> T;
  array[T] real x;
  real xbar;
  array[N, T] real y;
  cov_matrix[2] Omega;
}
parameters {
  array[N] vector[2] beta;
  vector[2] mu_beta;
  real<lower=0> sigmasq_y;
  cov_matrix[2] Sigma_beta;
}
//  transformed parameters {
//    real rho; 
//    real alpha0; 
//    //rho <- Sigma_beta[1, 2] / sqrt(Sigma_beta[1, 1] * Sigma_beta[2, 2]);
//    //alpha0 <- mu_beta[1] - mu_beta[2] * xbar; 
//  }
transformed parameters {
  real<lower=0> sigma_y;
  sigma_y = sqrt(sigmasq_y);
}
model {
  sigmasq_y ~ inv_gamma(0.001, 0.001);
  mu_beta ~ normal(0, 100);
  Sigma_beta ~ inv_wishart(2, Omega);
  for (n in 1 : N) {
    beta[n] ~ multi_normal(mu_beta, Sigma_beta);
  }
  for (n in 1 : N) {
    for (t in 1 : T) {
      // centeralize x[] 
      // y[n,t] ~ normal(beta[n, 1] + beta[n, 2] * (x[t] - xbar), sqrt(sigmasq_y));
      
      // NOT-centeralize x[] 
      y[n, t] ~ normal(beta[n, 1] + beta[n, 2] * x[t], sigma_y);
    }
  }
}
if (file.exists("fitted_model.RDS")) {
  fit <- readRDS("fitted_model.RDS")
} else {
  fit <- mod$sample(
    data = data_list,
    seed = 123,
    chains = 4,
    iter_warmup = 5000,
    iter_sampling = 5000,
    parallel_chains = 4,
    refresh = 500,
    max_treedepth = 20,
    adapt_delta = 0.99
  )

  fit$save_object("fitted_model.RDS")
}

Model summary

fit$summary()
# A tibble: 69 × 10
   variable   mean median    sd   mad     q5   q95  rhat ess_bulk ess_tail
   <chr>     <dbl>  <dbl> <dbl> <dbl>  <dbl> <dbl> <dbl>    <dbl>    <dbl>
 1 lp__      -346.  -321. 55.5  29.0  -443.  -291.  1.54     7.25     29.6
 2 beta[1,1]  107.   107.  3.01  2.00  102.   111.  1.18  1900.       28.1
 3 beta[2,1]  104.   106.  5.93  2.40   90.2  109.  1.43     8.24     26.6
 4 beta[3,1]  107.   107.  3.08  2.01  103.   113.  1.18   231.       26.4
 5 beta[4,1]  108.   107.  4.63  2.25  104.   119.  1.31    11.5      26.3
 6 beta[5,1]  103.   106.  7.23  2.47   86.6  109.  1.51     7.44     25.8
 7 beta[6,1]  108.   107.  3.81  2.14  104.   116.  1.25    17.2      26.0
 8 beta[7,1]  105.   106.  4.12  2.19   95.5  109.  1.28    13.0      26.2
 9 beta[8,1]  107.   107.  3.00  1.99  102.   111.  1.17  1931.       26.9
10 beta[9,1]  110.   107.  6.99  2.47  104.   126.  1.50     7.56     26.4
# … with 59 more rows

Diagnostic summary

fit$diagnostic_summary()
$num_divergent
[1] 16  1  0  0

$num_max_treedepth
[1] 0 0 0 0

$ebfmi
[1] 0.1473662 0.1860517 0.1694880 0.7798340

Plot chains

posterior <- fit$draws(format = "draws_df")

bayesplot::color_scheme_set("viridis")

bayesplot::mcmc_trace(posterior, np = bayesplot::nuts_params(fit), pars = c("beta[1,1]", "beta[2,1]", "beta[12,2]", "sigma_y", "Sigma_beta[1,1]"),
           facet_args = list(ncol = 1, strip.position = "left"))

Only plot chains 1,2,3

bayesplot::mcmc_trace(dplyr::filter(posterior, .chain != 4),
                      pars = c("beta[1,1]", "beta[2,1]", "beta[12,2]", "sigma_y", "Sigma_beta[1,1]"),
                      facet_args = list(ncol = 1, strip.position = "left"))