Fitting the birats
example
Load and format data
Data is from: https://github.com/stan-dev/example-models/blob/master/bugs_examples/vol2/birats/birats.data.R
Stan code is from: https://github.com/stan-dev/example-models/blob/master/bugs_examples/vol2/birats/birats.stan
# Load in data from the `birats.data.R` file
N <- 30
T <- 5
y <-
structure(c(151, 145, 147, 155, 135, 159, 141, 159, 177, 134,
160, 143, 154, 171, 163, 160, 142, 156, 157, 152, 154, 139, 146,
157, 132, 160, 169, 157, 137, 153, 199, 199, 214, 200, 188, 210,
189, 201, 236, 182, 208, 188, 200, 221, 216, 207, 187, 203, 212,
203, 205, 190, 191, 211, 185, 207, 216, 205, 180, 200, 246, 249,
263, 237, 230, 252, 231, 248, 285, 220, 261, 220, 244, 270, 242,
248, 234, 243, 259, 246, 253, 225, 229, 250, 237, 257, 261, 248,
219, 244, 283, 293, 312, 272, 280, 298, 275, 297, 350, 260, 313,
273, 289, 326, 281, 288, 280, 283, 307, 286, 298, 267, 272, 285,
286, 303, 295, 289, 258, 286, 320, 354, 328, 297, 323, 331, 305,
338, 376, 296, 352, 314, 325, 358, 312, 324, 316, 317, 336, 321,
334, 302, 302, 323, 331, 345, 333, 316, 291, 324), .Dim = c(30,
5))
x <- c(8.0, 15.0, 22.0, 29.0, 36.0)
xbar <- 22
Omega <- structure(c(0.005, 0, 0, 5), .Dim = c(2, 2))
# Create the data list
data_list <- list(N = N,
T = T,
y = y,
x = x,
xbar = xbar,
Omega = Omega)
Compile and fit model
# Compile the model
mod <- cmdstan_model("birats.stan")
mod$print()
// http://www.mrc-bsu.cam.ac.uk/bugs/winbugs/Vol2.pdf
// Page 23: Birats
//#
data {
int<lower=0> N;
int<lower=0> T;
array[T] real x;
real xbar;
array[N, T] real y;
cov_matrix[2] Omega;
}
parameters {
array[N] vector[2] beta;
vector[2] mu_beta;
real<lower=0> sigmasq_y;
cov_matrix[2] Sigma_beta;
}
// transformed parameters {
// real rho;
// real alpha0;
// //rho <- Sigma_beta[1, 2] / sqrt(Sigma_beta[1, 1] * Sigma_beta[2, 2]);
// //alpha0 <- mu_beta[1] - mu_beta[2] * xbar;
// }
transformed parameters {
real<lower=0> sigma_y;
sigma_y = sqrt(sigmasq_y);
}
model {
sigmasq_y ~ inv_gamma(0.001, 0.001);
mu_beta ~ normal(0, 100);
Sigma_beta ~ inv_wishart(2, Omega);
for (n in 1 : N) {
beta[n] ~ multi_normal(mu_beta, Sigma_beta);
}
for (n in 1 : N) {
for (t in 1 : T) {
// centeralize x[]
// y[n,t] ~ normal(beta[n, 1] + beta[n, 2] * (x[t] - xbar), sqrt(sigmasq_y));
// NOT-centeralize x[]
y[n, t] ~ normal(beta[n, 1] + beta[n, 2] * x[t], sigma_y);
}
}
}
if (file.exists("fitted_model.RDS")) {
fit <- readRDS("fitted_model.RDS")
} else {
fit <- mod$sample(
data = data_list,
seed = 123,
chains = 4,
iter_warmup = 5000,
iter_sampling = 5000,
parallel_chains = 4,
refresh = 500,
max_treedepth = 20,
adapt_delta = 0.99
)
fit$save_object("fitted_model.RDS")
}
Model summary
fit$summary()
# A tibble: 69 × 10
variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
<chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 lp__ -346. -321. 55.5 29.0 -443. -291. 1.54 7.25 29.6
2 beta[1,1] 107. 107. 3.01 2.00 102. 111. 1.18 1900. 28.1
3 beta[2,1] 104. 106. 5.93 2.40 90.2 109. 1.43 8.24 26.6
4 beta[3,1] 107. 107. 3.08 2.01 103. 113. 1.18 231. 26.4
5 beta[4,1] 108. 107. 4.63 2.25 104. 119. 1.31 11.5 26.3
6 beta[5,1] 103. 106. 7.23 2.47 86.6 109. 1.51 7.44 25.8
7 beta[6,1] 108. 107. 3.81 2.14 104. 116. 1.25 17.2 26.0
8 beta[7,1] 105. 106. 4.12 2.19 95.5 109. 1.28 13.0 26.2
9 beta[8,1] 107. 107. 3.00 1.99 102. 111. 1.17 1931. 26.9
10 beta[9,1] 110. 107. 6.99 2.47 104. 126. 1.50 7.56 26.4
# … with 59 more rows
Diagnostic summary
fit$diagnostic_summary()
$num_divergent
[1] 16 1 0 0
$num_max_treedepth
[1] 0 0 0 0
$ebfmi
[1] 0.1473662 0.1860517 0.1694880 0.7798340
Plot chains
posterior <- fit$draws(format = "draws_df")
bayesplot::color_scheme_set("viridis")
bayesplot::mcmc_trace(posterior, np = bayesplot::nuts_params(fit), pars = c("beta[1,1]", "beta[2,1]", "beta[12,2]", "sigma_y", "Sigma_beta[1,1]"),
facet_args = list(ncol = 1, strip.position = "left"))
Only plot chains 1,2,3
bayesplot::mcmc_trace(dplyr::filter(posterior, .chain != 4),
pars = c("beta[1,1]", "beta[2,1]", "beta[12,2]", "sigma_y", "Sigma_beta[1,1]"),
facet_args = list(ncol = 1, strip.position = "left"))