vignettes/advanced-features.Rmd
advanced-features.Rmd
Before we jump into coding, let’s start by loading the package and
the data that we will be using; the lifeexpect
database.
This data set was simulated using official statistics and research on
life expectancy in the US (see ?lifeexpect
for more
details). Each row corresponds to the observed age of an individual at
the time of disease.
## smoke female age
## 1 0 0 78.10840
## 2 0 1 83.42009
## 3 0 1 85.46404
## 4 1 0 69.83565
## 5 0 1 80.40131
## 6 1 0 72.18115
The data is characterized by the following model:
\[\begin{equation*} y_i \sim \mbox{N}\left(\theta_0 + \theta_{smk}smoke_i + \theta_{fem}female_i, \sigma^2\right) \end{equation*}\]
So the logposterior can be written as:
\[\begin{equation*} \sum_i \log\phi\left(\frac{y_i - (\theta_0 + \theta_{smk}smoke_i + \theta_{fem}female_i)}{\sigma}\right) \end{equation*}\]
Where \(y_i\) is the age of the i-th individual. Using R, we could write the logposterior as follows:
In some cases, the user would like to go beyond what
MCMC()
does. In those cases, we can directly access the
environment in which the main loop of the MCMC
-call is
being executed, using the function ith_step()
.
With ith_step()
, we can access the environment
containing the existing elements while the MCMC loop occurs. Among
these, we have: i
(the step number), logpost
(a vector storing the trace of the unnormalized logposterior),
draws
, (a matrix storing the kernel’s proposed states),
etc. The complete list of available objects is available either in the
manual or when printing the function:
# This will show the available objects
ith_step
## Available objects via the ith_step() function:
## - i : (int) Step (iteration) number.
## - nsteps : (int) Number of steps.
## - chain_id : (int) Id of the chain (goes from 1 to -nchains-)
## - theta0 : (double vector) Current state of the chain.
## - theta1 : (double vector) Proposed state of the chain.
## - ans : (double matrix) Set of accepted states (it will be NA for rows >= i).
## - draws : (double matrix) Set of proposed states (it will be NA for rows >= i).
## - logpost : (double vector) Value of -fun- (it will be NA for elements >= i).
## - R : (double vector) Random values from U(0,1). This is used with the Hastings ratio.
## - thin : (int) Thinning (applied after the last step).
## - burnin : (int) Burn-in (applied after the last step).
## - conv_checker : (function) Convergence checker function.
## - kernel : (fmcmc_kernel) Kernel object.
## - fun : (function) Passed function to MCMC.
## - f : (function) Wrapper of -fun-.
## - initial : (double vector) Starting point of the chain.
##
## The following objects always have fixed values (see ?ith_step): nchains, cl, multicore
## Other available objects: cnames, funargs, MCMC_OUTPUT, passedargs, progress
For example, sometimes accidents happen, and your computing environment could crash (R, your PC, your server, etc.). It could be a good idea to keep track of the current state of the chain. A way to do this is printing out the state of the process every n-th step.
Using the lifeexpect
data, let’s rewrite the
logpost()
function using ith_step()
. We will
print the latest accepted state every 1,000 steps:
logpost2 <- function(p, D) {
# Getting the number of step
i <- ith_step("i")
# After the first iteration, every 1000 steps:
if (i > 1L && !(i %% 1000)) {
# The last accepted state. Accepted states are
# stored in -ans-.
s <- ith_step()$ans[i - 1L,]
cat("Step: ", i, " state: c(\n ", paste(s, collapse = ", "), "\n)\n", sep = "")
}
# Just returning the previous value
logpost(p, D)
}
Note that the posterior distribution, i.e., accepted states, is
stored in the matrix ans
within the MCMC loop. Let’s use
the Robust Adaptive Metropolis Kernel to fit this model. Since we need
to estimate the standard error, we can set a lower-bound for the
variables. For the starting point, let’s use the vector
[70, 0, 0, sd(age)]
(more than a good guess!):
# Generating kernel
kern <- kernel_ram(warmup = 1000, lb = c(-100,-100,-100,.001))
# Running MCMC
ans0 <- MCMC(
initial = c(70, 0, 0, sd(lifeexpect$age)),
fun = logpost2,
nsteps = 10000,
kernel = kern,
seed = 555,
D = lifeexpect,
progress = FALSE
)
## Step: 1000 state: c(
## 69.9961016297189, -0.00345239203214055, 0.0070167653826948, 5.90378993611787
## )
## Step: 2000 state: c(
## 70.0504366122137, -0.0292853676293392, 0.076260722355002, 5.97687773201868
## )
## Step: 2000 state: c(
## 70.0504366122137, -0.0292853676293392, 0.076260722355002, 5.97687773201868
## )
## Step: 3000 state: c(
## 70.8055075956862, 0.279930458592591, 0.309906484934879, 6.51371636035569
## )
## Step: 3000 state: c(
## 70.8055075956862, 0.279930458592591, 0.309906484934879, 6.51371636035569
## )
## Step: 4000 state: c(
## 74.4132756921824, -1.22342712322476, 2.14666613219295, 5.7067947055272
## )
## Step: 4000 state: c(
## 74.4132756921824, -1.22342712322476, 2.14666613219295, 5.7067947055272
## )
## Step: 5000 state: c(
## 80.0205621762772, -9.83179753872789, 5.37801732288901, 1.99929550969985
## )
## Step: 5000 state: c(
## 80.0205621762772, -9.83179753872789, 5.37801732288901, 1.99929550969985
## )
## Step: 6000 state: c(
## 79.7724276801441, -9.56814688878825, 5.6486569436619, 2.04970696423494
## )
## Step: 6000 state: c(
## 79.7724276801441, -9.56814688878825, 5.6486569436619, 2.04970696423494
## )
## Step: 7000 state: c(
## 79.8413102462833, -9.80272237981432, 5.63324297617831, 1.97489510035287
## )
## Step: 7000 state: c(
## 79.8413102462833, -9.80272237981432, 5.63324297617831, 1.97489510035287
## )
## Step: 8000 state: c(
## 80.1744325184171, -9.79630960138713, 5.38165644274545, 1.99091898767099
## )
## Step: 8000 state: c(
## 80.1744325184171, -9.79630960138713, 5.38165644274545, 1.99091898767099
## )
## Step: 9000 state: c(
## 79.93479373063, -9.63123904834719, 5.49489884045119, 2.07352104788007
## )
## Step: 9000 state: c(
## 79.93479373063, -9.63123904834719, 5.49489884045119, 2.07352104788007
## )
## Step: 10000 state: c(
## 80.1499993519629, -9.74186365776657, 5.25019766088498, 1.99986537561054
## )
## Step: 10000 state: c(
## 80.1499993519629, -9.74186365776657, 5.25019766088498, 1.99986537561054
## )
The ith_step()
makes MCMC
very easy to
tailor. Now what happens when we deal with multiple chains?
Using the function ith_step()
could be of real help when
dealing with multiple chains in a single run. In such a case, we can use
the variable chain_id
that can be found with
ith_step()
. From the previous example:
logpost3 <- function(p, D) {
# Getting the number of step
i <- ith_step("i")
# After the first iteration, every 1000 steps:
if (i > 1L && !(i %% 1000)) {
# The last accepted state. Accepted states are
# stored in -ans-.
s <- ith_step()$ans[i - 1L,]
chain <- ith_step("chain_id")
cat("Step: ", i, " chain: ", chain, " state: c(\n ",
paste(s, collapse = ",\n "), "\n)\n", sep = ""
)
}
# Just returning the previous value
logpost(p, D)
}
# Rerunning using parallel chains
ans1 <- MCMC(
initial = ans0,
fun = logpost3, # The new version of logpost includes chain
nsteps = 1000,
kernel = kern, # Reusing the kernel
thin = 1,
nchains = 2L, # Two chains, two different prints
multicore = FALSE,
seed = 555,
progress = FALSE,
D = lifeexpect
)
## Step: 1000 chain: 1 state: c(
## 80.3915443781619,
## -10.0470301689399,
## 5.21253077834454,
## 2.0313564405203
## )
## Step: 1000 chain: 1 state: c(
## 80.3915443781619,
## -10.0470301689399,
## 5.21253077834454,
## 2.0313564405203
## )
## Step: 1000 chain: 2 state: c(
## 80.1338926312349,
## -9.83033416603663,
## 5.34896148596224,
## 1.99612829900285
## )
## Step: 1000 chain: 2 state: c(
## 80.1338926312349,
## -9.83033416603663,
## 5.34896148596224,
## 1.99612829900285
## )
Using ith_state()
increases the computational burden of
the process. Yet, since most of the load lies on the objective function
itself, the additional time can be neglected.
Another thing the user may need to do is storing data as the MCMC
algorithm runs. In such cases, you can use the
set_userdata()
function, which, as the name suggests, will
store the required data.
For a simple example, suppose we wanted to store the proposed state, we could do it in the following way:
logpost4 <- function(p, D) {
# Timestamp
set_userdata(
p1 = p[1],
p2 = p[2],
p3 = p[3]
)
# Just returning the previous value
logpost(p, D)
}
# Rerunning using parallel chains
ans1 <- MCMC(
initial = ans0,
fun = logpost4, # The new version of logpost includes chain
nsteps = 1000,
kernel = kern, # Reusing the kernel
thin = 10, # We are adding thinning
nchains = 2L, # Two chains, two different prints
multicore = FALSE,
seed = 555,
progress = FALSE,
D = lifeexpect
)
In this case, since nchains == 2
, MCMC
will
store a list of length two with the user data. To retrieve the generated
data frame, we can call the function get_userdata()
. We can
also inspect the MCMC_OUTPUT
as follows:
print(MCMC_OUTPUT)
## Last call to MCMC holds the following elements:
## Chain N: 1
## draws : num [1:100, 1:4] 80.5 80 80.3 80.4 80 ...
## logpost : Named num [1:100] -2151 -2150 -2142 -2142 -2122 ...
## Chain N: 2
## draws : num [1:100, 1:4] 80.1 80.2 80.3 79.9 79.9 ...
## logpost : Named num [1:100] -2127 -2125 -2124 -2122 -2125 ...
##
## Including the following userdata (use -get_userdata()- to access it):
## Chain N 1
## p1 p2 p3
## 10 80.49477 -9.387659 4.783305
## 20 79.97738 -9.806571 4.881653
## 30 80.33199 -10.182043 5.669943
## 40 80.39235 -9.788603 5.192905
## 50 79.98730 -9.608717 5.231799
## 60 80.60093 -9.843509 5.278163
## ... 94 more...
## Chain N 2
## p1 p2 p3
## 10 80.13857 -9.772963 5.038519
## 20 80.18688 -10.174983 5.463652
## 30 80.32215 -9.929603 5.138446
## 40 79.93485 -9.581106 5.514230
## 50 79.90929 -9.502281 5.229895
## 60 80.04301 -9.991739 5.558207
## ... 94 more...
str(get_userdata())
## List of 2
## $ :'data.frame': 100 obs. of 3 variables:
## ..$ p1: num [1:100] 80.5 80 80.3 80.4 80 ...
## ..$ p2: num [1:100] -9.39 -9.81 -10.18 -9.79 -9.61 ...
## ..$ p3: num [1:100] 4.78 4.88 5.67 5.19 5.23 ...
## $ :'data.frame': 100 obs. of 3 variables:
## ..$ p1: num [1:100] 80.1 80.2 80.3 79.9 79.9 ...
## ..$ p2: num [1:100] -9.77 -10.17 -9.93 -9.58 -9.5 ...
## ..$ p3: num [1:100] 5.04 5.46 5.14 5.51 5.23 ...