# Load some packages we are going to need
library(marginaleffects) # For predictions
library(splines) # For splines
library(ggplot2) # For plotting
library(dplyr) # For some renaming

# Make the plots look nice
theme_set(theme_classic())


# Read the data
# This is the start questionnaire from a diary study we conducted
# see Rohrer et al. (2024), https://doi.org/10.1525/collabra.121238
# Data on the OSF: https://osf.io/gj6x5/

dat <- read.csv("start.csv")

# Restrict to range that we can model well
dat <- dat[dat$sex != 3, ] # exclude 5 people who reported a gender distinct from male/female
dat <- dat[dat$age < 60, ] # exclude people over the age of 60

# Limit to complete cases
dat <- dat[complete.cases(dat[, c("age", "sex", "partner_any", "IMP_friends_Start")]),]

dat$sex <- as.factor(dat$sex)

table(dat$age)
table(dat$sex)
table(dat$partner_any)

dat <- dat |>
  mutate(sex = factor(sex, levels = c(1, 2), labels = c("female", "male"))) |>
  rename(partner = partner_any,
         gender = sex,
         friendship_importance = IMP_friends_Start)


age_lin <- lm(friendship_importance ~ age, data = dat)
age_cat <- lm(friendship_importance ~ as.factor(age), data = dat)

pred_lin <- avg_predictions(age_lin, by = "age")
pred_cat <- avg_predictions(age_cat, by = "age")

# Expand categorical predictions to plot them as step function including ribbon
pred_cat_expanded <- data.frame(matrix(NA, nrow = length(rep(pred_cat$estimate, each = 100)), ncol = 4))
names(pred_cat_expanded) <- c("age", "estimate", "conf.low", "conf.high")
pred_cat_expanded$age <- seq(from = min(pred_cat$age), to = (max(pred_cat$age) + 0.999999), length.out = nrow(pred_cat_expanded))
pred_cat_expanded$estimate <- rep(pred_cat$estimate, each = 100)
pred_cat_expanded$conf.low <- rep(pred_cat$conf.low, each = 100)
pred_cat_expanded$conf.high <- rep(pred_cat$conf.high, each = 100)

# color
col_cat <- "#E69F00"
col_lin <- "#CC79A7"

# Two extremes
ggplot() +
  # categorical
  geom_point(data = pred_cat, aes(x = (age + 0.5), y = estimate), color = col_cat, shape = 4) +
  geom_line(data = pred_cat, aes(x = (age + 0.5), y = estimate), color = col_cat, size = .2, linetype = "dashed") +
  geom_ribbon(data = pred_cat_expanded, aes(x = age, ymin = conf.low, ymax = conf.high), alpha = .2, fill = col_cat) +
  # linear
  geom_line(data = pred_lin, aes(x = age, y = estimate), color = col_lin) +
  geom_ribbon(data = pred_lin, aes(x = age, ymin = conf.low, ymax = conf.high), alpha = .2, fill = col_lin) +
  coord_cartesian(ylim = c(2.5, 5)) +
  xlab("Age") +
  ylab("Outcome (95% CI)")
ggsave("age_extremes.png", width = 4, height = 3)


# Smooth solutions

# Splines
age_splines <- lm(friendship_importance ~ bs(age, df = 4), data = dat)
pred_splines <- avg_predictions(age_splines, by = "age")

# Polynomial
age_poly <- lm(friendship_importance ~ I(age^4) + I(age^3) + I(age^2) + age, data = dat)
pred_poly <- avg_predictions(age_poly, by = "age")

# Bins
cutoffs <- quantile(dat$age, probs = seq(0, 1, length.out = 6), na.rm = TRUE)

age_bin <- lm(friendship_importance ~ I(age <= 22) + 
                I(age > 22 & age <= 29) + 
                I(age > 29 & age <= 39) + 
                I(age > 39 & age <= 50), data = dat)
pred_bin <- avg_predictions(age_bin, by = "age")


col_neutral_1 <- "darkgrey"
col_neutral_2 <- "lightgrey"
col_smooth <- "#0072B2"


ggplot() +
  # categorical
  geom_point(data = pred_cat, aes(x = (age + 0.5), y = estimate), color = col_neutral_1, shape = 4) +
  geom_line(data = pred_cat, aes(x = (age + 0.5), y = estimate), color = col_neutral_2, size = .2, linetype = "dashed") +
  # linear
  geom_line(data = pred_lin, aes(x = age, y = estimate), color = col_neutral_1, size = .2) +
  coord_cartesian(ylim = c(2.5, 5)) +
  xlab("Age") +
  ylab("Outcome (95% CI)") +
  # smoothed
  geom_line(data = pred_splines, aes(x = age, y = estimate), color = col_smooth) +
  geom_ribbon(data = pred_splines, aes(x = age, ymin = conf.low, ymax = conf.high), fill = col_smooth, alpha = .2)

ggsave("age_splines.png", width = 4, height = 3)

ggplot() +
  # categorical
  geom_point(data = pred_cat, aes(x = (age + 0.5), y = estimate), color = col_neutral_1, shape = 4) +
  geom_line(data = pred_cat, aes(x = (age + 0.5), y = estimate), color = col_neutral_2, size = .2, linetype = "dashed") +
  # linear
  geom_line(data = pred_lin, aes(x = age, y = estimate), color = col_neutral_1, size = .2) +
  coord_cartesian(ylim = c(2.5, 5)) +
  xlab("Age") +
  ylab("Outcome (95% CI)") +
  # smoothed
  geom_line(data = pred_poly, aes(x = age, y = estimate), color = col_smooth) +
  geom_ribbon(data = pred_poly, aes(x = age, ymin = conf.low, ymax = conf.high), fill = col_smooth, alpha = .2)

ggsave("age_poly.png", width = 4, height = 3)

ggplot() +
  # categorical
  geom_point(data = pred_cat, aes(x = (age + 0.5), y = estimate), color = col_neutral_1, shape = 4) +
  geom_line(data = pred_cat, aes(x = (age + 0.5), y = estimate), color = col_neutral_2, size = .2, linetype = "dashed") +
  # linear
  geom_line(data = pred_lin, aes(x = age, y = estimate), color = col_neutral_1, size = .2) +
  coord_cartesian(ylim = c(2.5, 5)) +
  xlab("Age") +
  ylab("Outcome (95% CI)") +
  # smoothed
  geom_line(data = pred_bin, aes(x = age, y = estimate), color = col_smooth) +
  geom_ribbon(data = pred_bin, aes(x = age, ymin = conf.low, ymax = conf.high), fill = col_smooth, alpha = .2)

ggsave("age_bin.png", width = 4, height = 3)
