Bayesian Survival Models with Stan

Vanderbilt Biostatistics Journal Club

Max Rohde

April 10, 2024

Basics of survival analysis

We define \(X\) to be the time to an event.

\(X\) is a non-negative random variable.

The survival function is the probability of having a survival time greater than \(x\).

\[ S(x) = P(X > x) \]

Define \(F(x)\) to be the CDF of \(X\), then

\[ S(x) = 1 - F(x) \]

Example: Exponential Distribution

PDF

\[ f(x) = \beta \exp(-\beta x) \]

CDF

\[ F(x) = 1 - \exp(-\beta x) \]

Survival function

\[ S(x) = \exp(-\beta x) \]

Example: Exponential Distribution

Survival analysis

Let’s assume we know the data is generated as

\[ X \sim \operatorname{Expo}(\beta) \]

We want to learn the parameter \(\beta\).

Bayesian Inference

Bayesian inference review

Update our background knowledge (prior) with data (likelihood) to obtain our current state of knowledge (posterior)

\[ P(\text{Model} \mid Data) = \frac{P(\text{Data} \mid \text{Model}) \, P(\text{Model})}{P(\text{Data})} \]

Stan

In practice, we use MCMC methods to obtain samples from the posterior.

Stan is a programming language for creating Bayesian models.

Stan uses an MCMC method called Hamiltonian Monte Carlo (HMC).

Exponential survival model with Stan

Exponential survival model with Stan

data {
  // Number of observations
  int<lower=0> N;
  // Observed survival time
  vector[N] x;
}

Exponential survival model with Stan

data {
  // Number of observations
  int<lower=0> N;
  // Observed survival time
  vector[N] x;
}

parameters {
  // Exponential parameter
  real beta;
}

Exponential survival model with Stan

data {
  // Number of observations
  int<lower=0> N;
  // Observed survival time
  vector[N] x;
}

parameters {
  // Exponential parameter
  real beta;
}

model {
  // Prior
  beta ~ uniform(0, 50)
  // Likelihood
  x ~ exponential(beta);
}

Data

Assume the true model is

\[ X \sim \operatorname{Expo}(3) \]

This could be number of years until developing a given disease (e.g., years to first myocardial infarction).

Data

We’ll take 10,000 samples for illustration.

# Generate the data
x <- rexp(n = 10000, rate = 3)

Running the model

cmdstanr is a package that interfaces between R and Stan.

Read and compile the model

mod1 <- cmdstanr::cmdstan_model(stan_file = "expo1.stan")

Format the data

# Format the data for Stan
data_list <- list(N = length(x),
                  x = x)

Running the model

MCMC sampling

# Fit the model
fit <- mod1$sample(data = data_list)
Running MCMC with 4 sequential chains...

Chain 1 Iteration:    1 / 2000 [  0%]  (Warmup) 
Chain 1 Iteration:  100 / 2000 [  5%]  (Warmup) 
Chain 1 Iteration:  200 / 2000 [ 10%]  (Warmup) 
Chain 1 Iteration:  300 / 2000 [ 15%]  (Warmup) 
Chain 1 Iteration:  400 / 2000 [ 20%]  (Warmup) 
Chain 1 Iteration:  500 / 2000 [ 25%]  (Warmup) 
Chain 1 Iteration:  600 / 2000 [ 30%]  (Warmup) 
Chain 1 Iteration:  700 / 2000 [ 35%]  (Warmup) 
Chain 1 Iteration:  800 / 2000 [ 40%]  (Warmup) 
Chain 1 Iteration:  900 / 2000 [ 45%]  (Warmup) 
Chain 1 Iteration: 1000 / 2000 [ 50%]  (Warmup) 
Chain 1 Iteration: 1001 / 2000 [ 50%]  (Sampling) 
Chain 1 Iteration: 1100 / 2000 [ 55%]  (Sampling) 
Chain 1 Iteration: 1200 / 2000 [ 60%]  (Sampling) 
Chain 1 Iteration: 1300 / 2000 [ 65%]  (Sampling) 
Chain 1 Iteration: 1400 / 2000 [ 70%]  (Sampling) 
Chain 1 Iteration: 1500 / 2000 [ 75%]  (Sampling) 
Chain 1 Iteration: 1600 / 2000 [ 80%]  (Sampling) 
Chain 1 Iteration: 1700 / 2000 [ 85%]  (Sampling) 
Chain 1 Iteration: 1800 / 2000 [ 90%]  (Sampling) 
Chain 1 Iteration: 1900 / 2000 [ 95%]  (Sampling) 
Chain 1 Iteration: 2000 / 2000 [100%]  (Sampling) 
Chain 1 finished in 0.1 seconds.
Chain 2 Iteration:    1 / 2000 [  0%]  (Warmup) 
Chain 2 Iteration:  100 / 2000 [  5%]  (Warmup) 
Chain 2 Iteration:  200 / 2000 [ 10%]  (Warmup) 
Chain 2 Iteration:  300 / 2000 [ 15%]  (Warmup) 
Chain 2 Iteration:  400 / 2000 [ 20%]  (Warmup) 
Chain 2 Iteration:  500 / 2000 [ 25%]  (Warmup) 
Chain 2 Iteration:  600 / 2000 [ 30%]  (Warmup) 
Chain 2 Iteration:  700 / 2000 [ 35%]  (Warmup) 
Chain 2 Iteration:  800 / 2000 [ 40%]  (Warmup) 
Chain 2 Iteration:  900 / 2000 [ 45%]  (Warmup) 
Chain 2 Iteration: 1000 / 2000 [ 50%]  (Warmup) 
Chain 2 Iteration: 1001 / 2000 [ 50%]  (Sampling) 
Chain 2 Iteration: 1100 / 2000 [ 55%]  (Sampling) 
Chain 2 Iteration: 1200 / 2000 [ 60%]  (Sampling) 
Chain 2 Iteration: 1300 / 2000 [ 65%]  (Sampling) 
Chain 2 Iteration: 1400 / 2000 [ 70%]  (Sampling) 
Chain 2 Iteration: 1500 / 2000 [ 75%]  (Sampling) 
Chain 2 Iteration: 1600 / 2000 [ 80%]  (Sampling) 
Chain 2 Iteration: 1700 / 2000 [ 85%]  (Sampling) 
Chain 2 Iteration: 1800 / 2000 [ 90%]  (Sampling) 
Chain 2 Iteration: 1900 / 2000 [ 95%]  (Sampling) 
Chain 2 Iteration: 2000 / 2000 [100%]  (Sampling) 
Chain 2 finished in 0.1 seconds.
Chain 3 Iteration:    1 / 2000 [  0%]  (Warmup) 
Chain 3 Iteration:  100 / 2000 [  5%]  (Warmup) 
Chain 3 Iteration:  200 / 2000 [ 10%]  (Warmup) 
Chain 3 Iteration:  300 / 2000 [ 15%]  (Warmup) 
Chain 3 Iteration:  400 / 2000 [ 20%]  (Warmup) 
Chain 3 Iteration:  500 / 2000 [ 25%]  (Warmup) 
Chain 3 Iteration:  600 / 2000 [ 30%]  (Warmup) 
Chain 3 Iteration:  700 / 2000 [ 35%]  (Warmup) 
Chain 3 Iteration:  800 / 2000 [ 40%]  (Warmup) 
Chain 3 Iteration:  900 / 2000 [ 45%]  (Warmup) 
Chain 3 Iteration: 1000 / 2000 [ 50%]  (Warmup) 
Chain 3 Iteration: 1001 / 2000 [ 50%]  (Sampling) 
Chain 3 Iteration: 1100 / 2000 [ 55%]  (Sampling) 
Chain 3 Iteration: 1200 / 2000 [ 60%]  (Sampling) 
Chain 3 Iteration: 1300 / 2000 [ 65%]  (Sampling) 
Chain 3 Iteration: 1400 / 2000 [ 70%]  (Sampling) 
Chain 3 Iteration: 1500 / 2000 [ 75%]  (Sampling) 
Chain 3 Iteration: 1600 / 2000 [ 80%]  (Sampling) 
Chain 3 Iteration: 1700 / 2000 [ 85%]  (Sampling) 
Chain 3 Iteration: 1800 / 2000 [ 90%]  (Sampling) 
Chain 3 Iteration: 1900 / 2000 [ 95%]  (Sampling) 
Chain 3 Iteration: 2000 / 2000 [100%]  (Sampling) 
Chain 3 finished in 0.1 seconds.
Chain 4 Iteration:    1 / 2000 [  0%]  (Warmup) 
Chain 4 Iteration:  100 / 2000 [  5%]  (Warmup) 
Chain 4 Iteration:  200 / 2000 [ 10%]  (Warmup) 
Chain 4 Iteration:  300 / 2000 [ 15%]  (Warmup) 
Chain 4 Iteration:  400 / 2000 [ 20%]  (Warmup) 
Chain 4 Iteration:  500 / 2000 [ 25%]  (Warmup) 
Chain 4 Iteration:  600 / 2000 [ 30%]  (Warmup) 
Chain 4 Iteration:  700 / 2000 [ 35%]  (Warmup) 
Chain 4 Iteration:  800 / 2000 [ 40%]  (Warmup) 
Chain 4 Iteration:  900 / 2000 [ 45%]  (Warmup) 
Chain 4 Iteration: 1000 / 2000 [ 50%]  (Warmup) 
Chain 4 Iteration: 1001 / 2000 [ 50%]  (Sampling) 
Chain 4 Iteration: 1100 / 2000 [ 55%]  (Sampling) 
Chain 4 Iteration: 1200 / 2000 [ 60%]  (Sampling) 
Chain 4 Iteration: 1300 / 2000 [ 65%]  (Sampling) 
Chain 4 Iteration: 1400 / 2000 [ 70%]  (Sampling) 
Chain 4 Iteration: 1500 / 2000 [ 75%]  (Sampling) 
Chain 4 Iteration: 1600 / 2000 [ 80%]  (Sampling) 
Chain 4 Iteration: 1700 / 2000 [ 85%]  (Sampling) 
Chain 4 Iteration: 1800 / 2000 [ 90%]  (Sampling) 
Chain 4 Iteration: 1900 / 2000 [ 95%]  (Sampling) 
Chain 4 Iteration: 2000 / 2000 [100%]  (Sampling) 
Chain 4 finished in 0.1 seconds.

All 4 chains finished successfully.
Mean chain execution time: 0.1 seconds.
Total execution time: 0.8 seconds.

What does the model tell us?

fit$summary()
variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
lp__ 973.193298 973.45000 0.66971045 0.31727640 971.813800 973.679000 1.004603 1695.512 2134.895
beta 2.995555 2.99564 0.02958197 0.03035624 2.947304 3.043432 1.002130 1352.927 1787.371

Final posterior

Censoring

We forgot about one of the most important aspect of survival analysis: censoring!

Assume the study ends after 6 months so that we don’t observe any survival times greater than 0.5.

How does this impact our inference?

Censoring

This doesn’t look good!

A likelihood solution

Removing the censored observations biased our inference because censoring was dependent on the outcome.

We need to account for this in the likelihood!

Let the model know that when there is censoring, the value of x is greater than 0.5.

How do we change the likelihood?

When we observe a survival time at \(x_{obs}\), it enters the likelihood as \(f(x_{obs})\).

When we observe a right-censoring time at \(x_{rc}\), it enters the likelihood as \(1 - F(x_{rc}) = S(x_{rc})\)

Similarly, if we had left-censoring, it would enter the likelihood as \(F(x_{lc})\)

Our first model

data {
  int<lower=0> N;
  vector[N] x;
}
parameters {
  real beta;
}
model {
  beta ~ uniform(0, 50);
  x ~ exponential(beta);
}

Change the data block

data {
  int<lower=0> N_obs;
  int<lower=0> N_cens;
  vector[N_obs] x_obs;
  vector[N_cens] x_cens;
}
parameters {
  real beta;
}
model {
  beta ~ uniform(0, 50);
  x ~ exponential(beta);
}

Change the syntax for likelihood

data {
  int<lower=0> N_obs;
  int<lower=0> N_cens;
  vector[N_obs] x_obs;
  vector[N_cens] x_cens;
}
parameters {
  real beta;
}
model {
  beta ~ uniform(0, 50);
  target += exponential_lpdf(x_obs | beta);
}

Likelihood for censored observations

data {
  int<lower=0> N_obs;
  int<lower=0> N_cens;
  vector[N_obs] x_obs;
  vector[N_cens] x_cens;
}
parameters {
  real beta;
}
model {
  beta ~ uniform(0, 50);
  target += exponential_lpdf(x_obs | beta);
  target += exponential_lccdf(x_cens | beta);
}

Data formatting and model fitting

mod_cens <- cmdstanr::cmdstan_model(stan_file = "expo2.stan")

x_obs <- x[x < 0.5]
N_obs <- length(x_obs)
N_cens <- length(x[x > 0.5])
x_cens <- rep(0.5, N_cens)

# Format the data for Stan
data_list <- list(x_obs = x_obs,
                  x_cens = x_cens,
                  N_obs = N_obs,
                  N_cens = N_cens)

fit3 <- mod_cens$sample(data = data_list)

Problem solved!

A brief look into rstanarm

For applied regression modeling, rstanarm is an R package that contains pre-made Stan models.

  • Standard parametric models (exponential, Weibull, Gompertz)
  • Flexible parametric (spline-based) hazard models
  • Accelerated failure time (AFT) models.
  • All types of censoring (left, right, interval) and left truncation)
  • Time-varying covariates
  • Frailty effects

Questions?