Bayesian statistics is sub-field of statistics based on the Bayesian interpretation of probability where probability expresses a degree of belief in an event. Bayesian statistics has some difference compare to the more common frequentist. In brief we have:
Frequentist makes predictions on the underlying truths of the experiment using only data from the current experiment.
Bayesian take a different approach where paste knowledge are take into consideration.
Knowing that the main difference between the frequentist and bayesian approach is that the latter specifies that there is some prior probability.
In this markdown we will have a look to the bayesian in order to create a regression model to predict a two class dataset. As we will see thanks to the bayesian statistics and in particular using the Markov chain Monte Carlo (MCMC) wee will be able to see how confident we are in the modeling of our data.
Before jump to the analysis we will create two utility functions. The
create_prediction is a function to create the prediction based on
the
bayesian model while the plot_bayesian_model is a utility function to
visualize the model while.
<- function(tbl){
create_prediction %>%
tbl predict(
new_data = sim_data,
type = "raw",
opts = list(
allow_new_levels = TRUE,
probs = c(0.025, 0.17, 0.83, 0.975)
)%>%
) as_tibble() %>%
::clean_names() %>%
janitorselect(-est_error) %>%
set_names(c(".pred", ".pred_lower_95", ".pred_lower_66",
".pred_upper_66", ".pred_upper_95")) %>%
bind_cols(sim_data)
}<- function(tbl){
plot_bayesian_model ggplot(tbl, aes(x, y, group = type)) +
geom_ribbon(aes(ymax = .pred_upper_95, ymin = .pred_lower_95,),
alpha = 0.5, fill = "#0096C7") +
geom_ribbon(aes(ymax = .pred_upper_66, ymin = .pred_lower_66),
alpha = 0.5, fill = "#0077B6") +
geom_point(aes(col = type), alpha = 0.5) +
geom_line(aes(y = .pred), size = 1, col = "#CAF0F8") +
scale_color_tq() +
scale_y_continuous(labels = scales::comma) +
theme(legend.position = "none")
}
First of all we will read the CSV file. This data are simulated data with no particular meaning. Hence we will not spend to much time trying to interpret those data. The focus of this markdown is to have a look to the Bayesian model and to see how increasing the complexity of the model leads to better results.
<- read_csv("data/bayesian_data_2.csv")
sim_data sim_data
## # A tibble: 385 × 3
## x y type
## <dbl> <dbl> <chr>
## 1 0.266 4.10 sim_1
## 2 0.697 0.111 sim_1
## 3 0.735 -3.84 sim_1
## 4 0.770 1.94 sim_1
## 5 2.30 -7.83 sim_1
## 6 2.63 -18.2 sim_1
## 7 -0.691 0.810 sim_1
## 8 0.706 2.26 sim_1
## 9 0.964 3.22 sim_1
## 10 2.49 -18.1 sim_1
## # … with 375 more rows
As we can see our data are clustered in two groups:
For this reason we can already say that the performance of a global model will not be good. Anyway to see the improvement we will begin with this very basic model.
For what concern the other column we have have a independent variable x and a dependent variable y.
<- sim_data %>%
p1 ggplot(aes(x, y, col = type)) +
geom_point(alpha = 0.5) +
geom_smooth(method = "gam", show.legend = FALSE) +
scale_color_tq() +
labs(x = NULL, y = NULL) +
guides(colour = guide_legend(override.aes = list(size = 10))) +
theme(text = element_text(size = 10))
<- sim_data %>%
p2 ggplot(aes(x, y, col = type)) +
geom_point(alpha = 0.5) +
geom_smooth(method = "gam", show.legend = FALSE) +
scale_color_tq() +
labs(x = NULL, y = NULL) +
facet_wrap(vars(type), ncol = 1, scales = "free") +
guides(colour = guide_legend(override.aes = list(size = 10))) +
theme(text = element_text(size = 10))
+ p2 + plot_layout(guides = 'collect', ncol = 2) p1
The first model that we are going to create is a bayesian version of a classical linear regression. - LINEAR (BAD FOR NON LINEAR PROBLEM) - NO HIERACHY
So, due to the fact that we will using a linear regression we will model our data as follow:
\[y = \alpha \; + \; \beta x_{i} \; + \; \epsilon_{i} \] But we will train int using the MCMC and not the ordinary least squares (OLS).
#workflow_bayesian_fit_1 <-
# fit(workflow_bayesian_1, data = sim_data)
#workflow_bayesian_fit_1$fit$fit$fit %>% plot()
#write_rds(workflow_bayesian_fit_1, "models/workflow_bayesian_fit_1.rds")
<- read_rds("models/workflow_bayesian_fit_1.rds")
workflow_bayesian_fit_1
<-
predictions_bayes_tbl_1 %>%
workflow_bayesian_fit_1 create_prediction()
<- predictions_bayes_tbl_1 %>%
plot_model_1 plot_bayesian_model() +
labs(x = NULL, y = NULL, title = "First Bayesian model",
subtitle = "Global Linear, NO hierarchy (BAD)")
NON-LINEAR : - NON-LINEAR (GOOD FOR NON LINEAR PROBLEM) - NO HIERACHY (BAD)
<- bayesian(
model_spec_bayesian_2 mode = "regression",
family = gaussian(),
engine = "brms",
formula.override = bayesian_formula(y ~ s(x))
)
<- recipe(y ~ x, sim_data)
recipe_spec_bayesian_2
<-
workflow_bayesian_2 workflow() %>%
add_model(model_spec_bayesian_2,
# unfortunately tidymodels as u bug so we have to enter a formula
# the formula will not be considered so is
# not importat what we write
formula = y ~ x) %>%
add_recipe(recipe_spec_bayesian_2)
#workflow_bayesian_fit_2 <-
# fit(workflow_bayesian_2, data = sim_data)
#workflow_bayesian_fit_2$fit$fit$fit %>% plot()
#write_rds(workflow_bayesian_fit_2, "models/workflow_bayesian_fit_2")
<- read_rds("models/workflow_bayesian_fit_2")
workflow_bayesian_fit_2
<-
predictions_bayes_tbl_2 %>%
workflow_bayesian_fit_2 create_prediction()
<- predictions_bayes_tbl_2 %>%
plot_model_2 plot_bayesian_model() +
labs(x = NULL, y = NULL, title = "Second Bayesian model",
subtitle = "Global NON-Linear (Better but still BAD)")
NON-LINEAR AND HIERARCHICAL: - NON-LINEAR (GOOD FOR NON LINEAR PROBLEM) - YES HIERARCHICAL (GOOD) - GLOBAL VARIANCE (BAD)
<- bayesian(
model_spec_bayesian_3 mode = "regression",
family = gaussian(),
engine = "brms",
formula.override = bayesian_formula(
~ s(x, by = type) + (1 | type)
y
)
)
<- recipe(y ~ x + type, sim_data)
recipe_spec_bayesian_3
<-
workflow_bayesian_3 workflow() %>%
add_model(model_spec_bayesian_3,
# unfortunately tidymodels as u bug so we have to enter a formula
# the formula will not be considered so is
# not importat what we write
formula = y ~ x) %>%
add_recipe(recipe_spec_bayesian_3)
#workflow_bayesian_fit_3 <-
# fit(workflow_bayesian_3, data = sim_data)
#workflow_bayesian_fit_3$fit$fit$fit %>% plot()
#write_rds(workflow_bayesian_fit_3, "models/workflow_bayesian_fit_3")
<- read_rds("models/workflow_bayesian_fit_3")
workflow_bayesian_fit_3
<-
predictions_bayes_tbl_3 %>%
workflow_bayesian_fit_3 create_prediction()
<- predictions_bayes_tbl_3 %>%
plot_model_3 plot_bayesian_model() +
labs(x = NULL, y = NULL, title = "Third Bayesian model",
subtitle =
"Groupwise Bayesian GAM (Hierarchical Good, global variance BAD)")
NON-LINEAR AND HIERARCHICAL AND GROUP VARIANCE: - NON-LINEAR (GOOD FOR NON LINEAR PROBLEM) - YES HIERARCHICAL (GOOD) - GROUP VARIANCE (GOOD)
<- bayesian(
model_spec_bayesian_4 mode = "regression",
family = gaussian(),
engine = "brms",
formula.override = bayesian_formula(
bf(
~ s(x, by = type) + (1 | type),
y ~s(x, by = type) + (1 | type)
sigma
)
)
)
<- recipe(y ~ x + type, sim_data)
recipe_spec_bayesian_4
<-
workflow_bayesian_4 workflow() %>%
add_model(model_spec_bayesian_4,
# unfortunately tidymodels as u bug so we have to enter a formula
# the formula will not be considered so is
# not importat what we write
formula = y ~ x)
add_recipe(recipe_spec_bayesian_4)
#workflow_bayesian_fit_4 <-
# fit(workflow_bayesian_4, data = sim_data)
#workflow_bayesian_fit_4$fit$fit$fit %>% plot()
#write_rds(workflow_bayesian_fit_4, "models/workflow_bayesian_fit_4")
<- read_rds("models/workflow_bayesian_fit_4")
workflow_bayesian_fit_4
<-
predictions_bayes_tbl_4 %>%
workflow_bayesian_fit_4 create_prediction()
<- predictions_bayes_tbl_4 %>%
plot_model_4 plot_bayesian_model() +
labs(x = NULL, y = NULL, title = "Fourth Bayesian model",
subtitle =
"Groupwise Bayesian GAM (Hierarchical Good, Group variance Good)")
Here are reported all the 4 methods together in order to see the difference and out the performance improve going from the simplest model to the most complex
+ plot_model_2 + plot_model_3 + plot_model_4 plot_model_1
## Warning: Removed 6 rows containing missing values (geom_point).
## Removed 6 rows containing missing values (geom_point).
## Removed 6 rows containing missing values (geom_point).
## Removed 6 rows containing missing values (geom_point).