Prerequisites

We start by loading the required packages and setting up the parameters for our optimization.

library(latex2exp)
library(foreach)
library(doParallel)
## Loading required package: iterators
## Loading required package: parallel
options(warn=-1)
# numCores <- commandArgs(trailingOnly=TRUE)[1]
# numCores <- as.numeric(numCores) - 1
numCores = 1
registerDoParallel(cores=numCores)
# print(paste('number of cores is ',numCores, ' for YV'))

library(doRNG)
## Loading required package: rngtools
registerDoRNG(625904618)
library(tidyverse)
## ── Attaching packages ─────────────────────────────────────── tidyverse 1.3.1 ──
## ✓ ggplot2 3.3.5     ✓ purrr   0.3.4
## ✓ tibble  3.1.5     ✓ dplyr   1.0.7
## ✓ tidyr   1.1.4     ✓ stringr 1.4.0
## ✓ readr   1.4.0     ✓ forcats 0.5.1
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## x purrr::accumulate() masks foreach::accumulate()
## x dplyr::filter()     masks stats::filter()
## x dplyr::lag()        masks stats::lag()
## x purrr::when()       masks foreach::when()
library(pomp)
## 
## Attaching package: 'pomp'
## The following object is masked from 'package:purrr':
## 
##     map
options(stringsAsFactors=FALSE)
stopifnot(packageVersion("pomp")>="3.0")

# We define three run levels. 1 can be used for diagnostic and 3 is used for actual optimizations
run_level <- 3   
covid_Np <-          switch(run_level,100, 1e3, 2e4)
covid_Nmif <-        switch(run_level, 10, 100, 100)
covid_Nreps_eval <-  switch(run_level,  2,  10,  10)
covid_Nreps_local <- switch(run_level, 10,  10,  10)
covid_Nreps_global <-switch(run_level, 10,  20,  20)
covid_Nsim <-        switch(run_level, 50, 100, 100)

Model Components Specifications

We now define the pomp model components including process model, measurement model, initial states, parameters and parameters' transformations.

Process Model

#########################################################################
#--------------------------|  rproc  |----------------------------------#
#########################################################################
rproc <- Csnippet("
                  double beta, foi, dw, births, mu_SE;
                  
                  //we only consider those that participate in the epidemic:
                  double pop = S + E + I + R;
                  
                  // transmission rate
                  beta = b0;
                  
                  // expected force of infection. iota: imported infections
                  // alpha mixing parameter, = 1:homogeneous mixing
                  foi = beta*pow(I+iota, alpha)/pop;
                  
                  // white noise (extrademographic stochasticity)
                  dw = rgammawn(sigmaSE,dt);
                  
                  mu_SE = foi*dw/dt;  // stochastic force of infection
                  
                  // Poisson births: fraction of leak into S from N
                  births = rpois(br*dt);
                  
                  
                  // State Updates:
                  double dN_SE  = rbinom(S , 1-exp(-mu_SE  *dt));
                  double dN_EI  = rbinom(E , 1-exp(-mu_EI  *dt));
                  double dN_IR  = rbinom(I , 1-exp(-mu_IR *dt));
                  S += births - dN_SE;
                  E += dN_SE  - dN_EI;
                  I += dN_EI  - dN_IR;
                  R += dN_IR;
                  W += (dw - dt)/sigmaSE;  // standardized i.i.d. white noise
                  ")

Inital Distribution

#########################################################################
#--------------------------|  rinit  |----------------------------------#
#########################################################################
rinit <- Csnippet("
                  double m = eta*N;
                  S = nearbyint(m*S_0);
                  E = nearbyint(m*E_0);
                  I = nearbyint(m*I_0);
                  R = nearbyint(m*R_0);
                  W = 0;
                  ")

Measurement Model

SEIR-VY

#########################################################################
#--------------------------|  dmeas  |----------------------------------#
#########################################################################
dmeas_VY <- Csnippet("
                  // Model for Viral Load
                  double shed_cases = E + I;
                  double mu_V = rho_V*shed_cases;
                  //double std_V = sqrt(mu_V*(1+od_V));
                  double lik_V = dnorm(V, mu_V, sd_V, 1);
                  
                  // Model for Case Counts
                  double mu_Y = rho_Y*I;
                  double std_Y = sqrt(mu_Y*(1+od_Y));
                  double lik_Y;
                  
                  if (Y > 0.0) {
                  lik_Y = pnorm(Y+0.5,mu_Y,std_Y,1,1)
                  - pnorm(Y-0.5,mu_Y,std_Y,1,1);
                  } else {
                  lik_Y = pnorm(Y+0.5,mu_Y,std_Y,1,1);
                  }
                  
                  // Combined likelihood
                  lik = lik_V + lik_Y;
                  //lik = lik_V;
                  //lik = lik_Y;
                  lik = (give_log) ? lik : exp(lik);
                  
                  ")

#########################################################################
#--------------------------|  rmeas  |----------------------------------#
#########################################################################
rmeas_VY <- Csnippet("
                  // Viral Load
                  double shed_cases = E + I;
                  double mu_V = rho_V*shed_cases;
                  //double std_V = sqrt(mu_V*(1+od_V));
                  V = rnorm(mu_V, sd_V);
                  
                  // Case Counts
                  double mu_Y = rho_Y*I;
                  double std_Y = sqrt(mu_Y*(1+od_Y));
                  Y = rnorm(mu_Y, std_Y);
                  if (Y > 0.0) {
                  Y = nearbyint(Y);
                  } else {
                  Y = 0.0;
                  }
                  
                  ")

SEIR-V

#########################################################################
#--------------------------|  dmeas  |----------------------------------#
#########################################################################
dmeas_V <- Csnippet("
                  // Model for Viral Load
                  double shed_cases = E + I;
                  double mu_V = rho_V*shed_cases;
                  //double std_V = sqrt(mu_V*(1+od_V));
                  double lik_V = dnorm(V, mu_V, sd_V, 1);
                  
                  // Model for Case Counts
                  double mu_Y = rho_Y*I;
                  double std_Y = sqrt(mu_Y*(1+od_Y));
                  double lik_Y;
                  
                  if (Y > 0.0) {
                  lik_Y = pnorm(Y+0.5,mu_Y,std_Y,1,1)
                  - pnorm(Y-0.5,mu_Y,std_Y,1,1);
                  } else {
                  lik_Y = pnorm(Y+0.5,mu_Y,std_Y,1,1);
                  }
                  
                  // Combined likelihood
                  //lik = lik_V + lik_Y;
                  lik = lik_V;
                  //lik = lik_Y;
                  lik = (give_log) ? lik : exp(lik);
                  
                  ")

#########################################################################
#--------------------------|  rmeas  |----------------------------------#
#########################################################################
rmeas_V <- Csnippet("
                  // Viral Load
                  double shed_cases = E + I;
                  double mu_V = rho_V*shed_cases;
                  //double std_V = sqrt(mu_V*(1+od_V));
                  V = rnorm(mu_V, sd_V);
                  
                  // Case Counts
                  double mu_Y = rho_Y*I;
                  double std_Y = sqrt(mu_Y*(1+od_Y));
                  Y = rnorm(mu_Y, std_Y);
                  if (Y > 0.0) {
                  Y = nearbyint(Y);
                  } else {
                  Y = 0.0;
                  }
                  
                  ")

SEIR-Y

#########################################################################
#--------------------------|  dmeas  |----------------------------------#
#########################################################################
dmeas_Y <- Csnippet("
                  // Model for Viral Load
                  double shed_cases = E + I;
                  double mu_V = rho_V*shed_cases;
                  //double std_V = sqrt(mu_V*(1+od_V));
                  double lik_V = dnorm(V, mu_V, sd_V, 1);
                  
                  // Model for Case Counts
                  double mu_Y = rho_Y*I;
                  double std_Y = sqrt(mu_Y*(1+od_Y));
                  double lik_Y;
                  
                  if (Y > 0.0) {
                  lik_Y = pnorm(Y+0.5,mu_Y,std_Y,1,1)
                  - pnorm(Y-0.5,mu_Y,std_Y,1,1);
                  } else {
                  lik_Y = pnorm(Y+0.5,mu_Y,std_Y,1,1);
                  }
                  
                  // Combined likelihood
                  //lik = lik_V + lik_Y;
                  //lik = lik_V;
                  lik = lik_Y;
                  lik = (give_log) ? lik : exp(lik);
                  
                  ")

#########################################################################
#--------------------------|  rmeas  |----------------------------------#
#########################################################################
rmeas_Y <- Csnippet("
                  // Viral Load
                  double shed_cases = E + I;
                  double mu_V = rho_V*shed_cases;
                  //double std_V = sqrt(mu_V*(1+od_V));
                  V = rnorm(mu_V, sd_V);
                  
                  // Case Counts
                  double mu_Y = rho_Y*I;
                  double std_Y = sqrt(mu_Y*(1+od_Y));
                  Y = rnorm(mu_Y, std_Y);
                  if (Y > 0.0) {
                  Y = nearbyint(Y);
                  } else {
                  Y = 0.0;
                  }
                  
                  ")

Parameters

#########################################################################
#-------------------------|  Parameters  |------------------------------#
#########################################################################
parameters = c(
  "b0", "alpha", "iota",      
  "sigmaSE",                  
  "br",                       
  "mu_EI", "mu_IR",           
  "N",                        
  "eta",                      
  "rho_V", "sd_V",            
  "rho_Y", "od_Y",            
  "S_0","E_0","I_0", "R_0")


par_trans = parameter_trans(
  log = c(
    "b0", "alpha", "iota","sigmaSE", "br",
    "rho_V", "sd_V", "od_Y"),
  logit = c("mu_EI", "mu_IR","eta", "rho_Y"),
  barycentric=c("S_0","E_0","I_0", "R_0")
)
states = c("S", "E", "I", "R", "W")

Data

Here we read our data. We then apply a delay of 5 days for the case counts (as discussed in the paper)

data = read_csv("https://raw.githubusercontent.com/Shakeri-Lab/COVID-SEIR/main/Data/abm.csv")
## 
## ── Column specification ────────────────────────────────────────────────────────
## cols(
##   dates = col_date(format = ""),
##   VAX_count = col_double(),
##   day = col_double(),
##   sdm = col_double(),
##   events = col_double(),
##   I_1 = col_double(),
##   I_2 = col_double(),
##   I_3 = col_double(),
##   Y_1 = col_double(),
##   Y_2 = col_double(),
##   Y_3 = col_double(),
##   V_1 = col_double(),
##   V_2 = col_double(),
##   V_3 = col_double(),
##   Infected = col_double(),
##   Y = col_double(),
##   V = col_double(),
##   logV = col_double()
## )
#########################################################################
#-------------------------|  Covariates  |------------------------------#
#########################################################################

sdm_covar <- covariate_table(
  t=      data[["day"]],
  sdmm=   data[["sdm"]],
  event=  data[["events"]],
  order=  "constant",
  times=  "t"
)

# shifting case counts by the assumed reporting delay
rep_del = 5    # reporting delay
data %>% mutate_at(c("Y_1"), 
                       tibble::lst("Y_1"=lead), 
                       n=rep_del) %>%
  mutate_at(c("Y_2"), 
            tibble::lst("Y_2"=lead), 
            n=rep_del)%>%
  mutate_at(c("Y_3"), 
            tibble::lst("Y_3"=lead), 
            n=rep_del)%>%
  mutate_at(c("Y"), 
            tibble::lst("Y"=lead), 
            n=rep_del)-> data_c

# focusing on the first peak for now
data <- data_c[1:70,]

pomp Model

Now that we have the model components along with the data, we can define our pomp model.

SEIR-VY

covidSEIRsR_VY = data %>%
  select(-logV) %>%
  rename(V = V,
         Y = Y
  ) %>%
  pomp(
    times = "day", # column name of data that corresponds to time
    t0 = 0,        # starting time
    # rprocess = discrete_time(rproc, delta.t=1), # daily
    rprocess = euler(rproc, delta.t=1/6), # every four
    rinit = rinit,
    rmeasure = rmeas_VY,
    dmeasure = dmeas_VY,
    accumvars= c("W"),
    partrans = par_trans,
    statenames = states,
    paramnames = parameters,
    covar=sdm_covar
  )

SEIR-V

covidSEIRsR_V = data %>%
  select(-logV) %>%
  rename(V = V,
         Y = Y
  ) %>%
  pomp(
    times = "day", # column name of data that corresponds to time
    t0 = 0,        # starting time
    # rprocess = discrete_time(rproc, delta.t=1), # daily
    rprocess = euler(rproc, delta.t=1/6), # every four
    rinit = rinit,
    rmeasure = rmeas_V,
    dmeasure = dmeas_V,
    accumvars= c("W"),
    partrans = par_trans,
    statenames = states,
    paramnames = parameters,
    covar=sdm_covar
  )

SEIR-Y

covidSEIRsR_Y = data %>%
  select(-logV) %>%
  rename(V = V,
         Y = Y
  ) %>%
  pomp(
    times = "day", # column name of data that corresponds to time
    t0 = 0,        # starting time
    # rprocess = discrete_time(rproc, delta.t=1), # daily
    rprocess = euler(rproc, delta.t=1/6), # every four
    rinit = rinit,
    rmeasure = rmeas_Y,
    dmeasure = dmeas_Y,
    accumvars= c("W"),
    partrans = par_trans,
    statenames = states,
    paramnames = parameters,
    covar=sdm_covar
  )

Simulations

As a quick check of the model we run some simulations with a set of guess parameters.

SEIR-VY

params_guess = c(
  b0=0.013, alpha=1.62, iota=5,
  sigmaSE=0.8,
  br=2,
  mu_EI=.16, mu_IR=0.13, # state transition
  rho_V=150, sd_V=1000,                    # measurement V
  rho_Y=.14, od_Y=0,                   # measurement Y
  eta=.05, N=50000,                   # initial value parameters
  S_0=.95, E_0=.04, I_0=.01, R_0=.0)


y = covidSEIRsR_VY %>%
  simulate(params=params_guess, nsim=250, format="data.frame")

y_avg = y %>% group_by(day) %>% summarize_at(vars(S:R, V, Y), mean)


observed = data %>%
  mutate(actual.cases = Y / params_guess['rho_Y']) %>%
  select(day, V = V, Y = actual.cases) %>%
  pivot_longer(c(V, Y))

y %>% pivot_longer(c(V, Y)) %>%
  ggplot(aes(x = day, y = value)) +
  geom_line(aes(color = factor(.id))) +
  geom_line(data = y_avg %>% pivot_longer(c(V, Y)),
            size=2, color="blue") +
  geom_line(data = observed, color="black", size=2) +
  scale_color_brewer(type = 'qual', palette = 3) +
  guides(color = FALSE) +
  facet_wrap(~name, scales="free_y")

SEIR-V

params_guess = c(
  b0=0.013, alpha=1.62, iota=5,
  sigmaSE=0.8,
  br=2,
  mu_EI=.16, mu_IR=0.13, # state transition
  rho_V=150, sd_V=1000,                    # measurement V
  rho_Y=.14, od_Y=0,                   # measurement Y
  eta=.05, N=50000,                   # initial value parameters
  S_0=.95, E_0=.04, I_0=.01, R_0=.0)


y = covidSEIRsR_V %>%
  simulate(params=params_guess, nsim=250, format="data.frame")

y_avg = y %>% group_by(day) %>% summarize_at(vars(S:R, V, Y), mean)


observed = data %>%
  mutate(actual.cases = Y / params_guess['rho_Y']) %>%
  select(day, V = V, Y = actual.cases) %>%
  pivot_longer(c(V, Y))

y %>% pivot_longer(c(V, Y)) %>%
  ggplot(aes(x = day, y = value)) +
  geom_line(aes(color = factor(.id))) +
  geom_line(data = y_avg %>% pivot_longer(c(V, Y)),
            size=2, color="blue") +
  geom_line(data = observed, color="black", size=2) +
  scale_color_brewer(type = 'qual', palette = 3) +
  guides(color = FALSE) +
  facet_wrap(~name, scales="free_y")

SEIR-V

params_guess = c(
  b0=0.013, alpha=1.62, iota=5,
  sigmaSE=0.8,
  br=2,
  mu_EI=.16, mu_IR=0.13, # state transition
  rho_V=150, sd_V=1000,                    # measurement V
  rho_Y=.14, od_Y=0,                   # measurement Y
  eta=.05, N=50000,                   # initial value parameters
  S_0=.95, E_0=.04, I_0=.01, R_0=.0)


y = covidSEIRsR_Y %>%
  simulate(params=params_guess, nsim=250, format="data.frame")

y_avg = y %>% group_by(day) %>% summarize_at(vars(S:R, V, Y), mean)


observed = data %>%
  mutate(actual.cases = Y / params_guess['rho_Y']) %>%
  select(day, V = V, Y = actual.cases) %>%
  pivot_longer(c(V, Y))

y %>% pivot_longer(c(V, Y)) %>%
  ggplot(aes(x = day, y = value)) +
  geom_line(aes(color = factor(.id))) +
  geom_line(data = y_avg %>% pivot_longer(c(V, Y)),
            size=2, color="blue") +
  geom_line(data = observed, color="black", size=2) +
  scale_color_brewer(type = 'qual', palette = 3) +
  guides(color = FALSE) +
  facet_wrap(~name, scales="free_y")

Particle Filtering

As a secondary check we apply particle filter and make sure the our model is capable of filtering.

SEIR-VY

#########################################################################
#----------------------|  Particle Filtering  |-------------------------#
#########################################################################

tic <- Sys.time()

foreach(i=1:10,.combine=c) %dopar% {
  library(pomp)
  covidSEIRsR_VY %>% pfilter(params=params_guess,Np=5000)
} -> pf

pf %>% logLik() %>% logmeanexp(se=TRUE) -> L_pf
L_pf
##                        se 
## -846.5972127    0.5133548
toc <- Sys.time()

SEIR-V

#########################################################################
#----------------------|  Particle Filtering  |-------------------------#
#########################################################################

tic <- Sys.time()

foreach(i=1:10,.combine=c) %dopar% {
  library(pomp)
  covidSEIRsR_V %>% pfilter(params=params_guess,Np=5000)
} -> pf

pf %>% logLik() %>% logmeanexp(se=TRUE) -> L_pf
L_pf
##                    se 
## -799.66973    1.86635
toc <- Sys.time()

SEIR-Y

#########################################################################
#----------------------|  Particle Filtering  |-------------------------#
#########################################################################

tic <- Sys.time()

foreach(i=1:10,.combine=c) %dopar% {
  library(pomp)
  covidSEIRsR_Y %>% pfilter(params=params_guess,Np=5000)
} -> pf

pf %>% logLik() %>% logmeanexp(se=TRUE) -> L_pf
L_pf
##                      se 
## -16.7722685   0.1589747
toc <- Sys.time()

Results

We can check out all of the parameters explored...

SEIR-VY

pairs(~loglik+b0+alpha+iota+br+rho_V+sd_V, data=res_VY, pch=16)

SEIR-V

pairs(~loglik+b0+alpha+iota+br+sigmaSE, data=res_V, pch=16)

SEIR-Y

pairs(~loglik+b0+alpha+iota+br+sigmaSE, data=res_Y, pch=16)

Poorman's profiles

We can plot the poorman's profile for different parameters to get a sense of the MLE values for the parameters. We showed the poorman's profile for three of the parameters, \(\sigma_{SE}\), \(\alpha\), and \(\beta_0\) for the three models. One can investigate the other parameters too.

SEIR-VY

\(\sigma_{SE}\)

res_VY %>%
  filter(loglik > max(loglik)-10) %>%
  group_by(cut=round(sigmaSE,2)) %>%
  filter(rank(-loglik)<2) %>%
  ggplot(aes(x=sigmaSE, y=loglik)) +
  geom_point()+
  labs(title = TeX('Log-lik vs.$\\sigma_{SE}$'), x=TeX('$\\sigma_{SE}$')) +
  theme(plot.title = element_text(size=18, face="bold", family="Times New Roman", hjust = 0.5),
        axis.title.x = element_text(size = 15, family = "Times New Roman"),
        axis.title.y = element_text(size = 15, family = "Times New Roman"),
        axis.text = element_text( size = 12, family = "Times New Roman"),
        legend.text = element_text(size = 10, family = "Times New Roman"),
        legend.title = element_blank())

\(\alpha\)

res_VY %>%
  filter(loglik > max(loglik)-5) %>%
  group_by(cut=round(alpha,2)) %>%
  filter(rank(-loglik)<2) %>%
  ggplot(aes(x=alpha, y=loglik)) +
  geom_point()+
  labs(title = TeX('Log-lik vs.$\\alpha$'), x=TeX('$\\alpha$')) +
  theme(plot.title = element_text(size=18, face="bold", family="Times New Roman", hjust = 0.5),
        axis.title.x = element_text(size = 15, family = "Times New Roman"),
        axis.title.y = element_text(size = 15, family = "Times New Roman"),
        axis.text = element_text( size = 12, family = "Times New Roman"),
        legend.text = element_text(size = 10, family = "Times New Roman"),
        legend.title = element_blank())

\(\beta_0\)

res_VY %>%
  filter(loglik > max(loglik)-3) %>%
  group_by(cut=round(b0,2)) %>%
  filter(rank(-loglik)<2) %>%
  ggplot(aes(x=b0, y=loglik)) +
  geom_point()+
  labs(title = TeX('Log-lik vs.$\\beta_0$'), x=TeX('$\\beta_0$')) +
  theme(plot.title = element_text(size=18, face="bold", family="Times New Roman", hjust = 0.5),
        axis.title.x = element_text(size = 15, family = "Times New Roman"),
        axis.title.y = element_text(size = 15, family = "Times New Roman"),
        axis.text = element_text( size = 12, family = "Times New Roman"),
        legend.text = element_text(size = 10, family = "Times New Roman"),
        legend.title = element_blank())

SEIR-V

\(\sigma_{SE}\)

res_V %>%
  filter(loglik > max(loglik)-10) %>%
  group_by(cut=round(sigmaSE,2)) %>%
  filter(rank(-loglik)<2) %>%
  ggplot(aes(x=sigmaSE, y=loglik)) +
  geom_point()+
  labs(title = TeX('Log-lik vs.$\\sigma_{SE}$'), x=TeX('$\\sigma_{SE}$')) +
  theme(plot.title = element_text(size=18, face="bold", family="Times New Roman", hjust = 0.5),
        axis.title.x = element_text(size = 15, family = "Times New Roman"),
        axis.title.y = element_text(size = 15, family = "Times New Roman"),
        axis.text = element_text( size = 12, family = "Times New Roman"),
        legend.text = element_text(size = 10, family = "Times New Roman"),
        legend.title = element_blank())

\(\alpha\)

res_V %>%
  filter(loglik > max(loglik)-5) %>%
  group_by(cut=round(alpha,2)) %>%
  filter(rank(-loglik)<2) %>%
  ggplot(aes(x=alpha, y=loglik)) +
  geom_point()+
  labs(title = TeX('Log-lik vs.$\\alpha$'), x=TeX('$\\alpha$')) +
  theme(plot.title = element_text(size=18, face="bold", family="Times New Roman", hjust = 0.5),
        axis.title.x = element_text(size = 15, family = "Times New Roman"),
        axis.title.y = element_text(size = 15, family = "Times New Roman"),
        axis.text = element_text( size = 12, family = "Times New Roman"),
        legend.text = element_text(size = 10, family = "Times New Roman"),
        legend.title = element_blank())

\(\beta_0\)

res_V %>%
  filter(loglik > max(loglik)-3) %>%
  group_by(cut=round(b0,2)) %>%
  filter(rank(-loglik)<2) %>%
  ggplot(aes(x=b0, y=loglik)) +
  geom_point()+
  labs(title = TeX('Log-lik vs.$\\beta_0$'), x=TeX('$\\beta_0$')) +
  theme(plot.title = element_text(size=18, face="bold", family="Times New Roman", hjust = 0.5),
        axis.title.x = element_text(size = 15, family = "Times New Roman"),
        axis.title.y = element_text(size = 15, family = "Times New Roman"),
        axis.text = element_text( size = 12, family = "Times New Roman"),
        legend.text = element_text(size = 10, family = "Times New Roman"),
        legend.title = element_blank())

SEIR-Y

\(\sigma_{SE}\)

res_Y %>%
  filter(loglik > max(loglik)-10) %>%
  group_by(cut=round(sigmaSE,2)) %>%
  filter(rank(-loglik)<2) %>%
  ggplot(aes(x=sigmaSE, y=loglik)) +
  geom_point()+
  labs(title = TeX('Log-lik vs.$\\sigma_{SE}$'), x=TeX('$\\sigma_{SE}$')) +
  theme(plot.title = element_text(size=18, face="bold", family="Times New Roman", hjust = 0.5),
        axis.title.x = element_text(size = 15, family = "Times New Roman"),
        axis.title.y = element_text(size = 15, family = "Times New Roman"),
        axis.text = element_text( size = 12, family = "Times New Roman"),
        legend.text = element_text(size = 10, family = "Times New Roman"),
        legend.title = element_blank())

\(\alpha\)

res_Y %>%
  filter(loglik > max(loglik)-5) %>%
  group_by(cut=round(alpha,2)) %>%
  filter(rank(-loglik)<2) %>%
  ggplot(aes(x=alpha, y=loglik)) +
  geom_point()+
  labs(title = TeX('Log-lik vs.$\\alpha$'), x=TeX('$\\alpha$')) +
  theme(plot.title = element_text(size=18, face="bold", family="Times New Roman", hjust = 0.5),
        axis.title.x = element_text(size = 15, family = "Times New Roman"),
        axis.title.y = element_text(size = 15, family = "Times New Roman"),
        axis.text = element_text( size = 12, family = "Times New Roman"),
        legend.text = element_text(size = 10, family = "Times New Roman"),
        legend.title = element_blank())

\(\beta_0\)

res_Y %>%
  filter(loglik > max(loglik)-3) %>%
  group_by(cut=round(b0,2)) %>%
  filter(rank(-loglik)<2) %>%
  ggplot(aes(x=b0, y=loglik)) +
  geom_point()+
  labs(title = TeX('Log-lik vs.$\\beta_0$'), x=TeX('$\\beta_0$')) +
  theme(plot.title = element_text(size=18, face="bold", family="Times New Roman", hjust = 0.5),
        axis.title.x = element_text(size = 15, family = "Times New Roman"),
        axis.title.y = element_text(size = 15, family = "Times New Roman"),
        axis.text = element_text( size = 12, family = "Times New Roman"),
        legend.text = element_text(size = 10, family = "Times New Roman"),
        legend.title = element_blank())

Simulation comparison

data_cal <- data_c[1:70,]
data_pred <- data_c[70:100,]

res <- bind_rows("V & Y" = res_VY, "V" = res_V, "Y" = res_Y, .id = "type")

res %>%
  filter(type=="V & Y") %>%
  filter(loglik == max(loglik)) %>%
  select(-loglik.se, -loglik, -type) -> params_opt_VY

res %>%
  filter(type=="V") %>%
  filter(loglik == max(loglik)) %>%
  select(-loglik.se, -loglik, -type) -> params_opt_V

res %>%
  filter(type=="Y") %>%
  filter(rank(-loglik)==2) %>%
  select(-loglik.se, -loglik, -type) -> params_opt_Y

params_sim = params_opt_Y
y_Y = SEIR_Y %>%
  simulate(params=params_sim, nsim=1000, format="data.frame")

params_sim = params_opt_V
y_V = SEIR_V %>%
  simulate(params=params_sim, nsim=1000, format="data.frame")

params_sim = params_opt_VY
y_VY = SEIR_VY %>%
  simulate(params=params_sim, nsim=1000, format="data.frame")

y <- bind_rows("VY" = y_VY, "V" = y_V, "Y" = y_Y, .id = "type")

y %>% group_by(day, type) %>% summarize_at(vars(S:R, V, Y), mean) -> y_avg
y %>% group_by(day, type) %>% summarize_at(vars(V, Y), sd) -> y_sd

observed = data_cal %>%
  mutate(actual.cases = Y / 0.14) %>%
  select(day, V = V, Y = actual.cases)


y_avg$V_low <- y_avg$V - 0.5*y_sd$V
y_avg$V_up  <- y_avg$V + 0.5*y_sd$V
y_avg$Y_low <- y_avg$Y - 0.5*y_sd$Y
y_avg$Y_up  <- y_avg$Y + 0.5*y_sd$Y

observed$V_low <- observed$V
observed$V_up  <- observed$V
observed$Y_low <- observed$Y
observed$Y_up  <- observed$Y
observed$type  <- "data"

y_avg <- y_avg %>%
  select(day, type, V, V_low, V_up, Y, Y_low, Y_up)

y_avg <- rbind(y_avg, observed)

y_avg$V_low <- y_avg$V_low/1000
y_avg$V <- y_avg$V/1000
y_avg$V_up <- y_avg$V_up/1000

cbPalette <- c("#000000", "#CC0000", "#56B4E9", "#009E73", "#0072B2", "#D55E00", "#E69F00")

y_avg %>% 
  group_by(type) %>%
  ggplot(aes(x = day, y = V)) +
  geom_line(aes(color = type), size=1.5) + 
  geom_ribbon(aes(ymin=V_low, ymax=V_up, fill = type), alpha = 0.3)+
  # geom_line(data = observed, aes(x=day, y=V), 
  #           color="black", size=1.5, show.legend = c("data")) +
  scale_color_manual(values=cbPalette) +
  scale_fill_manual(values=cbPalette)+
  labs(y=TeX('Viral Load ($\\times 1e3$)')) +
  theme_bw() +
  theme(plot.title = element_text(size=16, face="bold", family="Times New Roman", hjust = 0.5),
        axis.title.x = element_text(size = 22, family = "Times New Roman"),
        axis.title.y = element_text(size = 22, family = "Times New Roman"),
        axis.text = element_text(size = 21, family = "Times New Roman"),
        legend.text = element_text(size = 12, family = "Times New Roman"),
        legend.title = element_blank())

y_avg %>% 
  group_by(type) %>%
  ggplot(aes(x = day, y = Y)) +
  geom_line(aes(color = type), size=1.5) +
  geom_ribbon(aes(ymin=Y_low, ymax=Y_up, fill = type), alpha = 0.3)+
  scale_color_manual(values=cbPalette) +
  scale_fill_manual(values=cbPalette)+
  # geom_line(data = observed %>% pivot_wider(names_from = name), 
  #           color="black", size=1.5) +
  labs(y="Reported Cases") +
  theme_bw() +
  theme(plot.title = element_text(size=16, face="bold", family="Times New Roman", hjust = 0.5),
        axis.title.x = element_text(size = 22, family = "Times New Roman"),
        axis.title.y = element_text(size = 22, family = "Times New Roman"),
        axis.text = element_text( size = 21, family = "Times New Roman"),
        legend.text = element_text(size = 12, family = "Times New Roman"),
        legend.title = element_blank())

Forecasts

forecast_pomp <- function(h, model, ranges){
  
  sobol_design(
    lower=ranges[,"min"],
    upper=ranges[,"max"],
    nseq=20
  ) -> params
  
  library(foreach)
  library(doParallel)
  library(iterators)
  library(doRNG)
  
  registerDoParallel()
  registerDoRNG(887851050L)
  
  ## ----forecasts2b----------------------------------------------------------
  foreach(p=iter(params,by="row"),
          .inorder=FALSE,
          .combine=bind_rows
  ) %dopar% {
    
    library(pomp)
    
    ## ----forecasts2c----------------------------------------------------------
    M1 <- model
    
    M1 %>% pfilter(params=p, Np=1000, save.states=TRUE) -> pf
    
    ## ----forecasts2d----------------------------------------------------------
    pf %>%
      saved.states() %>% ## latent state for each particle
      tail(1) %>%        ## last timepoint only
      melt() %>%         ## reshape and rename the state variables
      spread(variable,value) %>%
      group_by(rep) %>%
      summarize(S_0=S/(S+E+I+R), E_0=E/(S+E+I+R), I_0=I/(S+E+I+R), R_0=R/(S+E+I+R)) %>%
      gather(variable,value,-rep) %>%
      spread(rep,value) %>%
      column_to_rownames("variable") %>%
      as.matrix() -> x
    
    ## ----forecasts2e1----------------------------------------------------------
    pp <- parmat(unlist(p),ncol(x))
    
    ## ----forecasts2e2----------------------------------------------------------
    M1 %>%
      simulate(params=pp,format="data.frame") %>%
      select(.id,day,Y, V) %>%
      mutate(
        period="calibration",
        loglik=logLik(pf)
      ) -> calib
    
    ## ----forecasts2f----------------------------------------------------------
    M2 <- M1
    time(M2) <- max(time(M1))+seq_len(h)
    timezero(M2) <- max(time(M1))
    
    ## ----forecasts2g----------------------------------------------------------
    pp[rownames(x),] <- x
    
    M2 %>%
      simulate(params=pp,format="data.frame") %>%
      select(.id,day,Y,V) %>%
      mutate(
        period="projection",
        loglik=logLik(pf)
      ) -> proj
    
    ## ----forecasts2h----------------------------------------------------------
    bind_rows(calib,proj) -> sims
    return(sims)
  }
}


simq_f <- function(sims){
  sims %>%
    mutate(weight=exp(loglik-mean(loglik))) %>%
    arrange(day,.id) -> sims
  
  ## ----forecasts2k----------------------------------------------------------
  sims %>%
    filter(day==max(day)) %>%
    summarize(ess=sum(weight)^2/sum(weight^2))
  
  ## ----forecasts2l----------------------------------------------------------
  sims %>%
    group_by(day,period) %>%
    summarize(
      p=c(0.025,0.5,0.975),
      q=quantile(Y,weights=weight,probs=p),
      label=c("lower","median","upper")
    ) %>%
    select(-p) %>%
    spread(label,q) %>%
    ungroup() %>%
    mutate(date=day) -> simq
  
  return(simq)
}

##############################

res_VY %>%
  select(-loglik.se) %>%
  filter(loglik>max(loglik)-0.5*qchisq(df=1, p=0.95)) %>%
  gather(parameters,value) %>%
  group_by(parameters) %>%
  summarize(min=min(value),max=max(value)) %>%
  ungroup() %>%
  filter(parameters!="loglik") %>%
  column_to_rownames("parameters") %>%
  as.matrix() -> ranges_VY

res_V %>%
  select(-loglik.se) %>%
  filter(loglik>max(loglik)-0.5*qchisq(df=1, p=0.95)) %>%
  gather(parameters,value) %>%
  group_by(parameters) %>%
  summarize(min=min(value),max=max(value)) %>%
  ungroup() %>%
  filter(parameters!="loglik") %>%
  column_to_rownames("parameters") %>%
  as.matrix() -> ranges_V

res_Y %>%
  select(-loglik.se) %>%
  filter(loglik>max(loglik)-0.5*qchisq(df=1, p=0.95)) %>%
  gather(parameters,value) %>%
  group_by(parameters) %>%
  summarize(min=min(value),max=max(value)) %>%
  ungroup() %>%
  filter(parameters!="loglik") %>%
  column_to_rownames("parameters") %>%
  as.matrix() -> ranges_Y

sims_VY <- forecast_pomp(30, SEIR_VY, ranges_VY)
sims_V <- forecast_pomp(30, SEIR_V, ranges_V)
sims_Y <- forecast_pomp(30, SEIR_Y, ranges_Y)

simq_VY <- simq_f(sims_VY)
## `summarise()` has grouped output by 'day', 'period'. You can override using the `.groups` argument.
simq_V <- simq_f(sims_V)
## `summarise()` has grouped output by 'day', 'period'. You can override using the `.groups` argument.
simq_Y <- simq_f(sims_Y)
## `summarise()` has grouped output by 'day', 'period'. You can override using the `.groups` argument.
simq <- bind_rows("VY" = simq_VY, "V" = simq_V, "Y" = simq_Y, .id = "type") %>%
  filter(period=='projection') %>%
  select(-date)

#-----------------------| ARIMA |------------------------#
library(forecast)
## Registered S3 method overwritten by 'quantmod':
##   method            from
##   as.zoo.data.frame zoo
## 
## Attaching package: 'forecast'
## The following object is masked from 'package:pomp':
## 
##     forecast
cases <- data_c[1:70,]$Y/0.14
viral <- data_c[1:70,]$V

cases %>% Arima(order=c(4,1,0)) -> arima_case

viral %>% Arima(order=c(1,1,1)) -> arima_viral

case_f_arima <- forecast(arima_case, 30)

data_frame(type='arima', day=71:100, period='projection', lower=case_f_arima$lower[,2], median=case_f_arima$mean, 
           upper=case_f_arima$upper[,2]) -> arima_case_f

arima_case_f$lower[3:30] = 0

simq <- bind_rows(simq, arima_case_f)


observed = data_c[1:100,] %>%
  mutate(actual.cases = Y / 0.14) %>%
  select(day, V = V, Y = actual.cases)

observed %>% select(-V) %>%
  mutate(median=Y, lower=Y, upper=Y, period="projection", type="data") %>%
  select(-Y) -> observed

simq <- rbind(simq, observed)
#--------------------------------------------------------#

cbPalette <- c("#F0E442", "#000000","#CC0000", "#56B4E9", "#009E73", "#0072B2", "#D55E00", "#E69F00")

simq %>%
  # filter(type!='Y') %>%
  # filter(type!='VY') %>%
  ggplot(aes(x=day))+
  geom_ribbon(aes(ymin=lower,ymax=upper,fill=type),alpha=0.3,color=NA)+
  geom_line(aes(y=median,color=type), size=1.5)+
  # geom_point(data=observed, mapping=aes(x=day,y=Y),color="black")+
  # geom_line(data=observed, mapping=aes(x=day,y=Y))+
  labs(y="cases") +
  scale_color_manual(values=cbPalette) +
  scale_fill_manual(values=cbPalette)+
    labs(y="Case Count") +
  coord_cartesian(ylim = c(0, 500), xlim = c(0,100)) +
  theme_bw()+
  theme(plot.title = element_text(size=16, face="bold", family="Times New Roman", hjust = 0.5),
        axis.title.x = element_text(size = 22, family = "Times New Roman"),
        axis.title.y = element_text(size = 22, family = "Times New Roman"),
        axis.text = element_text( size = 21, family = "Times New Roman"),
        legend.text = element_text(size = 12, family = "Times New Roman"),
        legend.title = element_blank()) -> p

p.zoom <- ggplot(simq, aes(x=day)) +
  geom_ribbon(aes(ymin=lower,ymax=upper,fill=type),alpha=0.3,color=NA) +
  geom_line(aes(y=median,color=type), size=1.5) +
  scale_color_manual(values=cbPalette) +
  scale_fill_manual(values=cbPalette) +
  coord_cartesian(xlim=c(65,100), ylim=c(0,50)) +
  theme_bw() +
  theme(axis.title.y=element_blank(),
        axis.title.x = element_text(family = "Times New Roman"),
        legend.position = "none",
        axis.text = element_text(family = "Times New Roman"))

p +  annotation_custom(ggplotGrob(p.zoom), xmin = 0, xmax = 65, 
                       ymin = 200, ymax = 500)